Home | History | Annotate | Download | only in framework
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Classes and functions used to construct graphs."""
     16 # pylint: disable=g-bad-name
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 import copy
     23 import re
     24 import sys
     25 import threading
     26 
     27 import numpy as np
     28 import six
     29 from six.moves import xrange  # pylint: disable=redefined-builtin
     30 
     31 from tensorflow.core.framework import attr_value_pb2
     32 from tensorflow.core.framework import function_pb2
     33 from tensorflow.core.framework import graph_pb2
     34 from tensorflow.core.framework import node_def_pb2
     35 from tensorflow.core.framework import op_def_pb2
     36 from tensorflow.core.framework import versions_pb2
     37 from tensorflow.core.protobuf import config_pb2
     38 from tensorflow.python import pywrap_tensorflow as c_api
     39 from tensorflow.python import tf2
     40 from tensorflow.python.eager import context
     41 from tensorflow.python.eager import core
     42 from tensorflow.python.eager import tape
     43 from tensorflow.python.framework import c_api_util
     44 from tensorflow.python.framework import composite_tensor
     45 from tensorflow.python.framework import device as pydev
     46 from tensorflow.python.framework import dtypes
     47 from tensorflow.python.framework import errors
     48 from tensorflow.python.framework import op_def_registry
     49 from tensorflow.python.framework import registry
     50 from tensorflow.python.framework import tensor_shape
     51 from tensorflow.python.framework import traceable_stack
     52 from tensorflow.python.framework import versions
     53 from tensorflow.python.ops import control_flow_util
     54 from tensorflow.python.platform import app
     55 from tensorflow.python.platform import tf_logging as logging
     56 from tensorflow.python.util import compat
     57 from tensorflow.python.util import decorator_utils
     58 from tensorflow.python.util import deprecation
     59 from tensorflow.python.util import function_utils
     60 from tensorflow.python.util import lock_util
     61 from tensorflow.python.util import memory
     62 from tensorflow.python.util import tf_contextlib
     63 from tensorflow.python.util import tf_stack
     64 from tensorflow.python.util.deprecation import deprecated_args
     65 from tensorflow.python.util.tf_export import tf_export
     66 
     67 
     68 # Temporary global switches determining if we should enable the work-in-progress
     69 # calls to the C API. These will be removed once all functionality is supported.
     70 _USE_C_API = True
     71 _USE_C_SHAPES = True
     72 
     73 
     74 def tensor_id(tensor):
     75   """Returns a unique identifier for this Tensor."""
     76   return tensor._id  # pylint: disable=protected-access
     77 
     78 
     79 class _UserDeviceSpec(object):
     80   """Store user-specified device and provide computation of merged device."""
     81 
     82   def __init__(self, device_name_or_function):
     83     self._device_name_or_function = device_name_or_function
     84 
     85     self.display_name = str(self._device_name_or_function)
     86     if callable(self._device_name_or_function):
     87       dev_func = self._device_name_or_function
     88       func_name = function_utils.get_func_name(dev_func)
     89       func_code = function_utils.get_func_code(dev_func)
     90       if func_code:
     91         fname = func_code.co_filename
     92         lineno = func_code.co_firstlineno
     93       else:
     94         fname = "unknown"
     95         lineno = -1
     96       self.display_name = "%s<%s, %d>" % (func_name, fname, lineno)
     97 
     98     self.raw_string = None
     99 
    100     self.function = self._device_name_or_function
    101     if not (self._device_name_or_function is None or
    102             callable(self._device_name_or_function)):
    103       self.raw_string = self._device_name_or_function
    104       self.function = pydev.merge_device(self._device_name_or_function)
    105 
    106 
    107 class NullContextmanager(object):
    108 
    109   def __init__(self, *args, **kwargs):
    110     pass
    111 
    112   def __enter__(self):
    113     pass
    114 
    115   def __exit__(self, type_arg, value_arg, traceback_arg):
    116     return False  # False values do not suppress exceptions
    117 
    118 
    119 def _override_helper(clazz_object, operator, func):
    120   """Overrides (string) operator on Tensors to call func.
    121 
    122   Args:
    123     clazz_object: the class to override for; either Tensor or SparseTensor.
    124     operator: the string name of the operator to override.
    125     func: the function that replaces the overridden operator.
    126 
    127   Raises:
    128     ValueError: If operator has already been overwritten,
    129       or if operator is not allowed to be overwritten.
    130   """
    131   existing = getattr(clazz_object, operator, None)
    132   if existing is not None:
    133     # Check to see if this is a default method-wrapper or slot wrapper which
    134     # will be true for the comparison operators.
    135     if not isinstance(existing, type(object.__lt__)):
    136       raise ValueError("operator %s cannot be overwritten again on class %s." %
    137                        (operator, clazz_object))
    138   if operator not in Tensor.OVERLOADABLE_OPERATORS:
    139     raise ValueError("Overriding %s is disallowed" % operator)
    140   setattr(clazz_object, operator, func)
    141 
    142 
    143 def _as_graph_element(obj):
    144   """Convert `obj` to a graph element if possible, otherwise return `None`.
    145 
    146   Args:
    147     obj: Object to convert.
    148 
    149   Returns:
    150     The result of `obj._as_graph_element()` if that method is available;
    151         otherwise `None`.
    152   """
    153   conv_fn = getattr(obj, "_as_graph_element", None)
    154   if conv_fn and callable(conv_fn):
    155     return conv_fn()
    156   return None
    157 
    158 
    159 _TENSOR_LIKE_TYPES = tuple()
    160 
    161 
    162 def is_dense_tensor_like(t):
    163   """EXPERIMENTAL: Returns true if `t` implements the tensor interface.
    164 
    165   See `register_dense_tensor_like_type()` for the current definition of a
    166   "tensor-like type".
    167 
    168   Args:
    169     t: An object.
    170 
    171   Returns:
    172     True iff `t` is an instance of one of the registered "tensor-like" types.
    173   """
    174   return isinstance(t, _TENSOR_LIKE_TYPES)
    175 
    176 
    177 def register_dense_tensor_like_type(tensor_type):
    178   """EXPERIMENTAL: Registers `tensor_type` as implementing the tensor interface.
    179 
    180   A "tensor-like type" can represent a single dense tensor, and implements
    181   the `name` and `dtype` properties.
    182 
    183   Args:
    184     tensor_type: A type implementing the tensor interface.
    185 
    186   Raises:
    187     TypeError: If `tensor_type` does not implement the tensor interface.
    188   """
    189   try:
    190     if not isinstance(tensor_type.name, property):
    191       raise TypeError("Type %s does not define a `name` property" %
    192                       tensor_type.__name__)
    193   except AttributeError:
    194     raise TypeError("Type %s does not define a `name` property" %
    195                     tensor_type.__name__)
    196   try:
    197     if not isinstance(tensor_type.dtype, property):
    198       raise TypeError("Type %s does not define a `dtype` property" %
    199                       tensor_type.__name__)
    200   except AttributeError:
    201     raise TypeError("Type %s does not define a `dtype` property" %
    202                     tensor_type.__name__)
    203   # We expect this list to be small, so choose quadratic complexity
    204   # for registration, so that we have a tuple that can be used for
    205   # more efficient `isinstance` checks later.
    206   global _TENSOR_LIKE_TYPES
    207   _TENSOR_LIKE_TYPES = tuple(list(_TENSOR_LIKE_TYPES) + [tensor_type])
    208 
    209 
    210 def uid():
    211   """A unique (within this program execution) integer."""
    212   return c_api.TFE_Py_UID()
    213 
    214 
    215 def numpy_text(tensor, is_repr=False):
    216   """Human readable representation of a tensor's numpy value."""
    217   if tensor.dtype.is_numpy_compatible:
    218     text = repr(tensor.numpy()) if is_repr else str(tensor.numpy())
    219   else:
    220     text = "<unprintable>"
    221   if "\n" in text:
    222     text = "\n" + text
    223   return text
    224 
    225 
    226 # NOTE(ebrevdo): Do not subclass this.  If you do, I will break you on purpose.
    227 class _TensorLike(object):
    228   """Internal cls for grouping Tensor, SparseTensor, ..., for is_instance."""
    229   pass
    230 
    231 
    232 @tf_export("Tensor")
    233 class Tensor(_TensorLike):
    234   """Represents one of the outputs of an `Operation`.
    235 
    236   A `Tensor` is a symbolic handle to one of the outputs of an
    237   `Operation`. It does not hold the values of that operation's output,
    238   but instead provides a means of computing those values in a
    239   TensorFlow `tf.Session`.
    240 
    241   This class has two primary purposes:
    242 
    243   1. A `Tensor` can be passed as an input to another `Operation`.
    244      This builds a dataflow connection between operations, which
    245      enables TensorFlow to execute an entire `Graph` that represents a
    246      large, multi-step computation.
    247 
    248   2. After the graph has been launched in a session, the value of the
    249      `Tensor` can be computed by passing it to
    250      `tf.Session.run`.
    251      `t.eval()` is a shortcut for calling
    252      `tf.get_default_session().run(t)`.
    253 
    254   In the following example, `c`, `d`, and `e` are symbolic `Tensor`
    255   objects, whereas `result` is a numpy array that stores a concrete
    256   value:
    257 
    258   ```python
    259   # Build a dataflow graph.
    260   c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
    261   d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
    262   e = tf.matmul(c, d)
    263 
    264   # Construct a `Session` to execute the graph.
    265   sess = tf.Session()
    266 
    267   # Execute the graph and store the value that `e` represents in `result`.
    268   result = sess.run(e)
    269   ```
    270   """
    271 
    272   # List of Python operators that we allow to override.
    273   OVERLOADABLE_OPERATORS = {
    274       # Binary.
    275       "__add__",
    276       "__radd__",
    277       "__sub__",
    278       "__rsub__",
    279       "__mul__",
    280       "__rmul__",
    281       "__div__",
    282       "__rdiv__",
    283       "__truediv__",
    284       "__rtruediv__",
    285       "__floordiv__",
    286       "__rfloordiv__",
    287       "__mod__",
    288       "__rmod__",
    289       "__lt__",
    290       "__le__",
    291       "__gt__",
    292       "__ge__",
    293       "__and__",
    294       "__rand__",
    295       "__or__",
    296       "__ror__",
    297       "__xor__",
    298       "__rxor__",
    299       "__getitem__",
    300       "__pow__",
    301       "__rpow__",
    302       # Unary.
    303       "__invert__",
    304       "__neg__",
    305       "__abs__",
    306       "__matmul__",
    307       "__rmatmul__"
    308   }
    309 
    310   def __init__(self, op, value_index, dtype):
    311     """Creates a new `Tensor`.
    312 
    313     Args:
    314       op: An `Operation`. `Operation` that computes this tensor.
    315       value_index: An `int`. Index of the operation's endpoint that produces
    316         this tensor.
    317       dtype: A `DType`. Type of elements stored in this tensor.
    318 
    319     Raises:
    320       TypeError: If the op is not an `Operation`.
    321     """
    322     if not isinstance(op, Operation):
    323       raise TypeError("op needs to be an Operation: %s" % op)
    324     self._op = op
    325     self._value_index = value_index
    326     self._dtype = dtypes.as_dtype(dtype)
    327     # This will be set by self._as_tf_output().
    328     self._tf_output = None
    329     # This will be set by self.shape().
    330     self._shape_val = None
    331     # List of operations that use this Tensor as input.  We maintain this list
    332     # to easily navigate a computation graph.
    333     self._consumers = []
    334     self._id = uid()
    335 
    336   @property
    337   def op(self):
    338     """The `Operation` that produces this tensor as an output."""
    339     return self._op
    340 
    341   @property
    342   def dtype(self):
    343     """The `DType` of elements in this tensor."""
    344     return self._dtype
    345 
    346   @property
    347   def graph(self):
    348     """The `Graph` that contains this tensor."""
    349     return self._op.graph
    350 
    351   @property
    352   def name(self):
    353     """The string name of this tensor."""
    354     if not self._op.name:
    355       raise ValueError("Operation was not named: %s" % self._op)
    356     return "%s:%d" % (self._op.name, self._value_index)
    357 
    358   @property
    359   def device(self):
    360     """The name of the device on which this tensor will be produced, or None."""
    361     return self._op.device
    362 
    363   @property
    364   def shape(self):
    365     """Returns the `TensorShape` that represents the shape of this tensor.
    366 
    367     The shape is computed using shape inference functions that are
    368     registered in the Op for each `Operation`.  See
    369     `tf.TensorShape`
    370     for more details of what a shape represents.
    371 
    372     The inferred shape of a tensor is used to provide shape
    373     information without having to launch the graph in a session. This
    374     can be used for debugging, and providing early error messages. For
    375     example:
    376 
    377     ```python
    378     c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
    379 
    380     print(c.shape)
    381     ==> TensorShape([Dimension(2), Dimension(3)])
    382 
    383     d = tf.constant([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]])
    384 
    385     print(d.shape)
    386     ==> TensorShape([Dimension(4), Dimension(2)])
    387 
    388     # Raises a ValueError, because `c` and `d` do not have compatible
    389     # inner dimensions.
    390     e = tf.matmul(c, d)
    391 
    392     f = tf.matmul(c, d, transpose_a=True, transpose_b=True)
    393 
    394     print(f.shape)
    395     ==> TensorShape([Dimension(3), Dimension(4)])
    396     ```
    397 
    398     In some cases, the inferred shape may have unknown dimensions. If
    399     the caller has additional information about the values of these
    400     dimensions, `Tensor.set_shape()` can be used to augment the
    401     inferred shape.
    402 
    403     Returns:
    404       A `TensorShape` representing the shape of this tensor.
    405 
    406     """
    407     if self._shape_val is None:
    408       self._shape_val = self._c_api_shape()
    409     return self._shape_val
    410 
    411   def _get_input_ops_without_shapes(self, target_op):
    412     """Returns ops needing shape inference to compute target_op's shape."""
    413     result = []
    414     stack = [self._op]
    415     visited = set()
    416     while stack:
    417       op = stack.pop()
    418       if op in visited: continue
    419       result.append(op)
    420       stack.extend(t.op for t in op.inputs if t._shape_val is None)
    421       visited.add(op)
    422     return result
    423 
    424   def _c_api_shape(self):
    425     """Returns the TensorShape of this tensor according to the C API."""
    426     c_graph = self._op._graph._c_graph  # pylint: disable=protected-access
    427     shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper(
    428         c_graph, self._as_tf_output())
    429     if unknown_shape:
    430       return tensor_shape.unknown_shape()
    431     else:
    432       shape_vector = [None if d == -1 else d for d in shape_vector]
    433       return tensor_shape.TensorShape(shape_vector)
    434 
    435   @property
    436   def _shape(self):
    437     logging.warning("Tensor._shape is private, use Tensor.shape "
    438                     "instead. Tensor._shape will eventually be removed.")
    439     return self.shape
    440 
    441   @_shape.setter
    442   def _shape(self, value):
    443     raise ValueError(
    444         "Tensor._shape cannot be assigned, use Tensor.set_shape instead.")
    445 
    446   def __iter__(self):
    447     if not context.executing_eagerly():
    448       raise TypeError(
    449           "Tensor objects are only iterable when eager execution is "
    450           "enabled. To iterate over this tensor use tf.map_fn.")
    451     shape = self._shape_tuple()
    452     if shape is None:
    453       raise TypeError("Cannot iterate over a tensor with unknown shape.")
    454     if not shape:
    455       raise TypeError("Cannot iterate over a scalar tensor.")
    456     if shape[0] is None:
    457       raise TypeError(
    458           "Cannot iterate over a tensor with unknown first dimension.")
    459     for i in xrange(shape[0]):
    460       yield self[i]
    461 
    462   def _shape_as_list(self):
    463     if self.shape.ndims is not None:
    464       return [dim.value for dim in self.shape.dims]
    465     else:
    466       return None
    467 
    468   def _shape_tuple(self):
    469     shape = self._shape_as_list()
    470     if shape is None:
    471       return None
    472     return tuple(shape)
    473 
    474   def _rank(self):
    475     """Integer rank of this Tensor, if known, else None.
    476 
    477     Returns:
    478       Integer rank or None
    479     """
    480     return self.shape.ndims
    481 
    482   def get_shape(self):
    483     """Alias of Tensor.shape."""
    484     return self.shape
    485 
    486   def set_shape(self, shape):
    487     """Updates the shape of this tensor.
    488 
    489     This method can be called multiple times, and will merge the given
    490     `shape` with the current shape of this tensor. It can be used to
    491     provide additional information about the shape of this tensor that
    492     cannot be inferred from the graph alone. For example, this can be used
    493     to provide additional information about the shapes of images:
    494 
    495     ```python
    496     _, image_data = tf.TFRecordReader(...).read(...)
    497     image = tf.image.decode_png(image_data, channels=3)
    498 
    499     # The height and width dimensions of `image` are data dependent, and
    500     # cannot be computed without executing the op.
    501     print(image.shape)
    502     ==> TensorShape([Dimension(None), Dimension(None), Dimension(3)])
    503 
    504     # We know that each image in this dataset is 28 x 28 pixels.
    505     image.set_shape([28, 28, 3])
    506     print(image.shape)
    507     ==> TensorShape([Dimension(28), Dimension(28), Dimension(3)])
    508     ```
    509 
    510     NOTE: This shape is not enforced at runtime. Setting incorrect shapes can
    511     result in inconsistencies between the statically-known graph and the runtime
    512     value of tensors. For runtime validation of the shape, use `tf.ensure_shape`
    513     instead.
    514 
    515     Args:
    516       shape: A `TensorShape` representing the shape of this tensor, a
    517       `TensorShapeProto`, a list, a tuple, or None.
    518 
    519     Raises:
    520       ValueError: If `shape` is not compatible with the current shape of
    521         this tensor.
    522     """
    523     # Reset cached shape.
    524     self._shape_val = None
    525 
    526     # We want set_shape to be reflected in the C API graph for when we run it.
    527     if not isinstance(shape, tensor_shape.TensorShape):
    528       shape = tensor_shape.TensorShape(shape)
    529     dim_list = []
    530     if shape.dims is None:
    531       unknown_shape = True
    532     else:
    533       unknown_shape = False
    534       for dim in shape.dims:
    535         if dim.value is None:
    536           dim_list.append(-1)
    537         else:
    538           dim_list.append(dim.value)
    539     try:
    540       c_api.TF_GraphSetTensorShape_wrapper(
    541           self._op._graph._c_graph,  # pylint: disable=protected-access
    542           self._as_tf_output(),
    543           dim_list,
    544           unknown_shape)
    545     except errors.InvalidArgumentError as e:
    546       # Convert to ValueError for backwards compatibility.
    547       raise ValueError(str(e))
    548 
    549   @property
    550   def value_index(self):
    551     """The index of this tensor in the outputs of its `Operation`."""
    552     return self._value_index
    553 
    554   def consumers(self):
    555     """Returns a list of `Operation`s that consume this tensor.
    556 
    557     Returns:
    558       A list of `Operation`s.
    559     """
    560     consumer_names = c_api.TF_OperationOutputConsumers_wrapper(
    561         self._as_tf_output())
    562     # pylint: disable=protected-access
    563     return [
    564         self.graph._get_operation_by_name_unsafe(name)
    565         for name in consumer_names
    566     ]
    567     # pylint: enable=protected-access
    568 
    569   def _as_node_def_input(self):
    570     """Return a value to use for the NodeDef "input" attribute.
    571 
    572     The returned string can be used in a NodeDef "input" attribute
    573     to indicate that the NodeDef uses this Tensor as input.
    574 
    575     Raises:
    576       ValueError: if this Tensor's Operation does not have a name.
    577 
    578     Returns:
    579       a string.
    580     """
    581     if not self._op.name:
    582       raise ValueError("Operation was not named: %s" % self._op)
    583     if self._value_index == 0:
    584       return self._op.name
    585     else:
    586       return "%s:%d" % (self._op.name, self._value_index)
    587 
    588   def _as_tf_output(self):
    589     # pylint: disable=protected-access
    590     # NOTE: Beyond preventing unnecessary (re-)allocation, the cached object
    591     # also guarantees that a dictionary of tf_output objects will retain a
    592     # deterministic (yet unsorted) order which prevents memory blowup in the
    593     # cache of executor(s) stored for every session.
    594     if self._tf_output is None:
    595       self._tf_output = c_api_util.tf_output(self.op._c_op, self.value_index)
    596     return self._tf_output
    597     # pylint: enable=protected-access
    598 
    599   def __str__(self):
    600     return "Tensor(\"%s\"%s%s%s)" % (
    601         self.name, (", shape=%s" % self.get_shape())
    602         if self.get_shape().ndims is not None else "",
    603         (", dtype=%s" % self._dtype.name)
    604         if self._dtype else "", (", device=%s" % self.device)
    605         if self.device else "")
    606 
    607   def __repr__(self):
    608     return "<tf.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.get_shape(),
    609                                                    self._dtype.name)
    610 
    611   def __hash__(self):
    612     # Necessary to support Python's collection membership operators
    613     return id(self)
    614 
    615   def __eq__(self, other):
    616     # Necessary to support Python's collection membership operators
    617     return id(self) == id(other)
    618 
    619   def __copy__(self):
    620     # TODO(b/77597810): get rid of Tensor copies.
    621     cls = self.__class__
    622     result = cls.__new__(cls)
    623     result.__dict__.update(self.__dict__)
    624     return result
    625 
    626   # NOTE(mrry): This enables the Tensor's overloaded "right" binary
    627   # operators to run when the left operand is an ndarray, because it
    628   # accords the Tensor class higher priority than an ndarray, or a
    629   # numpy matrix.
    630   # TODO(mrry): Convert this to using numpy's __numpy_ufunc__
    631   # mechanism, which allows more control over how Tensors interact
    632   # with ndarrays.
    633   __array_priority__ = 100
    634 
    635   @staticmethod
    636   def _override_operator(operator, func):
    637     _override_helper(Tensor, operator, func)
    638 
    639   def __bool__(self):
    640     """Dummy method to prevent a tensor from being used as a Python `bool`.
    641 
    642     This overload raises a `TypeError` when the user inadvertently
    643     treats a `Tensor` as a boolean (e.g. in an `if` statement). For
    644     example:
    645 
    646     ```python
    647     if tf.constant(True):  # Will raise.
    648       # ...
    649 
    650     if tf.constant(5) < tf.constant(7):  # Will raise.
    651       # ...
    652     ```
    653 
    654     This disallows ambiguities between testing the Python value vs testing the
    655     dynamic condition of the `Tensor`.
    656 
    657     Raises:
    658       `TypeError`.
    659     """
    660     raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
    661                     "Use `if t is not None:` instead of `if t:` to test if a "
    662                     "tensor is defined, and use TensorFlow ops such as "
    663                     "tf.cond to execute subgraphs conditioned on the value of "
    664                     "a tensor.")
    665 
    666   def __nonzero__(self):
    667     """Dummy method to prevent a tensor from being used as a Python `bool`.
    668 
    669     This is the Python 2.x counterpart to `__bool__()` above.
    670 
    671     Raises:
    672       `TypeError`.
    673     """
    674     raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
    675                     "Use `if t is not None:` instead of `if t:` to test if a "
    676                     "tensor is defined, and use TensorFlow ops such as "
    677                     "tf.cond to execute subgraphs conditioned on the value of "
    678                     "a tensor.")
    679 
    680   def eval(self, feed_dict=None, session=None):
    681     """Evaluates this tensor in a `Session`.
    682 
    683     Calling this method will execute all preceding operations that
    684     produce the inputs needed for the operation that produces this
    685     tensor.
    686 
    687     *N.B.* Before invoking `Tensor.eval()`, its graph must have been
    688     launched in a session, and either a default session must be
    689     available, or `session` must be specified explicitly.
    690 
    691     Args:
    692       feed_dict: A dictionary that maps `Tensor` objects to feed values.
    693         See `tf.Session.run` for a
    694         description of the valid feed values.
    695       session: (Optional.) The `Session` to be used to evaluate this tensor. If
    696         none, the default session will be used.
    697 
    698     Returns:
    699       A numpy array corresponding to the value of this tensor.
    700 
    701     """
    702     return _eval_using_default_session(self, feed_dict, self.graph, session)
    703 
    704 
    705 # TODO(agarwal): consider getting rid of this.
    706 class _EagerTensorBase(Tensor):
    707   """Base class for EagerTensor."""
    708 
    709   @property
    710   def dtype(self):
    711     # Note: using the intern table directly here as this is
    712     # performance-sensitive in some models.
    713     return dtypes._INTERN_TABLE[self._datatype_enum()]  # pylint: disable=protected-access
    714 
    715   def numpy(self):
    716     """Returns a numpy array or a scalar with the same contents as the Tensor.
    717 
    718     TODO(ashankar,agarwal): Perhaps this should NOT reference the underlying
    719     buffer but instead always explicitly copy? Note that currently it may or may
    720     not copy based on whether the numpy data is properly aligned or not.
    721 
    722     Returns:
    723       A numpy array or a scalar. Numpy array may share memory with the
    724       Tensor object. Any changes to one may be reflected in the other. A scalar
    725       value is returned when self has rank 0.
    726 
    727     Raises:
    728       ValueError: if the type of this Tensor is not representable in numpy.
    729     """
    730     if self.dtype == dtypes.resource:
    731       raise ValueError("Resource handles are not convertible to numpy.")
    732     return self._cpu_nograd()._numpy()  # pylint: disable=protected-access
    733 
    734   # __int__, __float__ and __index__ may copy the tensor to CPU and
    735   # only work for scalars; values are cast as per numpy.
    736   def __int__(self):
    737     return int(self.numpy())
    738 
    739   def __float__(self):
    740     return float(self.numpy())
    741 
    742   def __index__(self):
    743     return int(self.numpy())
    744 
    745   def __array__(self, dtype=None):
    746     return np.array(self.numpy(), dtype=dtype)
    747 
    748   def __format__(self, format_spec):
    749     return self.numpy().__format__(format_spec)
    750 
    751   def __reduce__(self):
    752     return (convert_to_tensor, (self.numpy(),))
    753 
    754   def _numpy(self):
    755     raise NotImplementedError()
    756 
    757   @property
    758   def backing_device(self):
    759     """Returns the name of the device holding this tensor's memory.
    760 
    761     `.backing_device` is usually the same as `.device`, which returns
    762     the device on which the kernel of the operation that produced this tensor
    763     ran. However, some operations can produce tensors on a different device
    764     (e.g., an operation that executes on the GPU but produces output tensors
    765     in host memory).
    766     """
    767     raise NotImplementedError()
    768 
    769   def __copy__(self):
    770     # Eager Tensors are immutable so it's safe to return themselves as a copy.
    771     return self
    772 
    773   def __deepcopy__(self, memo):
    774     # Eager Tensors are immutable so it's safe to return themselves as a copy.
    775     del memo
    776     return self
    777 
    778   def _datatype_enum(self):
    779     raise NotImplementedError()
    780 
    781   def _shape_tuple(self):
    782     """The shape of this Tensor, as a tuple.
    783 
    784     This is more performant than tuple(shape().as_list()) as it avoids
    785     two list and one object creation. Marked private for now as from an API
    786     perspective, it would be better to have a single performant way of
    787     getting a shape rather than exposing shape() and shape_tuple()
    788     (and heaven forbid, shape_list() etc. as well!). Punting on that for now,
    789     but ideally one would work things out and remove the need for this method.
    790 
    791     Returns:
    792       tuple with the shape.
    793     """
    794     raise NotImplementedError()
    795 
    796   def _rank(self):
    797     """Integer rank of this Tensor.
    798 
    799     Unlike regular Tensors, the rank is always known for EagerTensors.
    800 
    801     This is more performant than len(self._shape_tuple())
    802 
    803     Returns:
    804       Integer rank
    805     """
    806     raise NotImplementedError()
    807 
    808   def _num_elements(self):
    809     """Number of elements of this Tensor.
    810 
    811     Unlike regular Tensors, the number of elements is always known for
    812     EagerTensors.
    813 
    814     This is more performant than tensor.shape.num_elements
    815 
    816     Returns:
    817       Long - num elements in the tensor
    818     """
    819     raise NotImplementedError()
    820 
    821   def _copy_to_device(self, context, device):  # pylint: disable=redefined-outer-name
    822     raise NotImplementedError()
    823 
    824   def __str__(self):
    825     return "tf.Tensor(%s, shape=%s, dtype=%s)" % (numpy_text(self),
    826                                                   self.shape,
    827                                                   self.dtype.name)
    828 
    829   def __repr__(self):
    830     return "<tf.Tensor: id=%s, shape=%s, dtype=%s, numpy=%s>" % (
    831         self._id, self.shape, self.dtype.name, numpy_text(self, is_repr=True))
    832 
    833   @staticmethod
    834   def _override_operator(name, func):
    835     setattr(_EagerTensorBase, name, func)
    836 
    837   def _copy_nograd(self, ctx=None, device_name=None):
    838     """Copies tensor to dest device, but doesn't record the operation."""
    839     # pylint: disable=protected-access
    840     # Creates a new tensor on the dest device.
    841     if ctx is None:
    842       ctx = context.context()
    843     if device_name is None:
    844       device_name = ctx.device_name
    845     # pylint: disable=protected-access
    846     try:
    847       new_tensor = self._copy_to_device(context=ctx._handle, device=device_name)
    848     except core._NotOkStatusException as e:
    849       six.raise_from(core._status_to_exception(e.code, e.message), None)
    850     return new_tensor
    851 
    852   def _copy(self, ctx=None, device_name=None):
    853     """Copies tensor to dest device."""
    854     new_tensor = self._copy_nograd(ctx, device_name)
    855     # Record the copy on tape and define backprop copy as well.
    856     if context.executing_eagerly():
    857       self_device = self.device
    858       def grad_fun(dresult):
    859         return [dresult._copy(device_name=self_device)]
    860       tape.record_operation("_copy", [new_tensor], [self], grad_fun)
    861     return new_tensor
    862     # pylint: enable=protected-access
    863 
    864   @property
    865   def shape(self):
    866     if self._tensor_shape is None:  # pylint: disable=access-member-before-definition
    867       # `_tensor_shape` is declared and defined in the definition of
    868       # `EagerTensor`, in C.
    869       self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple())
    870     return self._tensor_shape
    871 
    872   def get_shape(self):
    873     """Alias of Tensor.shape."""
    874     return self.shape
    875 
    876   def _shape_as_list(self):
    877     """The shape of the tensor as a list."""
    878     return list(self._shape_tuple())
    879 
    880   @property
    881   def ndim(self):
    882     """Returns the number of Tensor dimensions."""
    883     return self.shape.ndims
    884 
    885   def __len__(self):
    886     """Returns the length of the first dimension in the Tensor."""
    887     if not self.shape.ndims:
    888       raise TypeError("Scalar tensor has no `len()`")
    889     return self._shape_tuple()[0]
    890 
    891   def _cpu_nograd(self):
    892     """A copy of this Tensor with contents backed by host memory.
    893 
    894     The copy cannot be differentiated through.
    895 
    896     Returns:
    897       A CPU-memory backed Tensor object with the same contents as this Tensor.
    898     """
    899     return self._copy_nograd(context.context(), "CPU:0")
    900 
    901   def cpu(self):
    902     """A copy of this Tensor with contents backed by host memory."""
    903     return self._copy(context.context(), "CPU:0")
    904 
    905   def gpu(self, gpu_index=0):
    906     """A copy of this Tensor with contents backed by memory on the GPU.
    907 
    908     Arguments:
    909       gpu_index: Identifies which GPU to place the contents on the returned
    910         Tensor in.
    911 
    912     Returns:
    913       A GPU-memory backed Tensor object initialized with the same contents
    914       as this Tensor.
    915     """
    916     return self._copy(context.context(), "GPU:" + str(gpu_index))
    917 
    918   def __bool__(self):
    919     return bool(self.numpy())
    920 
    921   def __nonzero__(self):
    922     return self.__bool__()
    923 
    924   def set_shape(self, shape):
    925     if not self.shape.is_compatible_with(shape):
    926       raise ValueError(
    927           "Tensor's shape %s is not compatible with supplied shape %s" %
    928           (self.shape, shape))
    929 
    930   # Methods not supported / implemented for Eager Tensors.
    931   @property
    932   def op(self):
    933     raise AttributeError(
    934         "Tensor.op is meaningless when eager execution is enabled.")
    935 
    936   @property
    937   def graph(self):
    938     raise AttributeError(
    939         "Tensor.graph is meaningless when eager execution is enabled.")
    940 
    941   @property
    942   def name(self):
    943     raise AttributeError(
    944         "Tensor.name is meaningless when eager execution is enabled.")
    945 
    946   @property
    947   def value_index(self):
    948     raise AttributeError(
    949         "Tensor.value_index is meaningless when eager execution is enabled.")
    950 
    951   def consumers(self):
    952     raise NotImplementedError(
    953         "Tensor.consumers is meaningless when eager execution is enabled.")
    954 
    955   def _add_consumer(self, consumer):
    956     raise NotImplementedError(
    957         "_add_consumer not supported when eager execution is enabled.")
    958 
    959   def _as_node_def_input(self):
    960     raise NotImplementedError(
    961         "_as_node_def_input not supported when eager execution is enabled.")
    962 
    963   def _as_tf_output(self):
    964     raise NotImplementedError(
    965         "_as_tf_output not supported when eager execution is enabled.")
    966 
    967   def eval(self, feed_dict=None, session=None):
    968     raise NotImplementedError(
    969         "eval is not supported when eager execution is enabled, "
    970         "is .numpy() what you're looking for?"
    971     )
    972 
    973 
    974 # This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and
    975 # registers it with the current module.
    976 EagerTensor = c_api.TFE_Py_InitEagerTensor(_EagerTensorBase)
    977 
    978 
    979 def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False):
    980   _ = name, as_ref
    981   if dtype and not dtype.is_compatible_with(t.dtype):
    982     raise ValueError(
    983         "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
    984         (dtype.name, t.dtype.name, str(t)))
    985   return t
    986 
    987 
    988 _tensor_conversion_func_registry = {
    989     0: [(Tensor, _TensorTensorConversionFunction)]
    990 }
    991 _tensor_conversion_func_cache = {}
    992 _tensor_conversion_func_lock = threading.Lock()
    993 register_dense_tensor_like_type(Tensor)
    994 
    995 
    996 @tf_export(v1=["convert_to_tensor"])
    997 def convert_to_tensor(value, dtype=None, name=None, preferred_dtype=None,
    998                       dtype_hint=None):
    999   """Converts the given `value` to a `Tensor`.
   1000 
   1001   This function converts Python objects of various types to `Tensor`
   1002   objects. It accepts `Tensor` objects, numpy arrays, Python lists,
   1003   and Python scalars. For example:
   1004 
   1005   ```python
   1006   import numpy as np
   1007 
   1008   def my_func(arg):
   1009     arg = tf.convert_to_tensor(arg, dtype=tf.float32)
   1010     return tf.matmul(arg, arg) + arg
   1011 
   1012   # The following calls are equivalent.
   1013   value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]]))
   1014   value_2 = my_func([[1.0, 2.0], [3.0, 4.0]])
   1015   value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
   1016   ```
   1017 
   1018   This function can be useful when composing a new operation in Python
   1019   (such as `my_func` in the example above). All standard Python op
   1020   constructors apply this function to each of their Tensor-valued
   1021   inputs, which allows those ops to accept numpy arrays, Python lists,
   1022   and scalars in addition to `Tensor` objects.
   1023 
   1024   Note: This function diverges from default Numpy behavior for `float` and
   1025     `string` types when `None` is present in a Python list or scalar. Rather
   1026     than silently converting `None` values, an error will be thrown.
   1027 
   1028   Args:
   1029     value: An object whose type has a registered `Tensor` conversion function.
   1030     dtype: Optional element type for the returned tensor. If missing, the
   1031       type is inferred from the type of `value`.
   1032     name: Optional name to use if a new `Tensor` is created.
   1033     preferred_dtype: Optional element type for the returned tensor,
   1034       used when dtype is None. In some cases, a caller may not have a
   1035       dtype in mind when converting to a tensor, so preferred_dtype
   1036       can be used as a soft preference.  If the conversion to
   1037       `preferred_dtype` is not possible, this argument has no effect.
   1038     dtype_hint: same meaning as preferred_dtype, and overrides it.
   1039 
   1040   Returns:
   1041     A `Tensor` based on `value`.
   1042 
   1043   Raises:
   1044     TypeError: If no conversion function is registered for `value` to `dtype`.
   1045     RuntimeError: If a registered conversion function returns an invalid value.
   1046     ValueError: If the `value` is a tensor not of given `dtype` in graph mode.
   1047   """
   1048   preferred_dtype = deprecation.deprecated_argument_lookup(
   1049       "dtype_hint", dtype_hint, "preferred_dtype", preferred_dtype)
   1050   return convert_to_tensor_v2(value, dtype, preferred_dtype, name)
   1051 
   1052 
   1053 @tf_export("convert_to_tensor", v1=[])
   1054 def convert_to_tensor_v2(value, dtype=None, dtype_hint=None, name=None):
   1055   """Converts the given `value` to a `Tensor`.
   1056 
   1057   This function converts Python objects of various types to `Tensor`
   1058   objects. It accepts `Tensor` objects, numpy arrays, Python lists,
   1059   and Python scalars. For example:
   1060 
   1061   ```python
   1062   import numpy as np
   1063 
   1064   def my_func(arg):
   1065     arg = tf.convert_to_tensor(arg, dtype=tf.float32)
   1066     return tf.matmul(arg, arg) + arg
   1067 
   1068   # The following calls are equivalent.
   1069   value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]]))
   1070   value_2 = my_func([[1.0, 2.0], [3.0, 4.0]])
   1071   value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))
   1072   ```
   1073 
   1074   This function can be useful when composing a new operation in Python
   1075   (such as `my_func` in the example above). All standard Python op
   1076   constructors apply this function to each of their Tensor-valued
   1077   inputs, which allows those ops to accept numpy arrays, Python lists,
   1078   and scalars in addition to `Tensor` objects.
   1079 
   1080   Note: This function diverges from default Numpy behavior for `float` and
   1081     `string` types when `None` is present in a Python list or scalar. Rather
   1082     than silently converting `None` values, an error will be thrown.
   1083 
   1084   Args:
   1085     value: An object whose type has a registered `Tensor` conversion function.
   1086     dtype: Optional element type for the returned tensor. If missing, the
   1087       type is inferred from the type of `value`.
   1088     dtype_hint: Optional element type for the returned tensor,
   1089       used when dtype is None. In some cases, a caller may not have a
   1090       dtype in mind when converting to a tensor, so dtype_hint
   1091       can be used as a soft preference.  If the conversion to
   1092       `dtype_hint` is not possible, this argument has no effect.
   1093     name: Optional name to use if a new `Tensor` is created.
   1094 
   1095   Returns:
   1096     A `Tensor` based on `value`.
   1097 
   1098   Raises:
   1099     TypeError: If no conversion function is registered for `value` to `dtype`.
   1100     RuntimeError: If a registered conversion function returns an invalid value.
   1101     ValueError: If the `value` is a tensor not of given `dtype` in graph mode.
   1102   """
   1103   return internal_convert_to_tensor(
   1104       value=value,
   1105       dtype=dtype,
   1106       name=name,
   1107       preferred_dtype=dtype_hint,
   1108       as_ref=False)
   1109 
   1110 
   1111 def _error_prefix(name):
   1112   return "" if name is None else "%s: " % name
   1113 
   1114 
   1115 def internal_convert_to_tensor(value,
   1116                                dtype=None,
   1117                                name=None,
   1118                                as_ref=False,
   1119                                preferred_dtype=None,
   1120                                ctx=None,
   1121                                accept_symbolic_tensors=True):
   1122   """Implementation of the public convert_to_tensor."""
   1123   if ctx is None: ctx = context.context()
   1124   if isinstance(value, EagerTensor):
   1125     if ctx.executing_eagerly():
   1126       if dtype is not None:
   1127         dtype = dtypes.as_dtype(dtype)
   1128         value = _TensorTensorConversionFunction(value, dtype=dtype)
   1129       return value
   1130     else:
   1131       graph = get_default_graph()
   1132       if not graph.building_function:
   1133         raise RuntimeError("Attempting to capture an EagerTensor without "
   1134                            "building a function.")
   1135       return graph.capture(value, name=name)
   1136   elif ((not accept_symbolic_tensors) and
   1137         isinstance(value, Tensor) and
   1138         ctx.executing_eagerly()):
   1139     # Found a symbolic tensor in an eager context.
   1140     # This happens when we use the Keras functional API (i.e. calling layers
   1141     # on the output of `keras.Input()`, which is symbolic) while eager
   1142     # execution is enabled.
   1143     if _is_keras_symbolic_tensor(value):
   1144       # If the graph of the tensor isn't the Keras graph, we should still
   1145       # fail, for the time being. TODO(fchollet): consider allowing
   1146       # all symbolic tensors to raise this exception in this case.
   1147       raise core._SymbolicException(  # pylint: disable=protected-access
   1148           "Using the symbolic output of a Keras layer during eager execution.")
   1149 
   1150   if dtype is not None:
   1151     dtype = dtypes.as_dtype(dtype)
   1152   unwrapped_type = type(value)
   1153   conversion_func_list = _tensor_conversion_func_cache.get(unwrapped_type, None)
   1154   if conversion_func_list is None:
   1155     with _tensor_conversion_func_lock:
   1156       conversion_func_list = []
   1157       for _, funcs_at_priority in sorted(
   1158           _tensor_conversion_func_registry.items()):
   1159         for base_type, conversion_func in funcs_at_priority:
   1160           if isinstance(value, base_type):
   1161             conversion_func_list.append((base_type, conversion_func))
   1162       _tensor_conversion_func_cache[unwrapped_type] = conversion_func_list
   1163 
   1164   for base_type, conversion_func in conversion_func_list:
   1165     # If dtype is None but preferred_dtype is not None, we try to
   1166     # cast to preferred_dtype first.
   1167     ret = None
   1168     if dtype is None and preferred_dtype is not None:
   1169       try:
   1170         ret = conversion_func(
   1171             value, dtype=preferred_dtype, name=name, as_ref=as_ref)
   1172       except (TypeError, ValueError, errors.UnimplementedError,
   1173               errors.InvalidArgumentError):
   1174         # Could not coerce the conversion to use the preferred dtype.
   1175         ret = None
   1176 
   1177       if ret is not None and ret is not NotImplemented:
   1178         if (ret.dtype.base_dtype !=
   1179             dtypes.as_dtype(preferred_dtype).base_dtype):
   1180           raise TypeError("convert_to_tensor did not convert to "
   1181                           "the preferred dtype: %s vs %s " %
   1182                           (ret.dtype.base_dtype,
   1183                            dtypes.as_dtype(preferred_dtype).base_dtype))
   1184 
   1185     if ret is None:
   1186       ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
   1187 
   1188     if ret is NotImplemented:
   1189       continue
   1190 
   1191     if not isinstance(ret, Tensor):
   1192       raise RuntimeError(
   1193           "%sConversion function %r for type %s returned non-Tensor: %r" %
   1194           (_error_prefix(name), conversion_func, base_type, ret))
   1195     if dtype and not dtype.is_compatible_with(ret.dtype):
   1196       raise RuntimeError(
   1197           "%sConversion function %r for type %s returned incompatible "
   1198           "dtype: requested = %s, actual = %s" %
   1199           (_error_prefix(name), conversion_func, base_type, dtype.name,
   1200            ret.dtype.name))
   1201     return ret
   1202   raise TypeError("%sCannot convert %r with type %s to Tensor: "
   1203                   "no conversion function registered." %
   1204                   (_error_prefix(name), value, unwrapped_type))
   1205 
   1206 
   1207 def internal_convert_n_to_tensor(values,
   1208                                  dtype=None,
   1209                                  name=None,
   1210                                  as_ref=False,
   1211                                  preferred_dtype=None,
   1212                                  ctx=None):
   1213   """Converts `values` to a list of `Tensor` objects.
   1214 
   1215   Args:
   1216     values: A list of objects that can be consumed by `tf.convert_to_tensor()`.
   1217     dtype: (Optional.) The required `DType` of the returned `Tensor` objects.
   1218     name: (Optional.) A name prefix to used when a new `Tensor` is
   1219       created, in which case element `i` will be given the name `name
   1220       + '_' + i`.
   1221     as_ref: True if the caller wants the results as ref tensors.
   1222     preferred_dtype: Optional element type for the returned tensors,
   1223       used when dtype is None. In some cases, a caller may not have a
   1224       dtype in mind when converting to a tensor, so preferred_dtype
   1225       can be used as a soft preference.  If the conversion to
   1226       `preferred_dtype` is not possible, this argument has no effect.
   1227     ctx: The value of context.context().
   1228 
   1229   Returns:
   1230     A list of `Tensor` and/or `IndexedSlices` objects.
   1231 
   1232   Raises:
   1233     TypeError: If no conversion function is registered for an element in
   1234       `values`.
   1235     RuntimeError: If a registered conversion function returns an invalid
   1236       value.
   1237   """
   1238   if not isinstance(values, collections.Sequence):
   1239     raise TypeError("values must be a sequence.")
   1240   ret = []
   1241   if ctx is None: ctx = context.context()
   1242   for i, value in enumerate(values):
   1243     n = None if name is None else "%s_%d" % (name, i)
   1244     ret.append(
   1245         internal_convert_to_tensor(
   1246             value,
   1247             dtype=dtype,
   1248             name=n,
   1249             as_ref=as_ref,
   1250             preferred_dtype=preferred_dtype,
   1251             ctx=ctx))
   1252   return ret
   1253 
   1254 
   1255 def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None):
   1256   """Converts `values` to a list of `Tensor` objects.
   1257 
   1258   Args:
   1259     values: A list of objects that can be consumed by `tf.convert_to_tensor()`.
   1260     dtype: (Optional.) The required `DType` of the returned `Tensor` objects.
   1261     name: (Optional.) A name prefix to used when a new `Tensor` is
   1262       created, in which case element `i` will be given the name `name
   1263       + '_' + i`.
   1264     preferred_dtype: Optional element type for the returned tensors,
   1265       used when dtype is None. In some cases, a caller may not have a
   1266       dtype in mind when converting to a tensor, so preferred_dtype
   1267       can be used as a soft preference.  If the conversion to
   1268       `preferred_dtype` is not possible, this argument has no effect.
   1269 
   1270   Returns:
   1271     A list of `Tensor` and/or `IndexedSlices` objects.
   1272 
   1273   Raises:
   1274     TypeError: If no conversion function is registered for an element in
   1275       `values`.
   1276     RuntimeError: If a registered conversion function returns an invalid
   1277       value.
   1278   """
   1279   return internal_convert_n_to_tensor(
   1280       values=values,
   1281       dtype=dtype,
   1282       name=name,
   1283       preferred_dtype=preferred_dtype,
   1284       as_ref=False)
   1285 
   1286 
   1287 @tf_export(v1=["convert_to_tensor_or_indexed_slices"])
   1288 def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
   1289   """Converts the given object to a `Tensor` or an `IndexedSlices`.
   1290 
   1291   If `value` is an `IndexedSlices` or `SparseTensor` it is returned
   1292   unmodified. Otherwise, it is converted to a `Tensor` using
   1293   `convert_to_tensor()`.
   1294 
   1295   Args:
   1296     value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed
   1297       by `convert_to_tensor()`.
   1298     dtype: (Optional.) The required `DType` of the returned `Tensor` or
   1299       `IndexedSlices`.
   1300     name: (Optional.) A name to use if a new `Tensor` is created.
   1301 
   1302   Returns:
   1303     A `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`.
   1304 
   1305   Raises:
   1306     ValueError: If `dtype` does not match the element type of `value`.
   1307   """
   1308   return internal_convert_to_tensor_or_indexed_slices(
   1309       value=value, dtype=dtype, name=name, as_ref=False)
   1310 
   1311 
   1312 def internal_convert_to_tensor_or_indexed_slices(value,
   1313                                                  dtype=None,
   1314                                                  name=None,
   1315                                                  as_ref=False):
   1316   """Converts the given object to a `Tensor` or an `IndexedSlices`.
   1317 
   1318   If `value` is an `IndexedSlices` or `SparseTensor` it is returned
   1319   unmodified. Otherwise, it is converted to a `Tensor` using
   1320   `convert_to_tensor()`.
   1321 
   1322   Args:
   1323     value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed
   1324       by `convert_to_tensor()`.
   1325     dtype: (Optional.) The required `DType` of the returned `Tensor` or
   1326       `IndexedSlices`.
   1327     name: (Optional.) A name to use if a new `Tensor` is created.
   1328     as_ref: True if the caller wants the results as ref tensors.
   1329 
   1330   Returns:
   1331     A `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`.
   1332 
   1333   Raises:
   1334     ValueError: If `dtype` does not match the element type of `value`.
   1335   """
   1336   if isinstance(value, EagerTensor) and not context.executing_eagerly():
   1337     return internal_convert_to_tensor(
   1338         value, dtype=dtype, name=name, as_ref=as_ref)
   1339   elif isinstance(value, _TensorLike):
   1340     if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype):
   1341       raise ValueError(
   1342           "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
   1343           (dtypes.as_dtype(dtype).name, value.dtype.name, str(value)))
   1344     return value
   1345   else:
   1346     return internal_convert_to_tensor(
   1347         value, dtype=dtype, name=name, as_ref=as_ref)
   1348 
   1349 
   1350 def internal_convert_n_to_tensor_or_indexed_slices(values,
   1351                                                    dtype=None,
   1352                                                    name=None,
   1353                                                    as_ref=False):
   1354   """Converts `values` to a list of `Tensor` or `IndexedSlices` objects.
   1355 
   1356   Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
   1357   unmodified.
   1358 
   1359   Args:
   1360     values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that
   1361       can be consumed by `convert_to_tensor()`.
   1362     dtype: (Optional.) The required `DType` of the returned `Tensor` or
   1363       `IndexedSlices`.
   1364     name: (Optional.) A name prefix to used when a new `Tensor` is
   1365       created, in which case element `i` will be given the name `name
   1366       + '_' + i`.
   1367     as_ref: True if the caller wants the results as ref tensors.
   1368 
   1369   Returns:
   1370     A list of `Tensor`, `IndexedSlices`, `SparseTensor` and/or `None` objects.
   1371 
   1372   Raises:
   1373     TypeError: If no conversion function is registered for an element in
   1374       `values`.
   1375     RuntimeError: If a registered conversion function returns an invalid
   1376       value.
   1377   """
   1378   if not isinstance(values, collections.Sequence):
   1379     raise TypeError("values must be a sequence.")
   1380   ret = []
   1381   for i, value in enumerate(values):
   1382     if value is None:
   1383       ret.append(value)
   1384     else:
   1385       n = None if name is None else "%s_%d" % (name, i)
   1386       ret.append(
   1387           internal_convert_to_tensor_or_indexed_slices(
   1388               value, dtype=dtype, name=n, as_ref=as_ref))
   1389   return ret
   1390 
   1391 
   1392 def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None):
   1393   """Converts `values` to a list of `Output` or `IndexedSlices` objects.
   1394 
   1395   Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
   1396   unmodified.
   1397 
   1398   Args:
   1399     values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that
   1400       can be consumed by `convert_to_tensor()`.
   1401     dtype: (Optional.) The required `DType` of the returned `Tensor`
   1402       `IndexedSlices`.
   1403     name: (Optional.) A name prefix to used when a new `Tensor` is
   1404       created, in which case element `i` will be given the name `name
   1405       + '_' + i`.
   1406 
   1407   Returns:
   1408     A list of `Tensor`, `IndexedSlices`, and/or `SparseTensor` objects.
   1409 
   1410   Raises:
   1411     TypeError: If no conversion function is registered for an element in
   1412       `values`.
   1413     RuntimeError: If a registered conversion function returns an invalid
   1414       value.
   1415   """
   1416   return internal_convert_n_to_tensor_or_indexed_slices(
   1417       values=values, dtype=dtype, name=name, as_ref=False)
   1418 
   1419 
   1420 def convert_to_tensor_or_composite(value, dtype=None, name=None):
   1421   """Converts the given object to a `Tensor` or `CompositeTensor`.
   1422 
   1423   If `value` is a `CompositeTensor` it is returned unmodified. Otherwise, it
   1424   is converted to a `Tensor` using `convert_to_tensor()`.
   1425 
   1426   Args:
   1427     value: A `CompositeTensor` or an object that can be consumed
   1428       by `convert_to_tensor()`.
   1429     dtype: (Optional.) The required `DType` of the returned `Tensor` or
   1430       `CompositeTensor`.
   1431     name: (Optional.) A name to use if a new `Tensor` is created.
   1432 
   1433   Returns:
   1434     A `Tensor` or `CompositeTensor`, based on `value`.
   1435 
   1436   Raises:
   1437     ValueError: If `dtype` does not match the element type of `value`.
   1438   """
   1439   return internal_convert_to_tensor_or_composite(
   1440       value=value, dtype=dtype, name=name, as_ref=False)
   1441 
   1442 
   1443 def internal_convert_to_tensor_or_composite(value,
   1444                                             dtype=None,
   1445                                             name=None,
   1446                                             as_ref=False):
   1447   """Converts the given object to a `Tensor` or `CompositeTensor`.
   1448 
   1449   If `value` is a `CompositeTensor` it is returned unmodified.  Otherwise, it
   1450   is converted to a `Tensor` using `convert_to_tensor()`.
   1451 
   1452   Args:
   1453     value: A `CompositeTensor`, or an object that can be consumed
   1454       by `convert_to_tensor()`.
   1455     dtype: (Optional.) The required `DType` of the returned `Tensor` or
   1456       `CompositeTensor`.
   1457     name: (Optional.) A name to use if a new `Tensor` is created.
   1458     as_ref: True if the caller wants the results as ref tensors.
   1459 
   1460   Returns:
   1461     A `Tensor` or `CompositeTensor`, based on `value`.
   1462 
   1463   Raises:
   1464     ValueError: If `dtype` does not match the element type of `value`.
   1465   """
   1466   if isinstance(value, composite_tensor.CompositeTensor):
   1467     value_dtype = getattr(value, "dtype", None)
   1468     if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value_dtype):
   1469       raise ValueError(
   1470           "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
   1471           (dtypes.as_dtype(dtype).name, value.dtype.name, str(value)))
   1472     return value
   1473   else:
   1474     return internal_convert_to_tensor(
   1475         value, dtype=dtype, name=name, as_ref=as_ref)
   1476 
   1477 
   1478 def internal_convert_n_to_tensor_or_composite(values,
   1479                                               dtype=None,
   1480                                               name=None,
   1481                                               as_ref=False):
   1482   """Converts `values` to a list of `Tensor` or `CompositeTensor` objects.
   1483 
   1484   Any `CompositeTensor` objects in `values` are returned unmodified.
   1485 
   1486   Args:
   1487     values: A list of `None`, `CompositeTensor`, or objects that
   1488       can be consumed by `convert_to_tensor()`.
   1489     dtype: (Optional.) The required `DType` of the returned `Tensor`s or
   1490       `CompositeTensor`s.
   1491     name: (Optional.) A name prefix to used when a new `Tensor` is
   1492       created, in which case element `i` will be given the name `name
   1493       + '_' + i`.
   1494     as_ref: True if the caller wants the results as ref tensors.
   1495 
   1496   Returns:
   1497     A list of `Tensor`, `CompositeTensor`, and/or `None` objects.
   1498 
   1499   Raises:
   1500     TypeError: If no conversion function is registered for an element in
   1501       `values`.
   1502     RuntimeError: If a registered conversion function returns an invalid
   1503       value.
   1504   """
   1505   if not isinstance(values, collections.Sequence):
   1506     raise TypeError("values must be a sequence.")
   1507   ret = []
   1508   for i, value in enumerate(values):
   1509     if value is None:
   1510       ret.append(value)
   1511     else:
   1512       n = None if name is None else "%s_%d" % (name, i)
   1513       ret.append(
   1514           internal_convert_to_tensor_or_composite(
   1515               value, dtype=dtype, name=n, as_ref=as_ref))
   1516   return ret
   1517 
   1518 
   1519 def convert_n_to_tensor_or_composite(values, dtype=None, name=None):
   1520   """Converts `values` to a list of `Output` or `CompositeTensor` objects.
   1521 
   1522   Any `CompositeTensor` objects in `values` are returned unmodified.
   1523 
   1524   Args:
   1525     values: A list of `None`, `CompositeTensor``, or objects that
   1526       can be consumed by `convert_to_tensor()`.
   1527     dtype: (Optional.) The required `DType` of the returned `Tensor`s or
   1528       `CompositeTensor`s.
   1529     name: (Optional.) A name prefix to used when a new `Tensor` is
   1530       created, in which case element `i` will be given the name `name
   1531       + '_' + i`.
   1532 
   1533   Returns:
   1534     A list of `Tensor` and/or `CompositeTensor` objects.
   1535 
   1536   Raises:
   1537     TypeError: If no conversion function is registered for an element in
   1538       `values`.
   1539     RuntimeError: If a registered conversion function returns an invalid
   1540       value.
   1541   """
   1542   return internal_convert_n_to_tensor_or_composite(
   1543       values=values, dtype=dtype, name=name, as_ref=False)
   1544 
   1545 
   1546 # TODO(josh11b): Add ctx argument to conversion_func() signature.
   1547 @tf_export("register_tensor_conversion_function")
   1548 def register_tensor_conversion_function(base_type,
   1549                                         conversion_func,
   1550                                         priority=100):
   1551   """Registers a function for converting objects of `base_type` to `Tensor`.
   1552 
   1553   The conversion function must have the following signature:
   1554 
   1555   ```python
   1556       def conversion_func(value, dtype=None, name=None, as_ref=False):
   1557         # ...
   1558   ```
   1559 
   1560   It must return a `Tensor` with the given `dtype` if specified. If the
   1561   conversion function creates a new `Tensor`, it should use the given
   1562   `name` if specified. All exceptions will be propagated to the caller.
   1563 
   1564   The conversion function may return `NotImplemented` for some
   1565   inputs. In this case, the conversion process will continue to try
   1566   subsequent conversion functions.
   1567 
   1568   If `as_ref` is true, the function must return a `Tensor` reference,
   1569   such as a `Variable`.
   1570 
   1571   NOTE: The conversion functions will execute in order of priority,
   1572   followed by order of registration. To ensure that a conversion function
   1573   `F` runs before another conversion function `G`, ensure that `F` is
   1574   registered with a smaller priority than `G`.
   1575 
   1576   Args:
   1577     base_type: The base type or tuple of base types for all objects that
   1578       `conversion_func` accepts.
   1579     conversion_func: A function that converts instances of `base_type` to
   1580       `Tensor`.
   1581     priority: Optional integer that indicates the priority for applying this
   1582       conversion function. Conversion functions with smaller priority values
   1583       run earlier than conversion functions with larger priority values.
   1584       Defaults to 100.
   1585 
   1586   Raises:
   1587     TypeError: If the arguments do not have the appropriate type.
   1588 
   1589   """
   1590   global _tensor_conversion_func_cache
   1591   with _tensor_conversion_func_lock:
   1592     if not (isinstance(base_type, type) or
   1593             (isinstance(base_type, tuple) and
   1594              all(isinstance(x, type) for x in base_type))):
   1595       raise TypeError("base_type must be a type or a tuple of types.")
   1596     if not callable(conversion_func):
   1597       raise TypeError("conversion_func must be callable.")
   1598 
   1599     # context._context is checked so that we don't inadvertently create it.
   1600     # This is because enable_eager_execution will fail when called from the main
   1601     # function if the context._context is already created, and the
   1602     # register_tensor_conversion_function calls happen when the module is
   1603     # imported.
   1604     if context._context is not None and context.executing_eagerly(
   1605     ) and isinstance(base_type, six.integer_types + (
   1606         float,
   1607         np.ndarray,
   1608     )):
   1609       # TODO(nareshmodi): consider setting a context variable which disables the
   1610       # fastpath instead.
   1611       raise TypeError(
   1612           "Cannot register conversions for numpy arrays, python number types "
   1613           "when executing eagerly.")
   1614 
   1615     try:
   1616       funcs_at_priority = _tensor_conversion_func_registry[priority]
   1617     except KeyError:
   1618       funcs_at_priority = []
   1619       _tensor_conversion_func_registry[priority] = funcs_at_priority
   1620     funcs_at_priority.append((base_type, conversion_func))
   1621     _tensor_conversion_func_cache = {}
   1622 
   1623 
   1624 @tf_export("IndexedSlices")
   1625 class IndexedSlices(_TensorLike, composite_tensor.CompositeTensor):
   1626   """A sparse representation of a set of tensor slices at given indices.
   1627 
   1628   This class is a simple wrapper for a pair of `Tensor` objects:
   1629 
   1630   * `values`: A `Tensor` of any dtype with shape `[D0, D1, ..., Dn]`.
   1631   * `indices`: A 1-D integer `Tensor` with shape `[D0]`.
   1632 
   1633   An `IndexedSlices` is typically used to represent a subset of a larger
   1634   tensor `dense` of shape `[LARGE0, D1, .. , DN]` where `LARGE0 >> D0`.
   1635   The values in `indices` are the indices in the first dimension of
   1636   the slices that have been extracted from the larger tensor.
   1637 
   1638   The dense tensor `dense` represented by an `IndexedSlices` `slices` has
   1639 
   1640   ```python
   1641   dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]
   1642   ```
   1643 
   1644   The `IndexedSlices` class is used principally in the definition of
   1645   gradients for operations that have sparse gradients
   1646   (e.g. `tf.gather`).
   1647 
   1648   Contrast this representation with
   1649   `tf.SparseTensor`,
   1650   which uses multi-dimensional indices and scalar values.
   1651   """
   1652 
   1653   def __init__(self, values, indices, dense_shape=None):
   1654     """Creates an `IndexedSlices`."""
   1655     _get_graph_from_inputs([values, indices, dense_shape])
   1656     self._values = values
   1657     self._indices = indices
   1658     self._dense_shape = dense_shape
   1659 
   1660   @property
   1661   def values(self):
   1662     """A `Tensor` containing the values of the slices."""
   1663     return self._values
   1664 
   1665   @property
   1666   def indices(self):
   1667     """A 1-D `Tensor` containing the indices of the slices."""
   1668     return self._indices
   1669 
   1670   @property
   1671   def dense_shape(self):
   1672     """A 1-D `Tensor` containing the shape of the corresponding dense tensor."""
   1673     return self._dense_shape
   1674 
   1675   @property
   1676   def name(self):
   1677     """The name of this `IndexedSlices`."""
   1678     return self.values.name
   1679 
   1680   @property
   1681   def device(self):
   1682     """The name of the device on which `values` will be produced, or `None`."""
   1683     return self.values.device
   1684 
   1685   @property
   1686   def op(self):
   1687     """The `Operation` that produces `values` as an output."""
   1688     return self.values.op
   1689 
   1690   @property
   1691   def dtype(self):
   1692     """The `DType` of elements in this tensor."""
   1693     return self.values.dtype
   1694 
   1695   @property
   1696   def graph(self):
   1697     """The `Graph` that contains the values, indices, and shape tensors."""
   1698     return self._values.graph
   1699 
   1700   def __str__(self):
   1701     return "IndexedSlices(indices=%s, values=%s%s)" % (
   1702         self._indices, self._values, (", dense_shape=%s" % self._dense_shape)
   1703         if self._dense_shape is not None else "")
   1704 
   1705   def __neg__(self):
   1706     return IndexedSlices(-self.values, self.indices, self.dense_shape)
   1707 
   1708   def _to_components(self):
   1709     if self._dense_shape is None:
   1710       return (self._values, self._indices)
   1711     else:
   1712       return (self._values, self._indices, self._dense_shape)
   1713 
   1714   @classmethod
   1715   def _from_components(cls, components):
   1716     return cls(*components)
   1717 
   1718   def _shape_invariant_to_components(self, shape=None):
   1719     if shape is None:
   1720       shape = self._values.shape
   1721     if self._dense_shape is None:
   1722       return [shape, shape[:1]]  # values, indices
   1723     else:
   1724       # values, indices, dense_shape
   1725       return [shape, shape[:1], tensor_shape.TensorShape([shape.ndims])]
   1726 
   1727   @property
   1728   def _is_graph_tensor(self):
   1729     return hasattr(self._values, 'graph')
   1730 
   1731 
   1732 IndexedSlicesValue = collections.namedtuple(
   1733     "IndexedSlicesValue", ["values", "indices", "dense_shape"])
   1734 
   1735 
   1736 def _device_string(dev_spec):
   1737   if isinstance(dev_spec, pydev.DeviceSpec):
   1738     return dev_spec.to_string()
   1739   else:
   1740     return dev_spec
   1741 
   1742 
   1743 def _NodeDef(op_type, name, device=None, attrs=None):  # pylint: disable=redefined-outer-name
   1744   """Create a NodeDef proto.
   1745 
   1746   Args:
   1747     op_type: Value for the "op" attribute of the NodeDef proto.
   1748     name: Value for the "name" attribute of the NodeDef proto.
   1749     device: string, device, or function from NodeDef to string.
   1750       Value for the "device" attribute of the NodeDef proto.
   1751     attrs: Optional dictionary where the key is the attribute name (a string)
   1752       and the value is the respective "attr" attribute of the NodeDef proto (an
   1753       AttrValue).
   1754 
   1755   Returns:
   1756     A node_def_pb2.NodeDef protocol buffer.
   1757   """
   1758   node_def = node_def_pb2.NodeDef()
   1759   node_def.op = compat.as_bytes(op_type)
   1760   node_def.name = compat.as_bytes(name)
   1761   if attrs is not None:
   1762     for k, v in six.iteritems(attrs):
   1763       node_def.attr[k].CopyFrom(v)
   1764   if device is not None:
   1765     if callable(device):
   1766       node_def.device = device(node_def)
   1767     else:
   1768       node_def.device = _device_string(device)
   1769   return node_def
   1770 
   1771 
   1772 # Copied from core/framework/node_def_util.cc
   1773 # TODO(mrry,josh11b): Consolidate this validation in C++ code.
   1774 _VALID_OP_NAME_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\-/]*$")
   1775 _VALID_SCOPE_NAME_REGEX = re.compile("^[A-Za-z0-9_.\\-/]*$")
   1776 
   1777 
   1778 def _create_c_op(graph, node_def, inputs, control_inputs):
   1779   """Creates a TF_Operation.
   1780 
   1781   Args:
   1782     graph: a `Graph`.
   1783     node_def: `node_def_pb2.NodeDef` for the operation to create.
   1784     inputs: A list of `Tensor`s (corresponding to scalar inputs) and lists of
   1785       `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N",
   1786       "list(int64)"). The length of the list should be equal to the number of
   1787       inputs specified by this operation's op def.
   1788     control_inputs: A list of `Operation`s to set as control dependencies.
   1789 
   1790   Returns:
   1791     A wrapped TF_Operation*.
   1792   """
   1793   # pylint: disable=protected-access
   1794   op_desc = c_api.TF_NewOperation(graph._c_graph,
   1795                                   compat.as_str(node_def.op),
   1796                                   compat.as_str(node_def.name))
   1797   if node_def.device:
   1798     c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device))
   1799   # Add inputs
   1800   for op_input in inputs:
   1801     if isinstance(op_input, (list, tuple)):
   1802       c_api.TF_AddInputList(op_desc, [t._as_tf_output() for t in op_input])
   1803     else:
   1804       c_api.TF_AddInput(op_desc, op_input._as_tf_output())
   1805 
   1806   # Add control inputs
   1807   for control_input in control_inputs:
   1808     c_api.TF_AddControlInput(op_desc, control_input._c_op)
   1809   # pylint: enable=protected-access
   1810 
   1811   # Add attrs
   1812   for name, attr_value in node_def.attr.items():
   1813     serialized = attr_value.SerializeToString()
   1814     # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
   1815     # It might be worth creating a convenient way to re-use the same status.
   1816     c_api.TF_SetAttrValueProto(op_desc, compat.as_str(name), serialized)
   1817 
   1818   try:
   1819     c_op = c_api.TF_FinishOperation(op_desc)
   1820   except errors.InvalidArgumentError as e:
   1821     # Convert to ValueError for backwards compatibility.
   1822     raise ValueError(str(e))
   1823 
   1824   return c_op
   1825 
   1826 
   1827 @tf_export("Operation")
   1828 class Operation(object):
   1829   """Represents a graph node that performs computation on tensors.
   1830 
   1831   An `Operation` is a node in a TensorFlow `Graph` that takes zero or
   1832   more `Tensor` objects as input, and produces zero or more `Tensor`
   1833   objects as output. Objects of type `Operation` are created by
   1834   calling a Python op constructor (such as
   1835   `tf.matmul`)
   1836   or `tf.Graph.create_op`.
   1837 
   1838   For example `c = tf.matmul(a, b)` creates an `Operation` of type
   1839   "MatMul" that takes tensors `a` and `b` as input, and produces `c`
   1840   as output.
   1841 
   1842   After the graph has been launched in a session, an `Operation` can
   1843   be executed by passing it to
   1844   `tf.Session.run`.
   1845   `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`.
   1846   """
   1847 
   1848   def __init__(self,
   1849                node_def,
   1850                g,
   1851                inputs=None,
   1852                output_types=None,
   1853                control_inputs=None,
   1854                input_types=None,
   1855                original_op=None,
   1856                op_def=None):
   1857     r"""Creates an `Operation`.
   1858 
   1859     NOTE: This constructor validates the name of the `Operation` (passed
   1860     as `node_def.name`). Valid `Operation` names match the following
   1861     regular expression:
   1862 
   1863         [A-Za-z0-9.][A-Za-z0-9_.\\-/]*
   1864 
   1865     Args:
   1866       node_def: `node_def_pb2.NodeDef`.  `NodeDef` for the `Operation`.
   1867         Used for attributes of `node_def_pb2.NodeDef`, typically `name`,
   1868         `op`, and `device`.  The `input` attribute is irrelevant here
   1869         as it will be computed when generating the model.
   1870       g: `Graph`. The parent graph.
   1871       inputs: list of `Tensor` objects. The inputs to this `Operation`.
   1872       output_types: list of `DType` objects.  List of the types of the
   1873         `Tensors` computed by this operation.  The length of this list indicates
   1874         the number of output endpoints of the `Operation`.
   1875       control_inputs: list of operations or tensors from which to have a
   1876         control dependency.
   1877       input_types: List of `DType` objects representing the
   1878         types of the tensors accepted by the `Operation`.  By default
   1879         uses `[x.dtype.base_dtype for x in inputs]`.  Operations that expect
   1880         reference-typed inputs must specify these explicitly.
   1881       original_op: Optional. Used to associate the new `Operation` with an
   1882         existing `Operation` (for example, a replica with the op that was
   1883         replicated).
   1884       op_def: Optional. The `op_def_pb2.OpDef` proto that describes the
   1885         op type that this `Operation` represents.
   1886 
   1887     Raises:
   1888       TypeError: if control inputs are not Operations or Tensors,
   1889         or if `node_def` is not a `NodeDef`,
   1890         or if `g` is not a `Graph`,
   1891         or if `inputs` are not tensors,
   1892         or if `inputs` and `input_types` are incompatible.
   1893       ValueError: if the `node_def` name is not valid.
   1894     """
   1895     # For internal use only: `node_def` can be set to a TF_Operation to create
   1896     # an Operation for that op. This is useful for creating Operations for ops
   1897     # indirectly created by C API methods, e.g. the ops created by
   1898     # TF_ImportGraphDef. When `node_def` is a TF_Operation, all optional fields
   1899     # should be None.
   1900 
   1901     if isinstance(node_def, node_def_pb2.NodeDef):
   1902       if node_def.ByteSize() >= (1 << 31) or node_def.ByteSize() < 0:
   1903         raise ValueError(
   1904             "Cannot create a tensor proto whose content is larger than 2GB.")
   1905       if not _VALID_OP_NAME_REGEX.match(node_def.name):
   1906         raise ValueError("'%s' is not a valid node name" % node_def.name)
   1907       c_op = None
   1908     elif type(node_def).__name__ == "SwigPyObject":
   1909       assert inputs is None
   1910       assert output_types is None
   1911       assert control_inputs is None
   1912       assert input_types is None
   1913       assert original_op is None
   1914       assert op_def is None
   1915       c_op = node_def
   1916     else:
   1917       raise TypeError("node_def needs to be a NodeDef: %s" % node_def)
   1918 
   1919     if not isinstance(g, Graph):
   1920       raise TypeError("g needs to be a Graph: %s" % g)
   1921     self._graph = g
   1922 
   1923     if inputs is None:
   1924       inputs = []
   1925     elif not isinstance(inputs, list):
   1926       raise TypeError("inputs needs to be a list of Tensors: %s" % inputs)
   1927     for a in inputs:
   1928       if not isinstance(a, Tensor):
   1929         raise TypeError("input needs to be a Tensor: %s" % a)
   1930     if input_types is None:
   1931       input_types = [i.dtype.base_dtype for i in inputs]
   1932     else:
   1933       if not all(
   1934           x.is_compatible_with(i.dtype)
   1935           for i, x in zip(inputs, input_types)):
   1936         raise TypeError("In op '%s', input types (%s) are not compatible "
   1937                         "with expected types (%s)" %
   1938                         (node_def.name, [i.dtype for i in inputs],
   1939                          input_types))
   1940 
   1941     # Build the list of control inputs.
   1942     control_input_ops = []
   1943     if control_inputs:
   1944       for c in control_inputs:
   1945         control_op = None
   1946         if isinstance(c, Operation):
   1947           control_op = c
   1948         elif isinstance(c, (Tensor, IndexedSlices)):
   1949           control_op = c.op
   1950         else:
   1951           raise TypeError("Control input must be an Operation, "
   1952                           "a Tensor, or IndexedSlices: %s" % c)
   1953         control_input_ops.append(control_op)
   1954 
   1955     # This will be set by self.inputs.
   1956     self._inputs_val = None
   1957 
   1958     # pylint: disable=protected-access
   1959     self._id_value = self._graph._next_id()
   1960     self._original_op = original_op
   1961     self._traceback = tf_stack.extract_stack()
   1962 
   1963     # List of _UserDevSpecs holding code location of device context manager
   1964     # invocations and the users original argument to them.
   1965     self._device_code_locations = None
   1966     # Dict mapping op name to file and line information for op colocation
   1967     # context managers.
   1968     self._colocation_code_locations = None
   1969     self._control_flow_context = self.graph._get_control_flow_context()
   1970     # pylint: enable=protected-access
   1971 
   1972     # Initialize self._c_op.
   1973     if c_op:
   1974       self._c_op = c_op
   1975     else:
   1976       if op_def is None:
   1977         op_def = self._graph._get_op_def(node_def.op)
   1978       # TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs.
   1979       # Refactor so we don't have to do this here.
   1980       grouped_inputs = self._reconstruct_sequence_inputs(
   1981           op_def, inputs, node_def.attr)
   1982       self._c_op = _create_c_op(self._graph, node_def, grouped_inputs,
   1983                                 control_input_ops)
   1984 
   1985     # Initialize self._outputs.
   1986     num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
   1987     output_types = [
   1988         c_api.TF_OperationOutputType(c_api_util.tf_output(self._c_op, i))
   1989         for i in range(num_outputs)]
   1990     self._outputs = [
   1991         Tensor(self, i, output_type)
   1992         for i, output_type in enumerate(output_types)
   1993     ]
   1994 
   1995     self._graph._add_op(self)  # pylint: disable=protected-access
   1996 
   1997     if not c_op:
   1998       self._control_flow_post_processing()
   1999 
   2000   def _control_flow_post_processing(self):
   2001     """Add this op to its control flow context.
   2002 
   2003     This may add new ops and change this op's inputs. self.inputs must be
   2004     available before calling this method.
   2005     """
   2006     for input_tensor in self.inputs:
   2007       control_flow_util.CheckInputFromValidContext(self, input_tensor.op)
   2008     if self._control_flow_context is not None:
   2009       self._control_flow_context.AddOp(self)
   2010 
   2011   def _reconstruct_sequence_inputs(self, op_def, inputs, attrs):
   2012     """Regroups a flat list of input tensors into scalar and sequence inputs.
   2013 
   2014     Args:
   2015       op_def: The `op_def_pb2.OpDef` (for knowing the input types)
   2016       inputs: a list of input `Tensor`s to the op.
   2017       attrs: mapping from attr name to `attr_value_pb2.AttrValue` (these define
   2018         how long each sequence is)
   2019 
   2020     Returns:
   2021       A list of `Tensor`s (corresponding to scalar inputs) and lists of
   2022       `Tensor`s (corresponding to sequence inputs).
   2023     """
   2024     grouped_inputs = []
   2025     i = 0
   2026     for input_arg in op_def.input_arg:
   2027       if input_arg.number_attr:
   2028         input_len = attrs[input_arg.number_attr].i
   2029         is_sequence = True
   2030       elif input_arg.type_list_attr:
   2031         input_len = len(attrs[input_arg.type_list_attr].list.type)
   2032         is_sequence = True
   2033       else:
   2034         input_len = 1
   2035         is_sequence = False
   2036 
   2037       if is_sequence:
   2038         grouped_inputs.append(inputs[i:i + input_len])
   2039       else:
   2040         grouped_inputs.append(inputs[i])
   2041       i += input_len
   2042 
   2043     assert i == len(inputs)
   2044     return grouped_inputs
   2045 
   2046   def colocation_groups(self):
   2047     """Returns the list of colocation groups of the op."""
   2048     default_colocation_group = [
   2049         compat.as_bytes("loc:@%s" % self.name)
   2050     ]
   2051     try:
   2052       class_attr = self.get_attr("_class")
   2053     except ValueError:
   2054       # This op has no explicit colocation group, so it is itself its
   2055       # own root of a colocation group.
   2056       return default_colocation_group
   2057 
   2058     attr_groups = [
   2059         class_name for class_name in class_attr
   2060         if class_name.startswith(b"loc:@")
   2061     ]
   2062 
   2063     # If there are no colocation groups in the explicit _class field,
   2064     # return the default colocation group.
   2065     return attr_groups if attr_groups else default_colocation_group
   2066 
   2067   def values(self):
   2068     """DEPRECATED: Use outputs."""
   2069     return tuple(self.outputs)
   2070 
   2071   def _get_control_flow_context(self):
   2072     """Returns the control flow context of this op.
   2073 
   2074     Returns:
   2075       A context object.
   2076     """
   2077     return self._control_flow_context
   2078 
   2079   def _set_control_flow_context(self, ctx):
   2080     """Sets the current control flow context of this op.
   2081 
   2082     Args:
   2083       ctx: a context object.
   2084     """
   2085     self._control_flow_context = ctx
   2086 
   2087   @property
   2088   def name(self):
   2089     """The full name of this operation."""
   2090     return c_api.TF_OperationName(self._c_op)
   2091 
   2092   @property
   2093   def _id(self):
   2094     """The unique integer id of this operation."""
   2095     return self._id_value
   2096 
   2097   @property
   2098   def device(self):
   2099     """The name of the device to which this op has been assigned, if any.
   2100 
   2101     Returns:
   2102       The string name of the device to which this op has been
   2103       assigned, or an empty string if it has not been assigned to a
   2104       device.
   2105     """
   2106     return c_api.TF_OperationDevice(self._c_op)
   2107 
   2108   @property
   2109   def _device_assignments(self):
   2110     """Code locations for device context managers active at op creation.
   2111 
   2112     This property will return a list of traceable_stack.TraceableObject
   2113     instances where .obj is a string representing the assigned device
   2114     (or information about the function that would be applied to this op
   2115     to compute the desired device) and the filename and lineno members
   2116     record the location of the relevant device context manager.
   2117 
   2118     For example, suppose file_a contained these lines:
   2119 
   2120       file_a.py:
   2121         15: with tf.device('/gpu:0'):
   2122         16:   node_b = tf.constant(4, name='NODE_B')
   2123 
   2124     Then a TraceableObject t_obj representing the device context manager
   2125     would have these member values:
   2126 
   2127       t_obj.obj -> '/gpu:0'
   2128       t_obj.filename = 'file_a.py'
   2129       t_obj.lineno = 15
   2130 
   2131     and node_b.op._device_assignments would return the list [t_obj].
   2132 
   2133     Returns:
   2134       [str: traceable_stack.TraceableObject, ...] as per this method's
   2135       description, above.
   2136     """
   2137     return self._device_code_locations or []
   2138 
   2139   @property
   2140   def _colocation_dict(self):
   2141     """Code locations for colocation context managers active at op creation.
   2142 
   2143     This property will return a dictionary for which the keys are nodes with
   2144     which this Operation is colocated, and for which the values are
   2145     traceable_stack.TraceableObject instances.  The TraceableObject instances
   2146     record the location of the relevant colocation context manager but have the
   2147     "obj" field set to None to prevent leaking private data.
   2148 
   2149     For example, suppose file_a contained these lines:
   2150 
   2151       file_a.py:
   2152         14: node_a = tf.constant(3, name='NODE_A')
   2153         15: with tf.colocate_with(node_a):
   2154         16:   node_b = tf.constant(4, name='NODE_B')
   2155 
   2156     Then a TraceableObject t_obj representing the colocation context manager
   2157     would have these member values:
   2158 
   2159       t_obj.obj -> None
   2160       t_obj.filename = 'file_a.py'
   2161       t_obj.lineno = 15
   2162 
   2163     and node_b.op._colocation_dict would return the dictionary
   2164 
   2165       { 'NODE_A': t_obj }
   2166 
   2167     Returns:
   2168       {str: traceable_stack.TraceableObject} as per this method's description,
   2169       above.
   2170     """
   2171     locations_dict = self._colocation_code_locations or {}
   2172     return locations_dict.copy()
   2173 
   2174   @property
   2175   def _output_types(self):
   2176     """List this operation's output types.
   2177 
   2178     Returns:
   2179       List of the types of the Tensors computed by this operation.
   2180       Each element in the list is an integer whose value is one of
   2181       the TF_DataType enums defined in c_api.h
   2182       The length of this list indicates the number of output endpoints
   2183       of the operation.
   2184     """
   2185     num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
   2186     output_types = [
   2187         c_api.TF_OperationOutputType(self._tf_output(i))
   2188         for i in xrange(num_outputs)
   2189     ]
   2190     # In all the tests we have output_types that are passed into
   2191     # Operation.__init__ are a list of ints (which is illegal according
   2192     # to the docstring), but input_types are instances of DType.
   2193     # This extra assert is to catch if we ever use DType for output_types.
   2194     if output_types:
   2195       assert isinstance(output_types[0], int)
   2196     return output_types
   2197 
   2198   def _tf_output(self, output_idx):
   2199     """Create and return a new TF_Output for output_idx'th output of this op."""
   2200     tf_output = c_api.TF_Output()
   2201     tf_output.oper = self._c_op
   2202     tf_output.index = output_idx
   2203     return tf_output
   2204 
   2205   def _tf_input(self, input_idx):
   2206     """Create and return a new TF_Input for input_idx'th input of this op."""
   2207     tf_input = c_api.TF_Input()
   2208     tf_input.oper = self._c_op
   2209     tf_input.index = input_idx
   2210     return tf_input
   2211 
   2212   def _set_device(self, device):  # pylint: disable=redefined-outer-name
   2213     """Set the device of this operation.
   2214 
   2215     Args:
   2216       device: string or device..  The device to set.
   2217     """
   2218     c_api.SetRequestedDevice(
   2219         self._graph._c_graph,  # pylint: disable=protected-access
   2220         self._c_op,  # pylint: disable=protected-access
   2221         compat.as_str(_device_string(device)))
   2222 
   2223   def _update_input(self, index, tensor):
   2224     """Update the input to this operation at the given index.
   2225 
   2226     NOTE: This is for TF internal use only. Please don't use it.
   2227 
   2228     Args:
   2229       index: the index of the input to update.
   2230       tensor: the Tensor to be used as the input at the given index.
   2231 
   2232     Raises:
   2233       TypeError: if tensor is not a Tensor,
   2234         or if input tensor type is not convertible to dtype.
   2235       ValueError: if the Tensor is from a different graph.
   2236     """
   2237     if not isinstance(tensor, Tensor):
   2238       raise TypeError("tensor must be a Tensor: %s" % tensor)
   2239     _assert_same_graph(self, tensor)
   2240 
   2241     # Reset cached inputs.
   2242     self._inputs_val = None
   2243     c_api.UpdateEdge(
   2244         self._graph._c_graph,  # pylint: disable=protected-access
   2245         tensor._as_tf_output(),  # pylint: disable=protected-access
   2246         self._tf_input(index))
   2247 
   2248   def _add_while_inputs(self, tensors):
   2249     """See AddWhileInputHack in python_api.h.
   2250 
   2251     NOTE: This is for TF internal use only. Please don't use it.
   2252 
   2253     Args:
   2254       tensors: list of Tensors
   2255 
   2256     Raises:
   2257       TypeError: if tensor is not a Tensor,
   2258         or if input tensor type is not convertible to dtype.
   2259       ValueError: if the Tensor is from a different graph.
   2260     """
   2261     for tensor in tensors:
   2262       if not isinstance(tensor, Tensor):
   2263         raise TypeError("tensor must be a Tensor: %s" % tensor)
   2264       _assert_same_graph(self, tensor)
   2265 
   2266       # Reset cached inputs.
   2267       self._inputs_val = None
   2268       c_api.AddWhileInputHack(
   2269           self._graph._c_graph,  # pylint: disable=protected-access
   2270           tensor._as_tf_output(),  # pylint: disable=protected-access
   2271           self._c_op)
   2272 
   2273   def _add_control_inputs(self, ops):
   2274     """Add a list of new control inputs to this operation.
   2275 
   2276     Args:
   2277       ops: the list of Operations to add as control input.
   2278 
   2279     Raises:
   2280       TypeError: if ops is not a list of Operations.
   2281       ValueError: if any op in ops is from a different graph.
   2282     """
   2283     for op in ops:
   2284       if not isinstance(op, Operation):
   2285         raise TypeError("op must be an Operation: %s" % op)
   2286       c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op)  # pylint: disable=protected-access
   2287 
   2288   def _add_control_input(self, op):
   2289     """Add a new control input to this operation.
   2290 
   2291     Args:
   2292       op: the Operation to add as control input.
   2293 
   2294     Raises:
   2295       TypeError: if op is not an Operation.
   2296       ValueError: if op is from a different graph.
   2297     """
   2298     if not isinstance(op, Operation):
   2299       raise TypeError("op must be an Operation: %s" % op)
   2300     c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op)  # pylint: disable=protected-access
   2301 
   2302   def _remove_all_control_inputs(self):
   2303     """Removes any control inputs to this operation."""
   2304     c_api.RemoveAllControlInputs(self._graph._c_graph, self._c_op)  # pylint: disable=protected-access
   2305 
   2306   def _add_outputs(self, types, shapes):
   2307     """Adds new Tensors to self.outputs.
   2308 
   2309     Note: this is generally unsafe to use. This is used in certain situations in
   2310     conjunction with _set_type_list_attr.
   2311 
   2312     Arguments:
   2313       types: list of DTypes
   2314       shapes: list of TensorShapes
   2315     """
   2316     assert len(types) == len(shapes)
   2317     orig_num_outputs = len(self.outputs)
   2318     for i in range(len(types)):
   2319       t = Tensor(self, orig_num_outputs + i, types[i])
   2320       self._outputs.append(t)
   2321       t.set_shape(shapes[i])
   2322 
   2323   def __str__(self):
   2324     return str(self.node_def)
   2325 
   2326   def __repr__(self):
   2327     return "<tf.Operation '%s' type=%s>" % (self.name, self.type)
   2328 
   2329   @property
   2330   def outputs(self):
   2331     """The list of `Tensor` objects representing the outputs of this op."""
   2332     return self._outputs
   2333 
   2334 # pylint: disable=protected-access
   2335 
   2336   class _InputList(object):
   2337     """Immutable input list wrapper."""
   2338 
   2339     def __init__(self, inputs):
   2340       self._inputs = inputs
   2341 
   2342     def __iter__(self):
   2343       return iter(self._inputs)
   2344 
   2345     def __len__(self):
   2346       return len(self._inputs)
   2347 
   2348     def __bool__(self):
   2349       return bool(self._inputs)
   2350 
   2351     # Python 3 wants __bool__, Python 2.7 wants __nonzero__
   2352     __nonzero__ = __bool__
   2353 
   2354     def __getitem__(self, i):
   2355       return self._inputs[i]
   2356 
   2357 # pylint: enable=protected-access
   2358 
   2359   @property
   2360   def inputs(self):
   2361     """The list of `Tensor` objects representing the data inputs of this op."""
   2362     if self._inputs_val is None:
   2363       tf_outputs = c_api.GetOperationInputs(self._c_op)
   2364       # pylint: disable=protected-access
   2365       retval = [
   2366           self.graph._get_tensor_by_tf_output(tf_output)
   2367           for tf_output in tf_outputs
   2368       ]
   2369       # pylint: enable=protected-access
   2370       self._inputs_val = Operation._InputList(retval)
   2371     return self._inputs_val
   2372 
   2373   @property
   2374   def _inputs(self):
   2375     logging.warning("Operation._inputs is private, use Operation.inputs "
   2376                     "instead. Operation._inputs will eventually be removed.")
   2377     return self.inputs
   2378 
   2379   @_inputs.setter
   2380   def _inputs(self, value):
   2381     raise ValueError("Cannot assign _inputs")
   2382 
   2383   @property
   2384   def _input_types(self):
   2385     num_inputs = c_api.TF_OperationNumInputs(self._c_op)
   2386     input_types = [
   2387         dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i)))
   2388         for i in xrange(num_inputs)
   2389     ]
   2390     return input_types
   2391 
   2392   @_input_types.setter
   2393   def _input_types(self, value):
   2394     raise ValueError("Cannot assign _input_types")
   2395 
   2396   @property
   2397   def control_inputs(self):
   2398     """The `Operation` objects on which this op has a control dependency.
   2399 
   2400     Before this op is executed, TensorFlow will ensure that the
   2401     operations in `self.control_inputs` have finished executing. This
   2402     mechanism can be used to run ops sequentially for performance
   2403     reasons, or to ensure that the side effects of an op are observed
   2404     in the correct order.
   2405 
   2406     Returns:
   2407       A list of `Operation` objects.
   2408 
   2409     """
   2410     control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op)
   2411     # pylint: disable=protected-access
   2412     return [
   2413         self.graph._get_operation_by_name_unsafe(
   2414             c_api.TF_OperationName(c_op)) for c_op in control_c_ops
   2415     ]
   2416     # pylint: enable=protected-access
   2417 
   2418   @property
   2419   def _control_outputs(self):
   2420     """The `Operation` objects which have a control dependency on this op.
   2421 
   2422     Before any of the ops in self._control_outputs can execute tensorflow will
   2423     ensure self has finished executing.
   2424 
   2425     Returns:
   2426       A list of `Operation` objects.
   2427 
   2428     """
   2429     control_c_ops = c_api.TF_OperationGetControlOutputs_wrapper(self._c_op)
   2430     # pylint: disable=protected-access
   2431     return [
   2432         self.graph._get_operation_by_name_unsafe(
   2433             c_api.TF_OperationName(c_op)) for c_op in control_c_ops
   2434     ]
   2435     # pylint: enable=protected-access
   2436 
   2437   @property
   2438   def _control_inputs(self):
   2439     logging.warning("Operation._control_inputs is private, use "
   2440                     "Operation.control_inputs instead. "
   2441                     "Operation._control_inputs will eventually be removed.")
   2442     return self.control_inputs
   2443 
   2444   @_control_inputs.setter
   2445   def _control_inputs(self, value):
   2446     logging.warning("Operation._control_inputs is private, use "
   2447                     "Operation.control_inputs instead. "
   2448                     "Operation._control_inputs will eventually be removed.")
   2449     # Copy value because it may be self._control_inputs_val (in particular if
   2450     # this is called from self._control_inputs += ...), and we don't want to
   2451     # clear value below.
   2452     value = copy.copy(value)
   2453     self._remove_all_control_inputs()
   2454     self._add_control_inputs(value)
   2455 
   2456   @property
   2457   def type(self):
   2458     """The type of the op (e.g. `"MatMul"`)."""
   2459     return c_api.TF_OperationOpType(self._c_op)
   2460 
   2461   @property
   2462   def graph(self):
   2463     """The `Graph` that contains this operation."""
   2464     return self._graph
   2465 
   2466   @property
   2467   def node_def(self):
   2468     # pylint: disable=line-too-long
   2469     """Returns the `NodeDef` representation of this operation.
   2470 
   2471     Returns:
   2472       A
   2473       [`NodeDef`](https://www.tensorflow.org/code/tensorflow/core/framework/node_def.proto)
   2474       protocol buffer.
   2475     """
   2476     # pylint: enable=line-too-long
   2477     with c_api_util.tf_buffer() as buf:
   2478       c_api.TF_OperationToNodeDef(self._c_op, buf)
   2479       data = c_api.TF_GetBuffer(buf)
   2480     node_def = node_def_pb2.NodeDef()
   2481     node_def.ParseFromString(compat.as_bytes(data))
   2482     return node_def
   2483 
   2484   @property
   2485   def _node_def(self):
   2486     logging.warning("Operation._node_def is private, use Operation.node_def "
   2487                     "instead. Operation._node_def will eventually be removed.")
   2488     return self.node_def
   2489 
   2490   @property
   2491   def op_def(self):
   2492     # pylint: disable=line-too-long
   2493     """Returns the `OpDef` proto that represents the type of this op.
   2494 
   2495     Returns:
   2496       An
   2497       [`OpDef`](https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto)
   2498       protocol buffer.
   2499     """
   2500     # pylint: enable=line-too-long
   2501     return self._graph._get_op_def(self.type)
   2502 
   2503   @property
   2504   def _op_def(self):
   2505     logging.warning("Operation._op_def is private, use Operation.op_def "
   2506                     "instead. Operation._op_def will eventually be removed.")
   2507     return self.op_def
   2508 
   2509   @property
   2510   def traceback(self):
   2511     """Returns the call stack from when this operation was constructed."""
   2512     return tf_stack.convert_stack(self._traceback)
   2513 
   2514   @property
   2515   def traceback_with_start_lines(self):
   2516     """Same as traceback but includes start line of function definition.
   2517 
   2518     Returns:
   2519       A list of 5-tuples (filename, lineno, name, code, func_start_lineno).
   2520     """
   2521     return tf_stack.convert_stack(self._traceback,
   2522                                   include_func_start_lineno=True)
   2523 
   2524   def _set_attr(self, attr_name, attr_value):
   2525     """Private method used to set an attribute in the node_def."""
   2526     buf = c_api.TF_NewBufferFromString(
   2527         compat.as_bytes(attr_value.SerializeToString()))
   2528     try:
   2529       # pylint: disable=protected-access
   2530       c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf)
   2531       # pylint: enable=protected-access
   2532     finally:
   2533       c_api.TF_DeleteBuffer(buf)
   2534 
   2535   def _set_func_attr(self, attr_name, func_name):
   2536     """Private method used to set a function attribute in the node_def."""
   2537     func = attr_value_pb2.NameAttrList(name=func_name)
   2538     self._set_attr(attr_name, attr_value_pb2.AttrValue(func=func))
   2539 
   2540   def _set_type_list_attr(self, attr_name, types):
   2541     """Private method used to set a function attribute in the node_def."""
   2542     if not types: return
   2543     if isinstance(types[0], dtypes.DType):
   2544       types = [dt.as_datatype_enum for dt in types]
   2545     types_list = attr_value_pb2.AttrValue.ListValue(type=types)
   2546     self._set_attr(attr_name, attr_value_pb2.AttrValue(list=types_list))
   2547 
   2548   def _set_shape_list_attr(self, attr_name, shapes):
   2549     """Private method used to set a function attribute in the node_def."""
   2550     shapes = [s.as_proto() for s in shapes]
   2551     shapes_list = attr_value_pb2.AttrValue.ListValue(shape=shapes)
   2552     self._set_attr(attr_name, attr_value_pb2.AttrValue(list=shapes_list))
   2553 
   2554   def _clear_attr(self, attr_name):
   2555     """Private method used to clear an attribute in the node_def."""
   2556     # pylint: disable=protected-access
   2557     c_api.ClearAttr(self._graph._c_graph, self._c_op, attr_name)
   2558     # pylint: enable=protected-access
   2559 
   2560   def get_attr(self, name):
   2561     """Returns the value of the attr of this op with the given `name`.
   2562 
   2563     Args:
   2564       name: The name of the attr to fetch.
   2565 
   2566     Returns:
   2567       The value of the attr, as a Python object.
   2568 
   2569     Raises:
   2570       ValueError: If this op does not have an attr with the given `name`.
   2571     """
   2572     fields = ("s", "i", "f", "b", "type", "shape", "tensor", "func")
   2573     try:
   2574       with c_api_util.tf_buffer() as buf:
   2575         c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf)
   2576         data = c_api.TF_GetBuffer(buf)
   2577     except errors.InvalidArgumentError as e:
   2578       # Convert to ValueError for backwards compatibility.
   2579       raise ValueError(str(e))
   2580     x = attr_value_pb2.AttrValue()
   2581     x.ParseFromString(data)
   2582 
   2583     oneof_value = x.WhichOneof("value")
   2584     if oneof_value is None:
   2585       return []
   2586     if oneof_value == "list":
   2587       for f in fields:
   2588         if getattr(x.list, f):
   2589           if f == "type":
   2590             return [dtypes.as_dtype(t) for t in x.list.type]
   2591           else:
   2592             return list(getattr(x.list, f))
   2593       return []
   2594     if oneof_value == "type":
   2595       return dtypes.as_dtype(x.type)
   2596     assert oneof_value in fields, "Unsupported field type in " + str(x)
   2597     return getattr(x, oneof_value)
   2598 
   2599   def run(self, feed_dict=None, session=None):
   2600     """Runs this operation in a `Session`.
   2601 
   2602     Calling this method will execute all preceding operations that
   2603     produce the inputs needed for this operation.
   2604 
   2605     *N.B.* Before invoking `Operation.run()`, its graph must have been
   2606     launched in a session, and either a default session must be
   2607     available, or `session` must be specified explicitly.
   2608 
   2609     Args:
   2610       feed_dict: A dictionary that maps `Tensor` objects to feed values.
   2611         See `tf.Session.run`
   2612         for a description of the valid feed values.
   2613       session: (Optional.) The `Session` to be used to run to this operation. If
   2614         none, the default session will be used.
   2615     """
   2616     _run_using_default_session(self, feed_dict, self.graph, session)
   2617 
   2618 _gradient_registry = registry.Registry("gradient")
   2619 
   2620 
   2621 @tf_export("RegisterGradient")
   2622 class RegisterGradient(object):
   2623   """A decorator for registering the gradient function for an op type.
   2624 
   2625   This decorator is only used when defining a new op type. For an op
   2626   with `m` inputs and `n` outputs, the gradient function is a function
   2627   that takes the original `Operation` and `n` `Tensor` objects
   2628   (representing the gradients with respect to each output of the op),
   2629   and returns `m` `Tensor` objects (representing the partial gradients
   2630   with respect to each input of the op).
   2631 
   2632   For example, assuming that operations of type `"Sub"` take two
   2633   inputs `x` and `y`, and return a single output `x - y`, the
   2634   following gradient function would be registered:
   2635 
   2636   ```python
   2637   @tf.RegisterGradient("Sub")
   2638   def _sub_grad(unused_op, grad):
   2639     return grad, tf.negative(grad)
   2640   ```
   2641 
   2642   The decorator argument `op_type` is the string type of an
   2643   operation. This corresponds to the `OpDef.name` field for the proto
   2644   that defines the operation.
   2645   """
   2646 
   2647   def __init__(self, op_type):
   2648     """Creates a new decorator with `op_type` as the Operation type.
   2649 
   2650     Args:
   2651       op_type: The string type of an operation. This corresponds to the
   2652         `OpDef.name` field for the proto that defines the operation.
   2653     """
   2654     if not isinstance(op_type, six.string_types):
   2655       raise TypeError("op_type must be a string")
   2656     self._op_type = op_type
   2657 
   2658   def __call__(self, f):
   2659     """Registers the function `f` as gradient function for `op_type`."""
   2660     _gradient_registry.register(f, self._op_type)
   2661     return f
   2662 
   2663 
   2664 @deprecation.deprecated_endpoints("NotDifferentiable", "NoGradient")
   2665 @tf_export("no_gradient", v1=["no_gradient", "NotDifferentiable", "NoGradient"])
   2666 def no_gradient(op_type):
   2667   """Specifies that ops of type `op_type` is not differentiable.
   2668 
   2669   This function should *not* be used for operations that have a
   2670   well-defined gradient that is not yet implemented.
   2671 
   2672   This function is only used when defining a new op type. It may be
   2673   used for ops such as `tf.size()` that are not differentiable.  For
   2674   example:
   2675 
   2676   ```python
   2677   tf.NotDifferentiable("Size")
   2678   ```
   2679 
   2680   The gradient computed for 'op_type' will then propagate zeros.
   2681 
   2682   For ops that have a well-defined gradient but are not yet implemented,
   2683   no declaration should be made, and an error *must* be thrown if
   2684   an attempt to request its gradient is made.
   2685 
   2686   Args:
   2687     op_type: The string type of an operation. This corresponds to the
   2688       `OpDef.name` field for the proto that defines the operation.
   2689 
   2690   Raises:
   2691     TypeError: If `op_type` is not a string.
   2692 
   2693   """
   2694   if not isinstance(op_type, six.string_types):
   2695     raise TypeError("op_type must be a string")
   2696   _gradient_registry.register(None, op_type)
   2697 
   2698 
   2699 # Aliases for the old names, will be eventually removed.
   2700 NoGradient = no_gradient
   2701 NotDifferentiable = no_gradient
   2702 
   2703 
   2704 def get_gradient_function(op):
   2705   """Returns the function that computes gradients for "op"."""
   2706   if not op.inputs:
   2707     return None
   2708   try:
   2709     op_type = op.get_attr("_gradient_op_type")
   2710   except ValueError:
   2711     op_type = op.type
   2712   return _gradient_registry.lookup(op_type)
   2713 
   2714 
   2715 _shape_registry = registry.Registry("shape functions")
   2716 _default_shape_function_registry = registry.Registry("default shape functions")
   2717 
   2718 # These are set to common_shapes.call_cpp_shape_fn by op generated code
   2719 # (generated by python_op_gen.cc).
   2720 # It is set outside ops.py to avoid a circular dependency.
   2721 _call_cpp_shape_fn = None
   2722 _call_cpp_shape_fn_and_require_op = None
   2723 
   2724 
   2725 def _set_call_cpp_shape_fn(call_cpp_shape_fn):
   2726   """Sets default shape fns from passed common_shapes.call_cpp_shape_fn."""
   2727   global _call_cpp_shape_fn, _call_cpp_shape_fn_and_require_op
   2728   if _call_cpp_shape_fn:
   2729     return  # already registered
   2730 
   2731   def call_without_requiring(op):
   2732     return call_cpp_shape_fn(op, require_shape_fn=False)
   2733 
   2734   _call_cpp_shape_fn = call_without_requiring
   2735 
   2736   def call_with_requiring(op):
   2737     return call_cpp_shape_fn(op, require_shape_fn=True)
   2738 
   2739   _call_cpp_shape_fn_and_require_op = call_with_requiring
   2740 
   2741 
   2742 class RegisterShape(object):
   2743   """No longer used.  Was: A decorator for registering a shape function.
   2744 
   2745   Shape functions must now be registered via the SetShapeFn on the
   2746   original Op specification in C++.
   2747 
   2748   """
   2749 
   2750   def __init__(self, op_type):
   2751     """Saves the `op_type` as the `Operation` type."""
   2752     if not isinstance(op_type, six.string_types):
   2753       raise TypeError("op_type must be a string")
   2754     self._op_type = op_type
   2755 
   2756   def __call__(self, f):
   2757     """Registers "f" as the shape function for "op_type"."""
   2758     if f is None:
   2759       assert _call_cpp_shape_fn
   2760 
   2761       # None is a special "weak" value that provides a default shape function,
   2762       # and can be overridden by a non-None registration.
   2763       try:
   2764         _default_shape_function_registry.register(_call_cpp_shape_fn,
   2765                                                   self._op_type)
   2766       except KeyError:
   2767         # Ignore duplicate registrations of the weak value. This can
   2768         # occur if the op library input to wrapper generation
   2769         # inadvertently links in one or more of the standard op
   2770         # libraries.
   2771         pass
   2772     else:
   2773       _shape_registry.register(f, self._op_type)
   2774     return f
   2775 
   2776 
   2777 def set_shape_and_handle_data_for_outputs(_):
   2778   """No op. TODO(b/74620627): Remove this."""
   2779   pass
   2780 
   2781 
   2782 class OpStats(object):
   2783   """A holder for statistics about an operator.
   2784 
   2785   This class holds information about the resource requirements for an op,
   2786   including the size of its weight parameters on-disk and how many FLOPS it
   2787   requires to execute forward inference.
   2788 
   2789   If you define a new operation, you can create a function that will return a
   2790   set of information about its usage of the CPU and disk space when serialized.
   2791   The function itself takes a Graph object that's been set up so you can call
   2792   methods like get_tensor_by_name to help calculate the results, and a NodeDef
   2793   argument.
   2794 
   2795   """
   2796 
   2797   def __init__(self, statistic_type, value=None):
   2798     """Sets up the initial placeholders for the statistics."""
   2799     self.statistic_type = statistic_type
   2800     self.value = value
   2801 
   2802   @property
   2803   def statistic_type(self):
   2804     return self._statistic_type
   2805 
   2806   @statistic_type.setter
   2807   def statistic_type(self, statistic_type):
   2808     self._statistic_type = statistic_type
   2809 
   2810   @property
   2811   def value(self):
   2812     return self._value
   2813 
   2814   @value.setter
   2815   def value(self, value):
   2816     self._value = value
   2817 
   2818   def __iadd__(self, other):
   2819     if other.statistic_type != self.statistic_type:
   2820       raise ValueError("Can't add an OpStat of type %s to one of %s." %
   2821                        (self.statistic_type, other.statistic_type))
   2822     if self.value is None:
   2823       self.value = other.value
   2824     elif other.value is not None:
   2825       self._value += other.value
   2826     return self
   2827 
   2828 
   2829 _stats_registry = registry.Registry("statistical functions")
   2830 
   2831 
   2832 class RegisterStatistics(object):
   2833   """A decorator for registering the statistics function for an op type.
   2834 
   2835   This decorator can be defined for an op type so that it gives a
   2836   report on the resources used by an instance of an operator, in the
   2837   form of an OpStats object.
   2838 
   2839   Well-known types of statistics include these so far:
   2840 
   2841   - flops: When running a graph, the bulk of the computation happens doing
   2842     numerical calculations like matrix multiplications. This type allows a node
   2843     to return how many floating-point operations it takes to complete. The
   2844     total number of FLOPs for a graph is a good guide to its expected latency.
   2845 
   2846   You can add your own statistics just by picking a new type string, registering
   2847   functions for the ops you care about, and then calling get_stats_for_node_def.
   2848 
   2849   If a statistic for an op is registered multiple times, a KeyError will be
   2850   raised.
   2851 
   2852   Since the statistics is counted on a per-op basis. It is not suitable for
   2853   model parameters (capacity), which is expected to be counted only once, even
   2854   if it is shared by multiple ops. (e.g. RNN)
   2855 
   2856   For example, you can define a new metric called doohickey for a Foo operation
   2857   by placing this in your code:
   2858 
   2859   ```python
   2860   @ops.RegisterStatistics("Foo", "doohickey")
   2861   def _calc_foo_bojangles(unused_graph, unused_node_def):
   2862     return ops.OpStats("doohickey", 20)
   2863   ```
   2864 
   2865   Then in client code you can retrieve the value by making this call:
   2866 
   2867   ```python
   2868   doohickey = ops.get_stats_for_node_def(graph, node_def, "doohickey")
   2869   ```
   2870 
   2871   If the NodeDef is for an op with a registered doohickey function, you'll get
   2872   back the calculated amount in doohickey.value, or None if it's not defined.
   2873 
   2874   """
   2875 
   2876   def __init__(self, op_type, statistic_type):
   2877     """Saves the `op_type` as the `Operation` type."""
   2878     if not isinstance(op_type, six.string_types):
   2879       raise TypeError("op_type must be a string.")
   2880     if "," in op_type:
   2881       raise TypeError("op_type must not contain a comma.")
   2882     self._op_type = op_type
   2883     if not isinstance(statistic_type, six.string_types):
   2884       raise TypeError("statistic_type must be a string.")
   2885     if "," in statistic_type:
   2886       raise TypeError("statistic_type must not contain a comma.")
   2887     self._statistic_type = statistic_type
   2888 
   2889   def __call__(self, f):
   2890     """Registers "f" as the statistics function for "op_type"."""
   2891     _stats_registry.register(f, self._op_type + "," + self._statistic_type)
   2892     return f
   2893 
   2894 
   2895 def get_stats_for_node_def(graph, node, statistic_type):
   2896   """Looks up the node's statistics function in the registry and calls it.
   2897 
   2898   This function takes a Graph object and a NodeDef from a GraphDef, and if
   2899   there's an associated statistics method, calls it and returns a result. If no
   2900   function has been registered for the particular node type, it returns an empty
   2901   statistics object.
   2902 
   2903   Args:
   2904     graph: A Graph object that's been set up with the node's graph.
   2905     node: A NodeDef describing the operator.
   2906     statistic_type: A string identifying the statistic we're interested in.
   2907   Returns:
   2908     An OpStats object containing information about resource usage.
   2909   """
   2910 
   2911   try:
   2912     stats_func = _stats_registry.lookup(node.op + "," + statistic_type)
   2913     result = stats_func(graph, node)
   2914   except LookupError:
   2915     result = OpStats(statistic_type)
   2916   return result
   2917 
   2918 
   2919 def _name_from_scope_name(name):
   2920   """Returns the name of an op given the name of its scope.
   2921 
   2922   Args:
   2923     name: the name of the scope.
   2924 
   2925   Returns:
   2926     the name of the op (equal to scope name minus any trailing slash).
   2927   """
   2928   return name[:-1] if (name and name[-1] == "/") else name
   2929 
   2930 
   2931 _MUTATION_LOCK_GROUP = 0
   2932 _SESSION_RUN_LOCK_GROUP = 1
   2933 
   2934 @tf_export("Graph")
   2935 class Graph(object):
   2936   """A TensorFlow computation, represented as a dataflow graph.
   2937 
   2938   A `Graph` contains a set of
   2939   `tf.Operation` objects,
   2940   which represent units of computation; and
   2941   `tf.Tensor` objects, which represent
   2942   the units of data that flow between operations.
   2943 
   2944   A default `Graph` is always registered, and accessible by calling
   2945   `tf.get_default_graph`.
   2946   To add an operation to the default graph, simply call one of the functions
   2947   that defines a new `Operation`:
   2948 
   2949   ```python
   2950   c = tf.constant(4.0)
   2951   assert c.graph is tf.get_default_graph()
   2952   ```
   2953 
   2954   Another typical usage involves the
   2955   `tf.Graph.as_default`
   2956   context manager, which overrides the current default graph for the
   2957   lifetime of the context:
   2958 
   2959   ```python
   2960   g = tf.Graph()
   2961   with g.as_default():
   2962     # Define operations and tensors in `g`.
   2963     c = tf.constant(30.0)
   2964     assert c.graph is g
   2965   ```
   2966 
   2967   Important note: This class *is not* thread-safe for graph construction. All
   2968   operations should be created from a single thread, or external
   2969   synchronization must be provided. Unless otherwise specified, all methods
   2970   are not thread-safe.
   2971 
   2972   A `Graph` instance supports an arbitrary number of "collections"
   2973   that are identified by name. For convenience when building a large
   2974   graph, collections can store groups of related objects: for
   2975   example, the `tf.Variable` uses a collection (named
   2976   `tf.GraphKeys.GLOBAL_VARIABLES`) for
   2977   all variables that are created during the construction of a graph. The caller
   2978   may define additional collections by specifying a new name.
   2979   """
   2980 
   2981   def __init__(self):
   2982     """Creates a new, empty Graph."""
   2983     # Protects core state that can be returned via public accessors.
   2984     # Thread-safety is provided on a best-effort basis to support buggy
   2985     # programs, and is not guaranteed by the public `tf.Graph` API.
   2986     #
   2987     # NOTE(mrry): This does not protect the various stacks. A warning will
   2988     # be reported if these are used from multiple threads
   2989     self._lock = threading.RLock()
   2990     # The group lock synchronizes Session.run calls with methods that create
   2991     # and mutate ops (e.g. Graph.create_op()). This synchronization is
   2992     # necessary because it's illegal to modify an operation after it's been run.
   2993     # The group lock allows any number of threads to mutate ops at the same time
   2994     # but if any modification is going on, all Session.run calls have to wait.
   2995     # Similarly, if one or more Session.run calls are going on, all mutate ops
   2996     # have to wait until all Session.run calls have finished.
   2997     self._group_lock = lock_util.GroupLock(num_groups=2)
   2998     self._nodes_by_id = dict()  # GUARDED_BY(self._lock)
   2999     self._next_id_counter = 0  # GUARDED_BY(self._lock)
   3000     self._nodes_by_name = dict()  # GUARDED_BY(self._lock)
   3001     self._version = 0  # GUARDED_BY(self._lock)
   3002     # Maps a name used in the graph to the next id to use for that name.
   3003     self._names_in_use = {}
   3004     self._stack_state_is_thread_local = False
   3005     self._thread_local = threading.local()
   3006     # Functions that will be applied to choose a device if none is specified.
   3007     # In TF2.x or after switch_to_thread_local(),
   3008     # self._thread_local._device_function_stack is used instead.
   3009     self._graph_device_function_stack = traceable_stack.TraceableStack()
   3010     # Default original_op applied to new ops.
   3011     self._default_original_op = None
   3012     # Current control flow context. It could be either CondContext or
   3013     # WhileContext defined in ops/control_flow_ops.py
   3014     self._control_flow_context = None
   3015     # A new node will depend of the union of all of the nodes in the stack.
   3016     # In TF2.x or after switch_to_thread_local(),
   3017     # self._thread_local._control_dependencies_stack is used instead.
   3018     self._graph_control_dependencies_stack = []
   3019     # Arbitrary collections of objects.
   3020     self._collections = {}
   3021     # The graph-level random seed
   3022     self._seed = None
   3023     # A dictionary of attributes that should be applied to all ops.
   3024     self._attr_scope_map = {}
   3025     # A map from op type to the kernel label that should be used.
   3026     self._op_to_kernel_label_map = {}
   3027     # A map from op type to an alternative op type that should be used when
   3028     # computing gradients.
   3029     self._gradient_override_map = {}
   3030     # True if the graph is considered "finalized".  In that case no
   3031     # new operations can be added.
   3032     self._finalized = False
   3033     # Functions defined in the graph
   3034     self._functions = collections.OrderedDict()
   3035     # Default GraphDef versions
   3036     self._graph_def_versions = versions_pb2.VersionDef(
   3037         producer=versions.GRAPH_DEF_VERSION,
   3038         min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)
   3039     self._building_function = False
   3040     # Stack of colocate_with ops. In TF2.x or after switch_to_thread_local(),
   3041     # self._thread_local._colocation_stack is used instead.
   3042     self._graph_colocation_stack = traceable_stack.TraceableStack()
   3043     # Set of tensors that are dangerous to feed!
   3044     self._unfeedable_tensors = set()
   3045     # Set of operations that are dangerous to fetch!
   3046     self._unfetchable_ops = set()
   3047     # A map of tensor handle placeholder to tensor dtype.
   3048     self._handle_feeders = {}
   3049     # A map from tensor handle to its read op.
   3050     self._handle_readers = {}
   3051     # A map from tensor handle to its move op.
   3052     self._handle_movers = {}
   3053     # A map from tensor handle to its delete op.
   3054     self._handle_deleters = {}
   3055     # Allow optimizers and other objects to pseudo-uniquely key graphs (this key
   3056     # will be shared when defining function graphs, for example, so optimizers
   3057     # being called inside function definitions behave as if they were seeing the
   3058     # actual outside graph).
   3059     self._graph_key = "grap-key-%d/" % (uid(),)
   3060     self._container = ""
   3061     self._registered_ops = op_def_registry.get_registered_ops()
   3062     # Set to True if this graph is being built in an
   3063     # AutomaticControlDependencies context.
   3064     self._add_control_dependencies = False
   3065 
   3066     # TODO(skyewm): fold as much of the above as possible into the C
   3067     # implementation
   3068     self._scoped_c_graph = c_api_util.ScopedTFGraph()
   3069     # The C API requires all ops to have shape functions. Disable this
   3070     # requirement (many custom ops do not have shape functions, and we don't
   3071     # want to break these existing cases).
   3072     c_api.SetRequireShapeInferenceFns(self._c_graph, False)
   3073     if tf2.enabled():
   3074       self.switch_to_thread_local()
   3075 
   3076   # Note: this method is private because the API of tf.Graph() is public and
   3077   # frozen, and this functionality is still not ready for public visibility.
   3078   @tf_contextlib.contextmanager
   3079   def _variable_creator_scope(self, creator, priority=100):
   3080     """Scope which defines a variable creation function.
   3081 
   3082     Args:
   3083       creator: A callable taking `next_creator` and `kwargs`. See the
   3084         `tf.variable_creator_scope` docstring.
   3085       priority: Creators with a higher `priority` are called first. Within the
   3086         same priority, creators are called inner-to-outer.
   3087 
   3088     Yields:
   3089       `_variable_creator_scope` is a context manager with a side effect, but
   3090       doesn't return a value.
   3091     """
   3092     # This step makes a copy of the existing stack, and it also initializes
   3093     # self._thread_local._variable_creator_stack if it doesn't exist yet.
   3094     old = list(self._variable_creator_stack)
   3095     stack = self._thread_local._variable_creator_stack  # pylint: disable=protected-access
   3096     stack.append((priority, creator))
   3097     # Sorting is stable, so we'll put higher-priority creators later in the list
   3098     # but otherwise maintain registration order.
   3099     stack.sort(key=lambda item: item[0])
   3100     try:
   3101       yield
   3102     finally:
   3103       self._thread_local._variable_creator_stack = old  # pylint: disable=protected-access
   3104 
   3105   # Note: this method is private because the API of tf.Graph() is public and
   3106   # frozen, and this functionality is still not ready for public visibility.
   3107   @property
   3108   def _variable_creator_stack(self):
   3109     if not hasattr(self._thread_local, "_variable_creator_stack"):
   3110       self._thread_local._variable_creator_stack = []  # pylint: disable=protected-access
   3111     return list(self._thread_local._variable_creator_stack)  # pylint: disable=protected-access
   3112 
   3113   @_variable_creator_stack.setter
   3114   def _variable_creator_stack(self, variable_creator_stack):
   3115     self._thread_local._variable_creator_stack = variable_creator_stack  # pylint: disable=protected-access
   3116 
   3117   def _check_not_finalized(self):
   3118     """Check if the graph is finalized.
   3119 
   3120     Raises:
   3121       RuntimeError: If the graph finalized.
   3122     """
   3123     if self._finalized:
   3124       raise RuntimeError("Graph is finalized and cannot be modified.")
   3125 
   3126   def _add_op(self, op):
   3127     """Adds 'op' to the graph.
   3128 
   3129     Args:
   3130       op: the Operator or Tensor to add.
   3131 
   3132     Raises:
   3133       TypeError: if op is not an Operation or Tensor.
   3134       ValueError: if the op.name or op._id are already used.
   3135     """
   3136     self._check_not_finalized()
   3137     if not isinstance(op, (Tensor, Operation)):
   3138       raise TypeError("op must be a Tensor or Operation: %s" % op)
   3139     with self._lock:
   3140       # pylint: disable=protected-access
   3141       if op._id in self._nodes_by_id:
   3142         raise ValueError("cannot add an op with id %d as it already "
   3143                          "exists in the graph" % op._id)
   3144       if op.name in self._nodes_by_name:
   3145         raise ValueError("cannot add op with name %s as that name "
   3146                          "is already used" % op.name)
   3147       self._nodes_by_id[op._id] = op
   3148       self._nodes_by_name[op.name] = op
   3149       self._version = max(self._version, op._id)
   3150       # pylint: enable=protected-access
   3151 
   3152   @property
   3153   def _c_graph(self):
   3154     if self._scoped_c_graph:
   3155       return self._scoped_c_graph.graph
   3156     return None
   3157 
   3158   @property
   3159   def version(self):
   3160     """Returns a version number that increases as ops are added to the graph.
   3161 
   3162     Note that this is unrelated to the
   3163     `tf.Graph.graph_def_versions`.
   3164 
   3165     Returns:
   3166        An integer version that increases as ops are added to the graph.
   3167     """
   3168     if self._finalized:
   3169       return self._version
   3170 
   3171     with self._lock:
   3172       return self._version
   3173 
   3174   @property
   3175   def graph_def_versions(self):
   3176     # pylint: disable=line-too-long
   3177     """The GraphDef version information of this graph.
   3178 
   3179     For details on the meaning of each version, see
   3180     [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto).
   3181 
   3182     Returns:
   3183       A `VersionDef`.
   3184     """
   3185     # pylint: enable=line-too-long
   3186     with c_api_util.tf_buffer() as buf:
   3187       c_api.TF_GraphVersions(self._c_graph, buf)
   3188       data = c_api.TF_GetBuffer(buf)
   3189     version_def = versions_pb2.VersionDef()
   3190     version_def.ParseFromString(compat.as_bytes(data))
   3191     return version_def
   3192 
   3193   @property
   3194   def seed(self):
   3195     """The graph-level random seed of this graph."""
   3196     return self._seed
   3197 
   3198   @seed.setter
   3199   def seed(self, seed):
   3200     self._seed = seed
   3201 
   3202   @property
   3203   def finalized(self):
   3204     """True if this graph has been finalized."""
   3205     return self._finalized
   3206 
   3207   def finalize(self):
   3208     """Finalizes this graph, making it read-only.
   3209 
   3210     After calling `g.finalize()`, no new operations can be added to
   3211     `g`.  This method is used to ensure that no operations are added
   3212     to a graph when it is shared between multiple threads, for example
   3213     when using a `tf.train.QueueRunner`.
   3214     """
   3215     self._finalized = True
   3216 
   3217   def _unsafe_unfinalize(self):
   3218     """Opposite of `finalize`. Internal interface.
   3219 
   3220     NOTE: Unfinalizing a graph could have negative impact on performance,
   3221     especially in a multi-threaded environment.  Unfinalizing a graph
   3222     when it is in use by a Session may lead to undefined behavior. Ensure
   3223     that all sessions using a graph are closed before calling this method.
   3224     """
   3225     self._finalized = False
   3226 
   3227   def _get_control_flow_context(self):
   3228     """Returns the current control flow context.
   3229 
   3230     Returns:
   3231       A context object.
   3232     """
   3233     return self._control_flow_context
   3234 
   3235   def _set_control_flow_context(self, ctx):
   3236     """Sets the current control flow context.
   3237 
   3238     Args:
   3239       ctx: a context object.
   3240     """
   3241     self._control_flow_context = ctx
   3242 
   3243   def _copy_functions_to_graph_def(self, graph_def, starting_bytesize):
   3244     """If this graph contains functions, copy them to `graph_def`."""
   3245     bytesize = starting_bytesize
   3246     for f in self._functions.values():
   3247       bytesize += f.definition.ByteSize()
   3248       if bytesize >= (1 << 31) or bytesize < 0:
   3249         raise ValueError("GraphDef cannot be larger than 2GB.")
   3250       graph_def.library.function.extend([f.definition])
   3251       if f.grad_func_name:
   3252         grad_def = function_pb2.GradientDef()
   3253         grad_def.function_name = f.name
   3254         grad_def.gradient_func = f.grad_func_name
   3255         graph_def.library.gradient.extend([grad_def])
   3256 
   3257   def _as_graph_def(self, from_version=None, add_shapes=False):
   3258     # pylint: disable=line-too-long
   3259     """Returns a serialized `GraphDef` representation of this graph.
   3260 
   3261     The serialized `GraphDef` can be imported into another `Graph`
   3262     (using `tf.import_graph_def`) or used with the
   3263     [C++ Session API](../../../../api_docs/cc/index.md).
   3264 
   3265     This method is thread-safe.
   3266 
   3267     Args:
   3268       from_version: Optional.  If this is set, returns a `GraphDef`
   3269         containing only the nodes that were added to this graph since
   3270         its `version` property had the given value.
   3271       add_shapes: If true, adds an "_output_shapes" list attr to each
   3272         node with the inferred shapes of each of its outputs.
   3273 
   3274     Returns:
   3275       A tuple containing a
   3276       [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
   3277       protocol buffer, and the version of the graph to which that
   3278       `GraphDef` corresponds.
   3279 
   3280     Raises:
   3281       ValueError: If the `graph_def` would be too large.
   3282 
   3283     """
   3284     # pylint: enable=line-too-long
   3285     with self._lock:
   3286       with c_api_util.tf_buffer() as buf:
   3287         c_api.TF_GraphToGraphDef(self._c_graph, buf)
   3288         data = c_api.TF_GetBuffer(buf)
   3289       graph = graph_pb2.GraphDef()
   3290       graph.ParseFromString(compat.as_bytes(data))
   3291       # Strip the experimental library field iff it's empty.
   3292       if not graph.library.function:
   3293         graph.ClearField("library")
   3294 
   3295       if add_shapes:
   3296         for node in graph.node:
   3297           op = self._nodes_by_name[node.name]
   3298           if op.outputs:
   3299             node.attr["_output_shapes"].list.shape.extend(
   3300                 [output.get_shape().as_proto() for output in op.outputs])
   3301     return graph, self._version
   3302 
   3303   def as_graph_def(self, from_version=None, add_shapes=False):
   3304     # pylint: disable=line-too-long
   3305     """Returns a serialized `GraphDef` representation of this graph.
   3306 
   3307     The serialized `GraphDef` can be imported into another `Graph`
   3308     (using `tf.import_graph_def`) or used with the
   3309     [C++ Session API](../../api_docs/cc/index.md).
   3310 
   3311     This method is thread-safe.
   3312 
   3313     Args:
   3314       from_version: Optional.  If this is set, returns a `GraphDef`
   3315         containing only the nodes that were added to this graph since
   3316         its `version` property had the given value.
   3317       add_shapes: If true, adds an "_output_shapes" list attr to each
   3318         node with the inferred shapes of each of its outputs.
   3319 
   3320     Returns:
   3321       A
   3322       [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
   3323       protocol buffer.
   3324 
   3325     Raises:
   3326       ValueError: If the `graph_def` would be too large.
   3327     """
   3328     # pylint: enable=line-too-long
   3329     result, _ = self._as_graph_def(from_version, add_shapes)
   3330     return result
   3331 
   3332   def _is_function(self, name):
   3333     """Tests whether 'name' is registered in this graph's function library.
   3334 
   3335     Args:
   3336       name: string op name.
   3337     Returns:
   3338       bool indicating whether or not 'name' is registered in function library.
   3339     """
   3340     return compat.as_str(name) in self._functions
   3341 
   3342   def _get_function(self, name):
   3343     """Returns the function definition for 'name'.
   3344 
   3345     Args:
   3346       name: string function name.
   3347     Returns:
   3348       The function def proto.
   3349     """
   3350     return self._functions.get(compat.as_str(name), None)
   3351 
   3352   def _add_function(self, function):
   3353     """Adds a function to the graph.
   3354 
   3355     After the function has been added, you can call to the function by
   3356     passing the function name in place of an op name to
   3357     `Graph.create_op()`.
   3358 
   3359     Args:
   3360       function: A `_DefinedFunction` object.
   3361 
   3362 
   3363     Raises:
   3364       ValueError: if another function is defined with the same name.
   3365     """
   3366     name = function.name
   3367     # Sanity checks on gradient definition.
   3368     if (function.grad_func_name is not None) and (function.python_grad_func is
   3369                                                   not None):
   3370       raise ValueError("Gradient defined twice for function %s" % name)
   3371 
   3372     # Add function to graph
   3373     # pylint: disable=protected-access
   3374     # Handle functions created without using the C API. TODO(apassos,skyewm)
   3375     # remove this when all functions are generated using the C API by default
   3376     # as this will be unnecessary.
   3377     if not function._c_func:
   3378       serialized = function.definition.SerializeToString()
   3379       c_func = c_api.TF_FunctionImportFunctionDef(serialized)
   3380       function._c_func = c_api_util.ScopedTFFunction(c_func)
   3381     gradient = (function._grad_func._c_func.func if function._grad_func
   3382                 else None)
   3383     c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient)
   3384     # pylint: enable=protected-access
   3385 
   3386     self._functions[compat.as_str(name)] = function
   3387 
   3388     # Need a new-enough consumer to support the functions we add to the graph.
   3389     if self._graph_def_versions.min_consumer < 12:
   3390       self._graph_def_versions.min_consumer = 12
   3391 
   3392   @property
   3393   def building_function(self):
   3394     """Returns True iff this graph represents a function."""
   3395     return self._building_function
   3396 
   3397   # Helper functions to create operations.
   3398   @deprecated_args(None,
   3399                    "Shapes are always computed; don't use the compute_shapes "
   3400                    "as it has no effect.", "compute_shapes")
   3401   def create_op(
   3402       self,
   3403       op_type,
   3404       inputs,
   3405       dtypes=None,  # pylint: disable=redefined-outer-name
   3406       input_types=None,
   3407       name=None,
   3408       attrs=None,
   3409       op_def=None,
   3410       compute_shapes=True,
   3411       compute_device=True):
   3412     """Creates an `Operation` in this graph.
   3413 
   3414     This is a low-level interface for creating an `Operation`. Most
   3415     programs will not call this method directly, and instead use the
   3416     Python op constructors, such as `tf.constant()`, which add ops to
   3417     the default graph.
   3418 
   3419     Args:
   3420       op_type: The `Operation` type to create. This corresponds to the
   3421         `OpDef.name` field for the proto that defines the operation.
   3422       inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
   3423       dtypes: (Optional) A list of `DType` objects that will be the types of the
   3424         tensors that the operation produces.
   3425       input_types: (Optional.) A list of `DType`s that will be the types of
   3426         the tensors that the operation consumes. By default, uses the base
   3427         `DType` of each input in `inputs`. Operations that expect
   3428         reference-typed inputs must specify `input_types` explicitly.
   3429       name: (Optional.) A string name for the operation. If not specified, a
   3430         name is generated based on `op_type`.
   3431       attrs: (Optional.) A dictionary where the key is the attribute name (a
   3432         string) and the value is the respective `attr` attribute of the
   3433         `NodeDef` proto that will represent the operation (an `AttrValue`
   3434         proto).
   3435       op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
   3436         the operation will have.
   3437       compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always
   3438         computed).
   3439       compute_device: (Optional.) If True, device functions will be executed
   3440         to compute the device property of the Operation.
   3441 
   3442     Raises:
   3443       TypeError: if any of the inputs is not a `Tensor`.
   3444       ValueError: if colocation conflicts with existing device assignment.
   3445 
   3446     Returns:
   3447       An `Operation` object.
   3448     """
   3449     del compute_shapes
   3450 
   3451     self._check_not_finalized()
   3452     for idx, a in enumerate(inputs):
   3453       if not isinstance(a, Tensor):
   3454         raise TypeError("Input #%d is not a tensor: %s" % (idx, a))
   3455     if name is None:
   3456       name = op_type
   3457     # If a names ends with a '/' it is a "name scope" and we use it as-is,
   3458     # after removing the trailing '/'.
   3459     if name and name[-1] == "/":
   3460       name = _name_from_scope_name(name)
   3461     else:
   3462       name = self.unique_name(name)
   3463 
   3464     node_def = _NodeDef(op_type, name, device=None, attrs=attrs)
   3465 
   3466     input_ops = set([t.op for t in inputs])
   3467     control_inputs = self._control_dependencies_for_inputs(input_ops)
   3468     # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a
   3469     # Session.run call cannot occur between creating and mutating the op.
   3470     with self._mutation_lock():
   3471       ret = Operation(
   3472           node_def,
   3473           self,
   3474           inputs=inputs,
   3475           output_types=dtypes,
   3476           control_inputs=control_inputs,
   3477           input_types=input_types,
   3478           original_op=self._default_original_op,
   3479           op_def=op_def)
   3480       self._create_op_helper(ret, compute_device=compute_device)
   3481     return ret
   3482 
   3483   def _create_op_from_tf_operation(self, c_op, compute_device=True):
   3484     """Creates an `Operation` in this graph from the supplied TF_Operation.
   3485 
   3486     This method is like create_op() except the new Operation is constructed
   3487     using `c_op`. The returned Operation will have `c_op` as its _c_op
   3488     field. This is used to create Operation objects around TF_Operations created
   3489     indirectly by the C API (e.g. by TF_ImportGraphDef, TF_FinishWhile).
   3490 
   3491     This function does not call Operation._control_flow_post_processing or
   3492     Graph._control_dependencies_for_inputs (since the inputs may not be
   3493     available yet). The caller is responsible for calling these methods.
   3494 
   3495     Args:
   3496       c_op: a wrapped TF_Operation
   3497       compute_device: (Optional.) If True, device functions will be executed
   3498         to compute the device property of the Operation.
   3499 
   3500     Returns:
   3501       An `Operation` object.
   3502     """
   3503     self._check_not_finalized()
   3504     ret = Operation(c_op, self)
   3505     # If a name_scope was created with ret.name but no nodes were created in it,
   3506     # the name will still appear in _names_in_use even though the name hasn't
   3507     # been used. This is ok, just leave _names_in_use as-is in this case.
   3508     # TODO(skyewm): make the C API guarantee no name conflicts.
   3509     name_key = ret.name.lower()
   3510     if name_key not in self._names_in_use:
   3511       self._names_in_use[name_key] = 1
   3512     self._create_op_helper(ret, compute_device=compute_device)
   3513     return ret
   3514 
   3515   def _create_op_helper(self, op, compute_device=True):
   3516     """Common logic for creating an op in this graph."""
   3517     # Apply any additional attributes requested. Do not overwrite any existing
   3518     # attributes.
   3519     for key, value in self._attr_scope_map.items():
   3520       try:
   3521         op.get_attr(key)
   3522       except ValueError:
   3523         if callable(value):
   3524           value = value(op.node_def)
   3525           if not isinstance(value, (type(None), attr_value_pb2.AttrValue)):
   3526             raise TypeError(
   3527                 "Callable for scope map key '%s' must return either None or "
   3528                 "an AttrValue protocol buffer; but it returned: %s" % (key,
   3529                                                                        value))
   3530         if value:
   3531           op._set_attr(key, value)  # pylint: disable=protected-access
   3532 
   3533     # Apply a kernel label if one has been specified for this op type.
   3534     try:
   3535       kernel_label = self._op_to_kernel_label_map[op.type]
   3536       op._set_attr("_kernel",  # pylint: disable=protected-access
   3537                    attr_value_pb2.AttrValue(s=compat.as_bytes(kernel_label)))
   3538     except KeyError:
   3539       pass
   3540 
   3541     # Apply the overriding op type for gradients if one has been specified for
   3542     # this op type.
   3543     try:
   3544       mapped_op_type = self._gradient_override_map[op.type]
   3545       op._set_attr("_gradient_op_type",  # pylint: disable=protected-access
   3546                    attr_value_pb2.AttrValue(s=compat.as_bytes(mapped_op_type)))
   3547     except KeyError:
   3548       pass
   3549 
   3550     self._record_op_seen_by_control_dependencies(op)
   3551 
   3552     if compute_device:
   3553       self._apply_device_functions(op)
   3554 
   3555     # Snapshot the colocation stack metadata before we might generate error
   3556     # messages using it.  Note that this snapshot depends on the actual stack
   3557     # and is independent of the op's _class attribute.
   3558     # pylint: disable=protected-access
   3559     op._colocation_code_locations = self._snapshot_colocation_stack_metadata()
   3560     # pylint: enable=protected-access
   3561 
   3562     if self._colocation_stack:
   3563       all_colocation_groups = []
   3564       for colocation_op in self._colocation_stack.peek_objs():
   3565         all_colocation_groups.extend(colocation_op.colocation_groups())
   3566         if colocation_op.device:
   3567           # pylint: disable=protected-access
   3568           op._set_device(colocation_op.device)
   3569           # pylint: enable=protected-access
   3570 
   3571       all_colocation_groups = sorted(set(all_colocation_groups))
   3572       # pylint: disable=protected-access
   3573       op._set_attr("_class", attr_value_pb2.AttrValue(
   3574           list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))
   3575       # pylint: enable=protected-access
   3576 
   3577     # Sets "container" attribute if
   3578     # (1) self._container is not None
   3579     # (2) "is_stateful" is set in OpDef
   3580     # (3) "container" attribute is in OpDef
   3581     # (4) "container" attribute is None
   3582     if self._container and op.op_def.is_stateful:
   3583       try:
   3584         container_attr = op.get_attr("container")
   3585       except ValueError:
   3586         # "container" attribute is not in OpDef
   3587         pass
   3588       else:
   3589         if not container_attr:
   3590           op._set_attr("container", attr_value_pb2.AttrValue(  # pylint: disable=protected-access
   3591               s=compat.as_bytes(self._container)))
   3592 
   3593   def _add_new_tf_operations(self, compute_devices=True):
   3594     """Creates `Operations` in this graph for any new TF_Operations.
   3595 
   3596     This is useful for when TF_Operations are indirectly created by the C API
   3597     outside of the Operation constructor (e.g. by TF_ImportGraphDef,
   3598     TF_FinishWhile). This ensures there are corresponding Operations for all
   3599     TF_Operations in the underlying TF_Graph.
   3600 
   3601     Args:
   3602       compute_devices: (Optional.) If True, device functions will be executed
   3603         to compute the device properties of each new Operation.
   3604 
   3605     Returns:
   3606       A list of the new `Operation` objects.
   3607     """
   3608     # Create all Operation objects before accessing their inputs since an op may
   3609     # be created before its inputs.
   3610     new_ops = [
   3611         self._create_op_from_tf_operation(c_op, compute_device=compute_devices)
   3612         for c_op in c_api_util.new_tf_operations(self)
   3613     ]
   3614 
   3615     # pylint: disable=protected-access
   3616     for op in new_ops:
   3617       new_control_inputs = self._control_dependencies_for_inputs(op.inputs)
   3618       op._add_control_inputs(new_control_inputs)
   3619       op._control_flow_post_processing()
   3620     # pylint: enable=protected-access
   3621 
   3622     return new_ops
   3623 
   3624   def as_graph_element(self, obj, allow_tensor=True, allow_operation=True):
   3625     """Returns the object referred to by `obj`, as an `Operation` or `Tensor`.
   3626 
   3627     This function validates that `obj` represents an element of this
   3628     graph, and gives an informative error message if it is not.
   3629 
   3630     This function is the canonical way to get/validate an object of
   3631     one of the allowed types from an external argument reference in the
   3632     Session API.
   3633 
   3634     This method may be called concurrently from multiple threads.
   3635 
   3636     Args:
   3637       obj: A `Tensor`, an `Operation`, or the name of a tensor or operation.
   3638         Can also be any object with an `_as_graph_element()` method that returns
   3639         a value of one of these types.
   3640       allow_tensor: If true, `obj` may refer to a `Tensor`.
   3641       allow_operation: If true, `obj` may refer to an `Operation`.
   3642 
   3643     Returns:
   3644       The `Tensor` or `Operation` in the Graph corresponding to `obj`.
   3645 
   3646     Raises:
   3647       TypeError: If `obj` is not a type we support attempting to convert
   3648         to types.
   3649       ValueError: If `obj` is of an appropriate type but invalid. For
   3650         example, an invalid string.
   3651       KeyError: If `obj` is not an object in the graph.
   3652     """
   3653     if self._finalized:
   3654       return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
   3655 
   3656     with self._lock:
   3657       return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
   3658 
   3659   def _as_graph_element_locked(self, obj, allow_tensor, allow_operation):
   3660     """See `Graph.as_graph_element()` for details."""
   3661     # The vast majority of this function is figuring
   3662     # out what an API user might be doing wrong, so
   3663     # that we can give helpful error messages.
   3664     #
   3665     # Ideally, it would be nice to split it up, but we
   3666     # need context to generate nice error messages.
   3667 
   3668     if allow_tensor and allow_operation:
   3669       types_str = "Tensor or Operation"
   3670     elif allow_tensor:
   3671       types_str = "Tensor"
   3672     elif allow_operation:
   3673       types_str = "Operation"
   3674     else:
   3675       raise ValueError("allow_tensor and allow_operation can't both be False.")
   3676 
   3677     temp_obj = _as_graph_element(obj)
   3678     if temp_obj is not None:
   3679       obj = temp_obj
   3680 
   3681     # If obj appears to be a name...
   3682     if isinstance(obj, compat.bytes_or_text_types):
   3683       name = compat.as_str(obj)
   3684 
   3685       if ":" in name and allow_tensor:
   3686         # Looks like a Tensor name and can be a Tensor.
   3687         try:
   3688           op_name, out_n = name.split(":")
   3689           out_n = int(out_n)
   3690         except:
   3691           raise ValueError("The name %s looks a like a Tensor name, but is "
   3692                            "not a valid one. Tensor names must be of the "
   3693                            "form \"<op_name>:<output_index>\"." % repr(name))
   3694         if op_name in self._nodes_by_name:
   3695           op = self._nodes_by_name[op_name]
   3696         else:
   3697           raise KeyError("The name %s refers to a Tensor which does not "
   3698                          "exist. The operation, %s, does not exist in the "
   3699                          "graph." % (repr(name), repr(op_name)))
   3700         try:
   3701           return op.outputs[out_n]
   3702         except:
   3703           raise KeyError("The name %s refers to a Tensor which does not "
   3704                          "exist. The operation, %s, exists but only has "
   3705                          "%s outputs." % (repr(name), repr(op_name),
   3706                                           len(op.outputs)))
   3707 
   3708       elif ":" in name and not allow_tensor:
   3709         # Looks like a Tensor name but can't be a Tensor.
   3710         raise ValueError("Name %s appears to refer to a Tensor, not a %s." %
   3711                          (repr(name), types_str))
   3712 
   3713       elif ":" not in name and allow_operation:
   3714         # Looks like an Operation name and can be an Operation.
   3715         if name not in self._nodes_by_name:
   3716           raise KeyError("The name %s refers to an Operation not in the "
   3717                          "graph." % repr(name))
   3718         return self._nodes_by_name[name]
   3719 
   3720       elif ":" not in name and not allow_operation:
   3721         # Looks like an Operation name but can't be an Operation.
   3722         if name in self._nodes_by_name:
   3723           # Yep, it's an Operation name
   3724           err_msg = ("The name %s refers to an Operation, not a %s." %
   3725                      (repr(name), types_str))
   3726         else:
   3727           err_msg = ("The name %s looks like an (invalid) Operation name, "
   3728                      "not a %s." % (repr(name), types_str))
   3729         err_msg += (" Tensor names must be of the form "
   3730                     "\"<op_name>:<output_index>\".")
   3731         raise ValueError(err_msg)
   3732 
   3733     elif isinstance(obj, Tensor) and allow_tensor:
   3734       # Actually obj is just the object it's referring to.
   3735       if obj.graph is not self:
   3736         raise ValueError("Tensor %s is not an element of this graph." % obj)
   3737       return obj
   3738     elif isinstance(obj, Operation) and allow_operation:
   3739       # Actually obj is just the object it's referring to.
   3740       if obj.graph is not self:
   3741         raise ValueError("Operation %s is not an element of this graph." % obj)
   3742       return obj
   3743     else:
   3744       # We give up!
   3745       raise TypeError("Can not convert a %s into a %s." % (type(obj).__name__,
   3746                                                            types_str))
   3747 
   3748   def get_operations(self):
   3749     """Return the list of operations in the graph.
   3750 
   3751     You can modify the operations in place, but modifications
   3752     to the list such as inserts/delete have no effect on the
   3753     list of operations known to the graph.
   3754 
   3755     This method may be called concurrently from multiple threads.
   3756 
   3757     Returns:
   3758       A list of Operations.
   3759     """
   3760     if self._finalized:
   3761       return list(self._nodes_by_id.values())
   3762 
   3763     with self._lock:
   3764       return list(self._nodes_by_id.values())
   3765 
   3766   def get_operation_by_name(self, name):
   3767     """Returns the `Operation` with the given `name`.
   3768 
   3769     This method may be called concurrently from multiple threads.
   3770 
   3771     Args:
   3772       name: The name of the `Operation` to return.
   3773 
   3774     Returns:
   3775       The `Operation` with the given `name`.
   3776 
   3777     Raises:
   3778       TypeError: If `name` is not a string.
   3779       KeyError: If `name` does not correspond to an operation in this graph.
   3780     """
   3781 
   3782     if not isinstance(name, six.string_types):
   3783       raise TypeError("Operation names are strings (or similar), not %s." %
   3784                       type(name).__name__)
   3785     return self.as_graph_element(name, allow_tensor=False, allow_operation=True)
   3786 
   3787   def _get_operation_by_name_unsafe(self, name):
   3788     """Returns the `Operation` with the given `name`.
   3789 
   3790     This is a internal unsafe version of get_operation_by_name. It skips many
   3791     checks and does not have user friedly error messages but runs considerably
   3792     faster. This method may be called concurrently from multiple threads.
   3793 
   3794     Args:
   3795       name: The name of the `Operation` to return.
   3796 
   3797     Returns:
   3798       The `Operation` with the given `name`.
   3799 
   3800     Raises:
   3801       KeyError: If `name` does not correspond to an operation in this graph.
   3802     """
   3803 
   3804     if self._finalized:
   3805       return self._nodes_by_name[name]
   3806 
   3807     with self._lock:
   3808       return self._nodes_by_name[name]
   3809 
   3810   def _get_operation_by_tf_operation(self, tf_oper):
   3811     op_name = c_api.TF_OperationName(tf_oper)
   3812     return self._get_operation_by_name_unsafe(op_name)
   3813 
   3814   def get_tensor_by_name(self, name):
   3815     """Returns the `Tensor` with the given `name`.
   3816 
   3817     This method may be called concurrently from multiple threads.
   3818 
   3819     Args:
   3820       name: The name of the `Tensor` to return.
   3821 
   3822     Returns:
   3823       The `Tensor` with the given `name`.
   3824 
   3825     Raises:
   3826       TypeError: If `name` is not a string.
   3827       KeyError: If `name` does not correspond to a tensor in this graph.
   3828     """
   3829     # Names should be strings.
   3830     if not isinstance(name, six.string_types):
   3831       raise TypeError("Tensor names are strings (or similar), not %s." %
   3832                       type(name).__name__)
   3833     return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
   3834 
   3835   def _get_tensor_by_tf_output(self, tf_output):
   3836     """Returns the `Tensor` representing `tf_output`.
   3837 
   3838     Note that there is only one such `Tensor`, i.e. multiple calls to this
   3839     function with the same TF_Output value will always return the same `Tensor`
   3840     object.
   3841 
   3842     Args:
   3843       tf_output: A wrapped `TF_Output` (the C API equivalent of `Tensor`).
   3844 
   3845     Returns:
   3846       The `Tensor` that represents `tf_output`.
   3847     """
   3848     op = self._get_operation_by_tf_operation(tf_output.oper)
   3849     return op.outputs[tf_output.index]
   3850 
   3851   def _next_id(self):
   3852     """Id for next Operation instance. Also increments the internal id."""
   3853     self._check_not_finalized()
   3854     with self._lock:
   3855       self._next_id_counter += 1
   3856       return self._next_id_counter
   3857 
   3858   @property
   3859   def _last_id(self):
   3860     return self._next_id_counter
   3861 
   3862   def _get_op_def(self, type):  # pylint: disable=redefined-builtin
   3863     """Returns the `OpDef` proto for `type`. `type` is a string."""
   3864     with c_api_util.tf_buffer() as buf:
   3865       # pylint: disable=protected-access
   3866       c_api.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), buf)
   3867       # pylint: enable=protected-access
   3868       data = c_api.TF_GetBuffer(buf)
   3869     op_def = op_def_pb2.OpDef()
   3870     op_def.ParseFromString(compat.as_bytes(data))
   3871     return op_def
   3872 
   3873   def as_default(self):
   3874     """Returns a context manager that makes this `Graph` the default graph.
   3875 
   3876     This method should be used if you want to create multiple graphs
   3877     in the same process. For convenience, a global default graph is
   3878     provided, and all ops will be added to this graph if you do not
   3879     create a new graph explicitly.
   3880 
   3881     Use this method with the `with` keyword to specify that ops created within
   3882     the scope of a block should be added to this graph. In this case, once
   3883     the scope of the `with` is exited, the previous default graph is set again
   3884     as default. There is a stack, so it's ok to have multiple nested levels
   3885     of `as_default` calls.
   3886 
   3887     The default graph is a property of the current thread. If you
   3888     create a new thread, and wish to use the default graph in that
   3889     thread, you must explicitly add a `with g.as_default():` in that
   3890     thread's function.
   3891 
   3892     The following code examples are equivalent:
   3893 
   3894     ```python
   3895     # 1. Using Graph.as_default():
   3896     g = tf.Graph()
   3897     with g.as_default():
   3898       c = tf.constant(5.0)
   3899       assert c.graph is g
   3900 
   3901     # 2. Constructing and making default:
   3902     with tf.Graph().as_default() as g:
   3903       c = tf.constant(5.0)
   3904       assert c.graph is g
   3905     ```
   3906 
   3907     If eager execution is enabled ops created under this context manager will be
   3908     added to the graph instead of executed eagerly.
   3909 
   3910     Returns:
   3911       A context manager for using this graph as the default graph.
   3912     """
   3913     return _default_graph_stack.get_controller(self)
   3914 
   3915   @property
   3916   def collections(self):
   3917     """Returns the names of the collections known to this graph."""
   3918     return list(self._collections)
   3919 
   3920   def add_to_collection(self, name, value):
   3921     """Stores `value` in the collection with the given `name`.
   3922 
   3923     Note that collections are not sets, so it is possible to add a value to
   3924     a collection several times.
   3925 
   3926     Args:
   3927       name: The key for the collection. The `GraphKeys` class
   3928         contains many standard names for collections.
   3929       value: The value to add to the collection.
   3930     """  # pylint: disable=g-doc-exception
   3931     self._check_not_finalized()
   3932     with self._lock:
   3933       if name not in self._collections:
   3934         self._collections[name] = [value]
   3935       else:
   3936         self._collections[name].append(value)
   3937 
   3938   def add_to_collections(self, names, value):
   3939     """Stores `value` in the collections given by `names`.
   3940 
   3941     Note that collections are not sets, so it is possible to add a value to
   3942     a collection several times. This function makes sure that duplicates in
   3943     `names` are ignored, but it will not check for pre-existing membership of
   3944     `value` in any of the collections in `names`.
   3945 
   3946     `names` can be any iterable, but if `names` is a string, it is treated as a
   3947     single collection name.
   3948 
   3949     Args:
   3950       names: The keys for the collections to add to. The `GraphKeys` class
   3951         contains many standard names for collections.
   3952       value: The value to add to the collections.
   3953     """
   3954     # Make sure names are unique, but treat strings as a single collection name
   3955     names = (names,) if isinstance(names, six.string_types) else set(names)
   3956     for name in names:
   3957       self.add_to_collection(name, value)
   3958 
   3959   def get_collection_ref(self, name):
   3960     """Returns a list of values in the collection with the given `name`.
   3961 
   3962     If the collection exists, this returns the list itself, which can
   3963     be modified in place to change the collection.  If the collection does
   3964     not exist, it is created as an empty list and the list is returned.
   3965 
   3966     This is different from `get_collection()` which always returns a copy of
   3967     the collection list if it exists and never creates an empty collection.
   3968 
   3969     Args:
   3970       name: The key for the collection. For example, the `GraphKeys` class
   3971         contains many standard names for collections.
   3972 
   3973     Returns:
   3974       The list of values in the collection with the given `name`, or an empty
   3975       list if no value has been added to that collection.
   3976     """  # pylint: disable=g-doc-exception
   3977     with self._lock:
   3978       coll_list = self._collections.get(name, None)
   3979       if coll_list is None:
   3980         coll_list = []
   3981         self._collections[name] = coll_list
   3982       return coll_list
   3983 
   3984   def get_collection(self, name, scope=None):
   3985     """Returns a list of values in the collection with the given `name`.
   3986 
   3987     This is different from `get_collection_ref()` which always returns the
   3988     actual collection list if it exists in that it returns a new list each time
   3989     it is called.
   3990 
   3991     Args:
   3992       name: The key for the collection. For example, the `GraphKeys` class
   3993         contains many standard names for collections.
   3994       scope: (Optional.) A string. If supplied, the resulting list is filtered
   3995         to include only items whose `name` attribute matches `scope` using
   3996         `re.match`. Items without a `name` attribute are never returned if a
   3997         scope is supplied. The choice of `re.match` means that a `scope` without
   3998         special tokens filters by prefix.
   3999 
   4000     Returns:
   4001       The list of values in the collection with the given `name`, or
   4002       an empty list if no value has been added to that collection. The
   4003       list contains the values in the order under which they were
   4004       collected.
   4005     """  # pylint: disable=g-doc-exception
   4006     with self._lock:
   4007       collection = self._collections.get(name, None)
   4008       if collection is None:
   4009         return []
   4010       if scope is None:
   4011         return list(collection)
   4012       else:
   4013         c = []
   4014         regex = re.compile(scope)
   4015         for item in collection:
   4016           if hasattr(item, "name") and regex.match(item.name):
   4017             c.append(item)
   4018         return c
   4019 
   4020   def get_all_collection_keys(self):
   4021     """Returns a list of collections used in this graph."""
   4022     with self._lock:
   4023       return [x for x in self._collections if isinstance(x, six.string_types)]
   4024 
   4025   def clear_collection(self, name):
   4026     """Clears all values in a collection.
   4027 
   4028     Args:
   4029       name: The key for the collection. The `GraphKeys` class contains many
   4030         standard names for collections.
   4031     """
   4032     self._check_not_finalized()
   4033     with self._lock:
   4034       if name in self._collections:
   4035         del self._collections[name]
   4036 
   4037   @tf_contextlib.contextmanager
   4038   def _original_op(self, op):
   4039     """Python 'with' handler to help annotate ops with their originator.
   4040 
   4041     An op may have an 'original_op' property that indicates the op on which
   4042     it was based. For example a replica op is based on the op that was
   4043     replicated and a gradient op is based on the op that was differentiated.
   4044 
   4045     All ops created in the scope of this 'with' handler will have
   4046     the given 'op' as their original op.
   4047 
   4048     Args:
   4049       op: The Operation that all ops created in this scope will have as their
   4050         original op.
   4051 
   4052     Yields:
   4053       Nothing.
   4054     """
   4055     old_original_op = self._default_original_op
   4056     self._default_original_op = op
   4057     try:
   4058       yield
   4059     finally:
   4060       self._default_original_op = old_original_op
   4061 
   4062   @property
   4063   def _name_stack(self):
   4064     # This may be called from a thread where name_stack doesn't yet exist.
   4065     if not hasattr(self._thread_local, "_name_stack"):
   4066       self._thread_local._name_stack = ""
   4067     return self._thread_local._name_stack
   4068 
   4069   @_name_stack.setter
   4070   def _name_stack(self, name_stack):
   4071     self._thread_local._name_stack = name_stack
   4072 
   4073   # pylint: disable=g-doc-return-or-yield,line-too-long
   4074   @tf_contextlib.contextmanager
   4075   def name_scope(self, name):
   4076     r"""Returns a context manager that creates hierarchical names for operations.
   4077 
   4078     A graph maintains a stack of name scopes. A `with name_scope(...):`
   4079     statement pushes a new name onto the stack for the lifetime of the context.
   4080 
   4081     The `name` argument will be interpreted as follows:
   4082 
   4083     * A string (not ending with '/') will create a new name scope, in which
   4084       `name` is appended to the prefix of all operations created in the
   4085       context. If `name` has been used before, it will be made unique by
   4086       calling `self.unique_name(name)`.
   4087     * A scope previously captured from a `with g.name_scope(...) as
   4088       scope:` statement will be treated as an "absolute" name scope, which
   4089       makes it possible to re-enter existing scopes.
   4090     * A value of `None` or the empty string will reset the current name scope
   4091       to the top-level (empty) name scope.
   4092 
   4093     For example:
   4094 
   4095     ```python
   4096     with tf.Graph().as_default() as g:
   4097       c = tf.constant(5.0, name="c")
   4098       assert c.op.name == "c"
   4099       c_1 = tf.constant(6.0, name="c")
   4100       assert c_1.op.name == "c_1"
   4101 
   4102       # Creates a scope called "nested"
   4103       with g.name_scope("nested") as scope:
   4104         nested_c = tf.constant(10.0, name="c")
   4105         assert nested_c.op.name == "nested/c"
   4106 
   4107         # Creates a nested scope called "inner".
   4108         with g.name_scope("inner"):
   4109           nested_inner_c = tf.constant(20.0, name="c")
   4110           assert nested_inner_c.op.name == "nested/inner/c"
   4111 
   4112         # Create a nested scope called "inner_1".
   4113         with g.name_scope("inner"):
   4114           nested_inner_1_c = tf.constant(30.0, name="c")
   4115           assert nested_inner_1_c.op.name == "nested/inner_1/c"
   4116 
   4117           # Treats `scope` as an absolute name scope, and
   4118           # switches to the "nested/" scope.
   4119           with g.name_scope(scope):
   4120             nested_d = tf.constant(40.0, name="d")
   4121             assert nested_d.op.name == "nested/d"
   4122 
   4123             with g.name_scope(""):
   4124               e = tf.constant(50.0, name="e")
   4125               assert e.op.name == "e"
   4126     ```
   4127 
   4128     The name of the scope itself can be captured by `with
   4129     g.name_scope(...) as scope:`, which stores the name of the scope
   4130     in the variable `scope`. This value can be used to name an
   4131     operation that represents the overall result of executing the ops
   4132     in a scope. For example:
   4133 
   4134     ```python
   4135     inputs = tf.constant(...)
   4136     with g.name_scope('my_layer') as scope:
   4137       weights = tf.Variable(..., name="weights")
   4138       biases = tf.Variable(..., name="biases")
   4139       affine = tf.matmul(inputs, weights) + biases
   4140       output = tf.nn.relu(affine, name=scope)
   4141     ```
   4142 
   4143     NOTE: This constructor validates the given `name`. Valid scope
   4144     names match one of the following regular expressions:
   4145 
   4146         [A-Za-z0-9.][A-Za-z0-9_.\\-/]* (for scopes at the root)
   4147         [A-Za-z0-9_.\\-/]* (for other scopes)
   4148 
   4149     Args:
   4150       name: A name for the scope.
   4151 
   4152     Returns:
   4153       A context manager that installs `name` as a new name scope.
   4154 
   4155     Raises:
   4156       ValueError: If `name` is not a valid scope name, according to the rules
   4157         above.
   4158     """
   4159     if name:
   4160       if isinstance(name, compat.bytes_or_text_types):
   4161         name = compat.as_str(name)
   4162 
   4163       if self._name_stack:
   4164         # Scopes created in a nested scope may have initial characters
   4165         # that are illegal as the initial character of an op name
   4166         # (viz. '-', '\', '/', and '_').
   4167         if not _VALID_SCOPE_NAME_REGEX.match(name):
   4168           raise ValueError("'%s' is not a valid scope name" % name)
   4169       else:
   4170         # Scopes created in the root must match the more restrictive
   4171         # op name regex, which constrains the initial character.
   4172         if not _VALID_OP_NAME_REGEX.match(name):
   4173           raise ValueError("'%s' is not a valid scope name" % name)
   4174     old_stack = self._name_stack
   4175     if not name:  # Both for name=None and name="" we re-set to empty scope.
   4176       new_stack = None
   4177     elif name[-1] == "/":
   4178       new_stack = _name_from_scope_name(name)
   4179     else:
   4180       new_stack = self.unique_name(name)
   4181     self._name_stack = new_stack
   4182     try:
   4183       yield "" if new_stack is None else new_stack + "/"
   4184     finally:
   4185       self._name_stack = old_stack
   4186 
   4187   # pylint: enable=g-doc-return-or-yield,line-too-long
   4188 
   4189   def unique_name(self, name, mark_as_used=True):
   4190     """Return a unique operation name for `name`.
   4191 
   4192     Note: You rarely need to call `unique_name()` directly.  Most of
   4193     the time you just need to create `with g.name_scope()` blocks to
   4194     generate structured names.
   4195 
   4196     `unique_name` is used to generate structured names, separated by
   4197     `"/"`, to help identify operations when debugging a graph.
   4198     Operation names are displayed in error messages reported by the
   4199     TensorFlow runtime, and in various visualization tools such as
   4200     TensorBoard.
   4201 
   4202     If `mark_as_used` is set to `True`, which is the default, a new
   4203     unique name is created and marked as in use. If it's set to `False`,
   4204     the unique name is returned without actually being marked as used.
   4205     This is useful when the caller simply wants to know what the name
   4206     to be created will be.
   4207 
   4208     Args:
   4209       name: The name for an operation.
   4210       mark_as_used: Whether to mark this name as being used.
   4211 
   4212     Returns:
   4213       A string to be passed to `create_op()` that will be used
   4214       to name the operation being created.
   4215     """
   4216     if self._name_stack:
   4217       name = self._name_stack + "/" + name
   4218 
   4219     # For the sake of checking for names in use, we treat names as case
   4220     # insensitive (e.g. foo = Foo).
   4221     name_key = name.lower()
   4222     i = self._names_in_use.get(name_key, 0)
   4223     # Increment the number for "name_key".
   4224     if mark_as_used:
   4225       self._names_in_use[name_key] = i + 1
   4226     if i > 0:
   4227       base_name_key = name_key
   4228       # Make sure the composed name key is not already used.
   4229       while name_key in self._names_in_use:
   4230         name_key = "%s_%d" % (base_name_key, i)
   4231         i += 1
   4232       # Mark the composed name_key as used in case someone wants
   4233       # to call unique_name("name_1").
   4234       if mark_as_used:
   4235         self._names_in_use[name_key] = 1
   4236 
   4237       # Return the new name with the original capitalization of the given name.
   4238       name = "%s_%d" % (name, i-1)
   4239     return name
   4240 
   4241   def get_name_scope(self):
   4242     """Returns the current name scope.
   4243 
   4244     For example:
   4245 
   4246     ```python
   4247     with tf.name_scope('scope1'):
   4248       with tf.name_scope('scope2'):
   4249         print(tf.get_default_graph().get_name_scope())
   4250     ```
   4251     would print the string `scope1/scope2`.
   4252 
   4253     Returns:
   4254       A string representing the current name scope.
   4255     """
   4256     return self._name_stack
   4257 
   4258   @tf_contextlib.contextmanager
   4259   def _colocate_with_for_gradient(self, op, gradient_uid,
   4260                                   ignore_existing=False):
   4261     with self.colocate_with(op, ignore_existing):
   4262       if gradient_uid is not None and self._control_flow_context is not None:
   4263         self._control_flow_context.EnterGradientColocation(op, gradient_uid)
   4264         try:
   4265           yield
   4266         finally:
   4267           self._control_flow_context.ExitGradientColocation(op, gradient_uid)
   4268       else:
   4269         yield
   4270 
   4271   @tf_contextlib.contextmanager
   4272   def colocate_with(self, op, ignore_existing=False):
   4273     """Returns a context manager that specifies an op to colocate with.
   4274 
   4275     Note: this function is not for public use, only for internal libraries.
   4276 
   4277     For example:
   4278 
   4279     ```python
   4280     a = tf.Variable([1.0])
   4281     with g.colocate_with(a):
   4282       b = tf.constant(1.0)
   4283       c = tf.add(a, b)
   4284     ```
   4285 
   4286     `b` and `c` will always be colocated with `a`, no matter where `a`
   4287     is eventually placed.
   4288 
   4289     **NOTE** Using a colocation scope resets any existing device constraints.
   4290 
   4291     If `op` is `None` then `ignore_existing` must be `True` and the new
   4292     scope resets all colocation and device constraints.
   4293 
   4294     Args:
   4295       op: The op to colocate all created ops with, or `None`.
   4296       ignore_existing: If true, only applies colocation of this op within
   4297         the context, rather than applying all colocation properties
   4298         on the stack.  If `op` is `None`, this value must be `True`.
   4299 
   4300     Raises:
   4301       ValueError: if op is None but ignore_existing is False.
   4302 
   4303     Yields:
   4304       A context manager that specifies the op with which to colocate
   4305       newly created ops.
   4306     """
   4307     if op is None and not ignore_existing:
   4308       raise ValueError("Trying to reset colocation (op is None) but "
   4309                        "ignore_existing is not True")
   4310     op = _op_to_colocate_with(op)
   4311 
   4312     # By default, colocate_with resets the device function stack,
   4313     # since colocate_with is typically used in specific internal
   4314     # library functions where colocation is intended to be "stronger"
   4315     # than device functions.
   4316     #
   4317     # In the future, a caller may specify that device_functions win
   4318     # over colocation, in which case we can add support.
   4319     device_fn_tmp = self._device_function_stack
   4320     self._device_function_stack = traceable_stack.TraceableStack()
   4321 
   4322     if ignore_existing:
   4323       current_stack = self._colocation_stack
   4324       self._colocation_stack = traceable_stack.TraceableStack()
   4325 
   4326     if op is not None:
   4327       # offset refers to the stack frame used for storing code location.
   4328       # We use 4, the sum of 1 to use our caller's stack frame and 3
   4329       # to jump over layers of context managers above us.
   4330       self._colocation_stack.push_obj(op, offset=4)
   4331 
   4332     try:
   4333       yield
   4334     finally:
   4335       # Restore device function stack
   4336       self._device_function_stack = device_fn_tmp
   4337       if op is not None:
   4338         self._colocation_stack.pop_obj()
   4339 
   4340       # Reset the colocation stack if requested.
   4341       if ignore_existing:
   4342         self._colocation_stack = current_stack
   4343 
   4344   def _add_device_to_stack(self, device_name_or_function, offset=0):
   4345     """Add device to stack manually, separate from a context manager."""
   4346     total_offset = 1 + offset
   4347     spec = _UserDeviceSpec(device_name_or_function)
   4348     self._device_function_stack.push_obj(spec, offset=total_offset)
   4349     return spec
   4350 
   4351   @tf_contextlib.contextmanager
   4352   def device(self, device_name_or_function):
   4353     # pylint: disable=line-too-long
   4354     """Returns a context manager that specifies the default device to use.
   4355 
   4356     The `device_name_or_function` argument may either be a device name
   4357     string, a device function, or None:
   4358 
   4359     * If it is a device name string, all operations constructed in
   4360       this context will be assigned to the device with that name, unless
   4361       overridden by a nested `device()` context.
   4362     * If it is a function, it will be treated as a function from
   4363       Operation objects to device name strings, and invoked each time
   4364       a new Operation is created. The Operation will be assigned to
   4365       the device with the returned name.
   4366     * If it is None, all `device()` invocations from the enclosing context
   4367       will be ignored.
   4368 
   4369     For information about the valid syntax of device name strings, see
   4370     the documentation in
   4371     [`DeviceNameUtils`](https://www.tensorflow.org/code/tensorflow/core/util/device_name_utils.h).
   4372 
   4373     For example:
   4374 
   4375     ```python
   4376     with g.device('/device:GPU:0'):
   4377       # All operations constructed in this context will be placed
   4378       # on GPU 0.
   4379       with g.device(None):
   4380         # All operations constructed in this context will have no
   4381         # assigned device.
   4382 
   4383     # Defines a function from `Operation` to device string.
   4384     def matmul_on_gpu(n):
   4385       if n.type == "MatMul":
   4386         return "/device:GPU:0"
   4387       else:
   4388         return "/cpu:0"
   4389 
   4390     with g.device(matmul_on_gpu):
   4391       # All operations of type "MatMul" constructed in this context
   4392       # will be placed on GPU 0; all other operations will be placed
   4393       # on CPU 0.
   4394     ```
   4395 
   4396     **N.B.** The device scope may be overridden by op wrappers or
   4397     other library code. For example, a variable assignment op
   4398     `v.assign()` must be colocated with the `tf.Variable` `v`, and
   4399     incompatible device scopes will be ignored.
   4400 
   4401     Args:
   4402       device_name_or_function: The device name or function to use in
   4403         the context.
   4404 
   4405     Yields:
   4406       A context manager that specifies the default device to use for newly
   4407       created ops.
   4408     """
   4409     self._add_device_to_stack(device_name_or_function, offset=2)
   4410     try:
   4411       yield
   4412     finally:
   4413       self._device_function_stack.pop_obj()
   4414 
   4415   def _apply_device_functions(self, op):
   4416     """Applies the current device function stack to the given operation."""
   4417     # Apply any device functions in LIFO order, so that the most recently
   4418     # pushed function has the first chance to apply a device to the op.
   4419     # We apply here because the result can depend on the Operation's
   4420     # signature, which is computed in the Operation constructor.
   4421     # pylint: disable=protected-access
   4422     for device_spec in self._device_function_stack.peek_objs():
   4423       if device_spec.function is None:
   4424         break
   4425       op._set_device(device_spec.function(op))
   4426     op._device_code_locations = self._snapshot_device_function_stack_metadata()
   4427     # pylint: enable=protected-access
   4428 
   4429   # pylint: disable=g-doc-return-or-yield
   4430   @tf_contextlib.contextmanager
   4431   def container(self, container_name):
   4432     """Returns a context manager that specifies the resource container to use.
   4433 
   4434     Stateful operations, such as variables and queues, can maintain their
   4435     states on devices so that they can be shared by multiple processes.
   4436     A resource container is a string name under which these stateful
   4437     operations are tracked. These resources can be released or cleared
   4438     with `tf.Session.reset()`.
   4439 
   4440     For example:
   4441 
   4442     ```python
   4443     with g.container('experiment0'):
   4444       # All stateful Operations constructed in this context will be placed
   4445       # in resource container "experiment0".
   4446       v1 = tf.Variable([1.0])
   4447       v2 = tf.Variable([2.0])
   4448       with g.container("experiment1"):
   4449         # All stateful Operations constructed in this context will be
   4450         # placed in resource container "experiment1".
   4451         v3 = tf.Variable([3.0])
   4452         q1 = tf.FIFOQueue(10, tf.float32)
   4453       # All stateful Operations constructed in this context will be
   4454       # be created in the "experiment0".
   4455       v4 = tf.Variable([4.0])
   4456       q1 = tf.FIFOQueue(20, tf.float32)
   4457       with g.container(""):
   4458         # All stateful Operations constructed in this context will be
   4459         # be placed in the default resource container.
   4460         v5 = tf.Variable([5.0])
   4461         q3 = tf.FIFOQueue(30, tf.float32)
   4462 
   4463     # Resets container "experiment0", after which the state of v1, v2, v4, q1
   4464     # will become undefined (such as uninitialized).
   4465     tf.Session.reset(target, ["experiment0"])
   4466     ```
   4467 
   4468     Args:
   4469       container_name: container name string.
   4470 
   4471     Returns:
   4472       A context manager for defining resource containers for stateful ops,
   4473         yields the container name.
   4474     """
   4475     original_container = self._container
   4476     self._container = container_name
   4477     try:
   4478       yield self._container
   4479     finally:
   4480       self._container = original_container
   4481 
   4482   # pylint: enable=g-doc-return-or-yield
   4483 
   4484   class _ControlDependenciesController(object):
   4485     """Context manager for `control_dependencies()`."""
   4486 
   4487     def __init__(self, graph, control_inputs):
   4488       """Create a new `_ControlDependenciesController`.
   4489 
   4490       A `_ControlDependenciesController` is the context manager for
   4491       `with tf.control_dependencies()` blocks.  These normally nest,
   4492       as described in the documentation for `control_dependencies()`.
   4493 
   4494       The `control_inputs` argument list control dependencies that must be
   4495       added to the current set of control dependencies.  Because of
   4496       uniquification the set can be empty even if the caller passed a list of
   4497       ops.  The special value `None` indicates that we want to start a new
   4498       empty set of control dependencies instead of extending the current set.
   4499 
   4500       In that case we also clear the current control flow context, which is an
   4501       additional mechanism to add control dependencies.
   4502 
   4503       Args:
   4504         graph: The graph that this controller is managing.
   4505         control_inputs: List of ops to use as control inputs in addition
   4506           to the current control dependencies.  None to indicate that
   4507           the dependencies should be cleared.
   4508       """
   4509       self._graph = graph
   4510       if control_inputs is None:
   4511         self._control_inputs_val = []
   4512         self._new_stack = True
   4513       else:
   4514         self._control_inputs_val = control_inputs
   4515         self._new_stack = False
   4516       self._seen_nodes = set()
   4517       self._old_stack = None
   4518       self._old_control_flow_context = None
   4519 
   4520 # pylint: disable=protected-access
   4521 
   4522     def __enter__(self):
   4523       if self._new_stack:
   4524         # Clear the control_dependencies graph.
   4525         self._old_stack = self._graph._control_dependencies_stack
   4526         self._graph._control_dependencies_stack = []
   4527         # Clear the control_flow_context too.
   4528         self._old_control_flow_context = self._graph._get_control_flow_context()
   4529         self._graph._set_control_flow_context(None)
   4530       self._graph._push_control_dependencies_controller(self)
   4531 
   4532     def __exit__(self, unused_type, unused_value, unused_traceback):
   4533       self._graph._pop_control_dependencies_controller(self)
   4534       if self._new_stack:
   4535         self._graph._control_dependencies_stack = self._old_stack
   4536         self._graph._set_control_flow_context(self._old_control_flow_context)
   4537 
   4538 # pylint: enable=protected-access
   4539 
   4540     @property
   4541     def control_inputs(self):
   4542       return self._control_inputs_val
   4543 
   4544     def add_op(self, op):
   4545       self._seen_nodes.add(op)
   4546 
   4547     def op_in_group(self, op):
   4548       return op in self._seen_nodes
   4549 
   4550   def _push_control_dependencies_controller(self, controller):
   4551     self._control_dependencies_stack.append(controller)
   4552 
   4553   def _pop_control_dependencies_controller(self, controller):
   4554     assert self._control_dependencies_stack[-1] is controller
   4555     self._control_dependencies_stack.pop()
   4556 
   4557   def _current_control_dependencies(self):
   4558     ret = set()
   4559     for controller in self._control_dependencies_stack:
   4560       for op in controller.control_inputs:
   4561         ret.add(op)
   4562     return ret
   4563 
   4564   def _control_dependencies_for_inputs(self, input_ops):
   4565     """For an op that takes `input_ops` as inputs, compute control inputs.
   4566 
   4567     The returned control dependencies should yield an execution that
   4568     is equivalent to adding all control inputs in
   4569     self._control_dependencies_stack to a newly created op. However,
   4570     this function attempts to prune the returned control dependencies
   4571     by observing that nodes created within the same `with
   4572     control_dependencies(...):` block may have data dependencies that make
   4573     the explicit approach redundant.
   4574 
   4575     Args:
   4576       input_ops: The data input ops for an op to be created.
   4577 
   4578     Returns:
   4579       A list of control inputs for the op to be created.
   4580     """
   4581     ret = []
   4582     for controller in self._control_dependencies_stack:
   4583       # If any of the input_ops already depends on the inputs from controller,
   4584       # we say that the new op is dominated (by that input), and we therefore
   4585       # do not need to add control dependencies for this controller's inputs.
   4586       dominated = False
   4587       for op in input_ops:
   4588         if controller.op_in_group(op):
   4589           dominated = True
   4590           break
   4591       if not dominated:
   4592         # Don't add a control input if we already have a data dependency on i.
   4593         # NOTE(mrry): We do not currently track transitive data dependencies,
   4594         #   so we may add redundant control inputs.
   4595         ret.extend([c for c in controller.control_inputs if c not in input_ops])
   4596     return ret
   4597 
   4598   def _record_op_seen_by_control_dependencies(self, op):
   4599     """Record that the given op depends on all registered control dependencies.
   4600 
   4601     Args:
   4602       op: An Operation.
   4603     """
   4604     for controller in self._control_dependencies_stack:
   4605       controller.add_op(op)
   4606 
   4607   def control_dependencies(self, control_inputs):
   4608     """Returns a context manager that specifies control dependencies.
   4609 
   4610     Use with the `with` keyword to specify that all operations constructed
   4611     within the context should have control dependencies on
   4612     `control_inputs`. For example:
   4613 
   4614     ```python
   4615     with g.control_dependencies([a, b, c]):
   4616       # `d` and `e` will only run after `a`, `b`, and `c` have executed.
   4617       d = ...
   4618       e = ...
   4619     ```
   4620 
   4621     Multiple calls to `control_dependencies()` can be nested, and in
   4622     that case a new `Operation` will have control dependencies on the union
   4623     of `control_inputs` from all active contexts.
   4624 
   4625     ```python
   4626     with g.control_dependencies([a, b]):
   4627       # Ops constructed here run after `a` and `b`.
   4628       with g.control_dependencies([c, d]):
   4629         # Ops constructed here run after `a`, `b`, `c`, and `d`.
   4630     ```
   4631 
   4632     You can pass None to clear the control dependencies:
   4633 
   4634     ```python
   4635     with g.control_dependencies([a, b]):
   4636       # Ops constructed here run after `a` and `b`.
   4637       with g.control_dependencies(None):
   4638         # Ops constructed here run normally, not waiting for either `a` or `b`.
   4639         with g.control_dependencies([c, d]):
   4640           # Ops constructed here run after `c` and `d`, also not waiting
   4641           # for either `a` or `b`.
   4642     ```
   4643 
   4644     *N.B.* The control dependencies context applies *only* to ops that
   4645     are constructed within the context. Merely using an op or tensor
   4646     in the context does not add a control dependency. The following
   4647     example illustrates this point:
   4648 
   4649     ```python
   4650     # WRONG
   4651     def my_func(pred, tensor):
   4652       t = tf.matmul(tensor, tensor)
   4653       with tf.control_dependencies([pred]):
   4654         # The matmul op is created outside the context, so no control
   4655         # dependency will be added.
   4656         return t
   4657 
   4658     # RIGHT
   4659     def my_func(pred, tensor):
   4660       with tf.control_dependencies([pred]):
   4661         # The matmul op is created in the context, so a control dependency
   4662         # will be added.
   4663         return tf.matmul(tensor, tensor)
   4664     ```
   4665 
   4666     Also note that though execution of ops created under this scope will trigger
   4667     execution of the dependencies, the ops created under this scope might still
   4668     be pruned from a normal tensorflow graph. For example, in the following
   4669     snippet of code the dependencies are never executed:
   4670 
   4671     ```python
   4672       loss = model.loss()
   4673       with tf.control_dependencies(dependencies):
   4674         loss = loss + tf.constant(1)  # note: dependencies ignored in the
   4675                                       # backward pass
   4676       return tf.gradients(loss, model.variables)
   4677     ```
   4678 
   4679     This is because evaluating the gradient graph does not require evaluating
   4680     the constant(1) op created in the forward pass.
   4681 
   4682     Args:
   4683       control_inputs: A list of `Operation` or `Tensor` objects which
   4684         must be executed or computed before running the operations
   4685         defined in the context.  Can also be `None` to clear the control
   4686         dependencies.
   4687 
   4688     Returns:
   4689      A context manager that specifies control dependencies for all
   4690      operations constructed within the context.
   4691 
   4692     Raises:
   4693       TypeError: If `control_inputs` is not a list of `Operation` or
   4694         `Tensor` objects.
   4695     """
   4696     if control_inputs is None:
   4697       return self._ControlDependenciesController(self, None)
   4698     # First convert the inputs to ops, and deduplicate them.
   4699     # NOTE(mrry): Other than deduplication, we do not currently track direct
   4700     #   or indirect dependencies between control_inputs, which may result in
   4701     #   redundant control inputs.
   4702     control_ops = []
   4703     current = self._current_control_dependencies()
   4704     for c in control_inputs:
   4705       # The hasattr(handle) is designed to match ResourceVariables. This is so
   4706       # control dependencies on a variable or on an unread variable don't
   4707       # trigger reads.
   4708       if (isinstance(c, IndexedSlices) or
   4709           (hasattr(c, "_handle") and hasattr(c, "op"))):
   4710         c = c.op
   4711       c = self.as_graph_element(c)
   4712       if isinstance(c, Tensor):
   4713         c = c.op
   4714       elif not isinstance(c, Operation):
   4715         raise TypeError("Control input must be Operation or Tensor: %s" % c)
   4716       if c not in current:
   4717         control_ops.append(c)
   4718         current.add(c)
   4719     return self._ControlDependenciesController(self, control_ops)
   4720 
   4721   # pylint: disable=g-doc-return-or-yield
   4722   @tf_contextlib.contextmanager
   4723   def _attr_scope(self, attr_map):
   4724     """EXPERIMENTAL: A context manager for setting attributes on operators.
   4725 
   4726     This context manager can be used to add additional
   4727     attributes to operators within the scope of the context.
   4728 
   4729     For example:
   4730 
   4731        with ops.Graph().as_default() as g:
   4732          f_1 = Foo()  # No extra attributes
   4733          with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=False)}):
   4734            f_2 = Foo()  # Additional attribute _a=False
   4735            with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=True)}):
   4736              f_3 = Foo()  # Additional attribute _a=False
   4737              with g._attr_scope({"_a": None}):
   4738                f_4 = Foo()  # No additional attributes.
   4739 
   4740     Args:
   4741       attr_map: A dictionary mapping attr name strings to
   4742         AttrValue protocol buffers or None.
   4743 
   4744     Returns:
   4745       A context manager that sets the kernel label to be used for one or more
   4746       ops created in that context.
   4747 
   4748     Raises:
   4749       TypeError: If attr_map is not a dictionary mapping
   4750         strings to AttrValue protobufs.
   4751     """
   4752     if not isinstance(attr_map, dict):
   4753       raise TypeError("attr_map must be a dictionary mapping "
   4754                       "strings to AttrValue protocol buffers")
   4755     # The saved_attrs dictionary stores any currently-set labels that
   4756     # will be overridden by this context manager.
   4757     saved_attrs = {}
   4758     # Install the given attribute
   4759     for name, attr in attr_map.items():
   4760       if not (isinstance(name, six.string_types) and
   4761               (isinstance(attr, (type(None), attr_value_pb2.AttrValue)) or
   4762                callable(attr))):
   4763         raise TypeError("attr_map must be a dictionary mapping "
   4764                         "strings to AttrValue protocol buffers or "
   4765                         "callables that emit AttrValue protocol buffers")
   4766       try:
   4767         saved_attrs[name] = self._attr_scope_map[name]
   4768       except KeyError:
   4769         pass
   4770       if attr is None:
   4771         del self._attr_scope_map[name]
   4772       else:
   4773         self._attr_scope_map[name] = attr
   4774     try:
   4775       yield  # The code within the context runs here.
   4776     finally:
   4777       # Remove the attributes set for this context, and restore any saved
   4778       # attributes.
   4779       for name, attr in attr_map.items():
   4780         try:
   4781           self._attr_scope_map[name] = saved_attrs[name]
   4782         except KeyError:
   4783           del self._attr_scope_map[name]
   4784 
   4785   # pylint: enable=g-doc-return-or-yield
   4786 
   4787   # pylint: disable=g-doc-return-or-yield
   4788   @tf_contextlib.contextmanager
   4789   def _kernel_label_map(self, op_to_kernel_label_map):
   4790     """EXPERIMENTAL: A context manager for setting kernel labels.
   4791 
   4792     This context manager can be used to select particular
   4793     implementations of kernels within the scope of the context.
   4794 
   4795     For example:
   4796 
   4797         with ops.Graph().as_default() as g:
   4798           f_1 = Foo()  # Uses the default registered kernel for the Foo op.
   4799           with g.kernel_label_map({"Foo": "v_2"}):
   4800             f_2 = Foo()  # Uses the registered kernel with label "v_2"
   4801                          # for the Foo op.
   4802             with g.kernel_label_map({"Foo": "v_3"}):
   4803               f_3 = Foo()  # Uses the registered kernel with label "v_3"
   4804                            # for the Foo op.
   4805               with g.kernel_label_map({"Foo": ""}):
   4806                 f_4 = Foo()  # Uses the default registered kernel
   4807                              # for the Foo op.
   4808 
   4809     Args:
   4810       op_to_kernel_label_map: A dictionary mapping op type strings to
   4811         kernel label strings.
   4812 
   4813     Returns:
   4814       A context manager that sets the kernel label to be used for one or more
   4815       ops created in that context.
   4816 
   4817     Raises:
   4818       TypeError: If op_to_kernel_label_map is not a dictionary mapping
   4819         strings to strings.
   4820     """
   4821     if not isinstance(op_to_kernel_label_map, dict):
   4822       raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
   4823                       "strings to strings")
   4824     # The saved_labels dictionary stores any currently-set labels that
   4825     # will be overridden by this context manager.
   4826     saved_labels = {}
   4827     # Install the given label
   4828     for op_type, label in op_to_kernel_label_map.items():
   4829       if not (isinstance(op_type, six.string_types) and
   4830               isinstance(label, six.string_types)):
   4831         raise TypeError("op_to_kernel_label_map must be a dictionary mapping "
   4832                         "strings to strings")
   4833       try:
   4834         saved_labels[op_type] = self._op_to_kernel_label_map[op_type]
   4835       except KeyError:
   4836         pass
   4837       self._op_to_kernel_label_map[op_type] = label
   4838     try:
   4839       yield  # The code within the context runs here.
   4840     finally:
   4841       # Remove the labels set for this context, and restore any saved labels.
   4842       for op_type, label in op_to_kernel_label_map.items():
   4843         try:
   4844           self._op_to_kernel_label_map[op_type] = saved_labels[op_type]
   4845         except KeyError:
   4846           del self._op_to_kernel_label_map[op_type]
   4847 
   4848   # pylint: enable=g-doc-return-or-yield
   4849 
   4850   # pylint: disable=g-doc-return-or-yield
   4851   @tf_contextlib.contextmanager
   4852   def gradient_override_map(self, op_type_map):
   4853     """EXPERIMENTAL: A context manager for overriding gradient functions.
   4854 
   4855     This context manager can be used to override the gradient function
   4856     that will be used for ops within the scope of the context.
   4857 
   4858     For example:
   4859 
   4860     ```python
   4861     @tf.RegisterGradient("CustomSquare")
   4862     def _custom_square_grad(op, grad):
   4863       # ...
   4864 
   4865     with tf.Graph().as_default() as g:
   4866       c = tf.constant(5.0)
   4867       s_1 = tf.square(c)  # Uses the default gradient for tf.square.
   4868       with g.gradient_override_map({"Square": "CustomSquare"}):
   4869         s_2 = tf.square(s_2)  # Uses _custom_square_grad to compute the
   4870                               # gradient of s_2.
   4871     ```
   4872 
   4873     Args:
   4874       op_type_map: A dictionary mapping op type strings to alternative op
   4875         type strings.
   4876 
   4877     Returns:
   4878       A context manager that sets the alternative op type to be used for one
   4879       or more ops created in that context.
   4880 
   4881     Raises:
   4882       TypeError: If `op_type_map` is not a dictionary mapping strings to
   4883         strings.
   4884     """
   4885     if not isinstance(op_type_map, dict):
   4886       raise TypeError("op_type_map must be a dictionary mapping "
   4887                       "strings to strings")
   4888     # The saved_mappings dictionary stores any currently-set mappings that
   4889     # will be overridden by this context manager.
   4890     saved_mappings = {}
   4891     # Install the given label
   4892     for op_type, mapped_op_type in op_type_map.items():
   4893       if not (isinstance(op_type, six.string_types) and
   4894               isinstance(mapped_op_type, six.string_types)):
   4895         raise TypeError("op_type_map must be a dictionary mapping "
   4896                         "strings to strings")
   4897       try:
   4898         saved_mappings[op_type] = self._gradient_override_map[op_type]
   4899       except KeyError:
   4900         pass
   4901       self._gradient_override_map[op_type] = mapped_op_type
   4902     try:
   4903       yield  # The code within the context runs here.
   4904     finally:
   4905       # Remove the labels set for this context, and restore any saved labels.
   4906       for op_type, mapped_op_type in op_type_map.items():
   4907         try:
   4908           self._gradient_override_map[op_type] = saved_mappings[op_type]
   4909         except KeyError:
   4910           del self._gradient_override_map[op_type]
   4911 
   4912   # pylint: enable=g-doc-return-or-yield
   4913 
   4914   def prevent_feeding(self, tensor):
   4915     """Marks the given `tensor` as unfeedable in this graph."""
   4916     self._unfeedable_tensors.add(tensor)
   4917 
   4918   def is_feedable(self, tensor):
   4919     """Returns `True` if and only if `tensor` is feedable."""
   4920     return tensor not in self._unfeedable_tensors
   4921 
   4922   def prevent_fetching(self, op):
   4923     """Marks the given `op` as unfetchable in this graph."""
   4924     self._unfetchable_ops.add(op)
   4925 
   4926   def is_fetchable(self, tensor_or_op):
   4927     """Returns `True` if and only if `tensor_or_op` is fetchable."""
   4928     if isinstance(tensor_or_op, Tensor):
   4929       return tensor_or_op.op not in self._unfetchable_ops
   4930     else:
   4931       return tensor_or_op not in self._unfetchable_ops
   4932 
   4933   def switch_to_thread_local(self):
   4934     """Make device, colocation and dependencies stacks thread-local.
   4935 
   4936     Device, colocation and dependencies stacks are not thread-local be default.
   4937     If multiple threads access them, then the state is shared.  This means that
   4938     one thread may affect the behavior of another thread.
   4939 
   4940     After this method is called, the stacks become thread-local.  If multiple
   4941     threads access them, then the state is not shared.  Each thread uses its own
   4942     value; a thread doesn't affect other threads by mutating such a stack.
   4943 
   4944     The initial value for every thread's stack is set to the current value
   4945     of the stack when `switch_to_thread_local()` was first called.
   4946     """
   4947     if not self._stack_state_is_thread_local:
   4948       self._stack_state_is_thread_local = True
   4949 
   4950   @property
   4951   def _device_function_stack(self):
   4952     if self._stack_state_is_thread_local:
   4953       # This may be called from a thread where device_function_stack doesn't yet
   4954       # exist.
   4955       # pylint: disable=protected-access
   4956       if not hasattr(self._thread_local, "_device_function_stack"):
   4957         stack_copy_for_this_thread = self._graph_device_function_stack.copy()
   4958         self._thread_local._device_function_stack = stack_copy_for_this_thread
   4959       return self._thread_local._device_function_stack
   4960       # pylint: enable=protected-access
   4961     else:
   4962       return self._graph_device_function_stack
   4963 
   4964   @property
   4965   def _device_functions_outer_to_inner(self):
   4966     user_device_specs = self._device_function_stack.peek_objs()
   4967     device_functions = [spec.function for spec in user_device_specs]
   4968     device_functions_outer_to_inner = list(reversed(device_functions))
   4969     return device_functions_outer_to_inner
   4970 
   4971   def _snapshot_device_function_stack_metadata(self):
   4972     """Return device function stack as a list of TraceableObjects.
   4973 
   4974     Returns:
   4975       [traceable_stack.TraceableObject, ...] where each TraceableObject's .obj
   4976       member is a displayable name for the user's argument to Graph.device, and
   4977       the filename and lineno members point to the code location where
   4978       Graph.device was called directly or indirectly by the user.
   4979     """
   4980     traceable_objects = self._device_function_stack.peek_traceable_objs()
   4981     snapshot = []
   4982     for obj in traceable_objects:
   4983       obj_copy = obj.copy_metadata()
   4984       obj_copy.obj = obj.obj.display_name
   4985       snapshot.append(obj_copy)
   4986     return snapshot
   4987 
   4988   @_device_function_stack.setter
   4989   def _device_function_stack(self, device_function_stack):
   4990     if self._stack_state_is_thread_local:
   4991       # pylint: disable=protected-access
   4992       self._thread_local._device_function_stack = device_function_stack
   4993       # pylint: enable=protected-access
   4994     else:
   4995       self._graph_device_function_stack = device_function_stack
   4996 
   4997   @property
   4998   def _colocation_stack(self):
   4999     """Return thread-local copy of colocation stack."""
   5000     if self._stack_state_is_thread_local:
   5001       # This may be called from a thread where colocation_stack doesn't yet
   5002       # exist.
   5003       # pylint: disable=protected-access
   5004       if not hasattr(self._thread_local, "_colocation_stack"):
   5005         stack_copy_for_this_thread = self._graph_colocation_stack.copy()
   5006         self._thread_local._colocation_stack = stack_copy_for_this_thread
   5007       return self._thread_local._colocation_stack
   5008       # pylint: enable=protected-access
   5009     else:
   5010       return self._graph_colocation_stack
   5011 
   5012   def _snapshot_colocation_stack_metadata(self):
   5013     """Return colocation stack metadata as a dictionary."""
   5014     traceable_objects = self._colocation_stack.peek_traceable_objs()
   5015     return {obj.obj.name: obj.copy_metadata() for obj in traceable_objects}
   5016 
   5017   @_colocation_stack.setter
   5018   def _colocation_stack(self, colocation_stack):
   5019     if self._stack_state_is_thread_local:
   5020       # pylint: disable=protected-access
   5021       self._thread_local._colocation_stack = colocation_stack
   5022       # pylint: enable=protected-access
   5023     else:
   5024       self._graph_colocation_stack = colocation_stack
   5025 
   5026   @property
   5027   def _control_dependencies_stack(self):
   5028     if self._stack_state_is_thread_local:
   5029       # This may be called from a thread where control_dependencies_stack
   5030       # doesn't yet exist.
   5031       if not hasattr(self._thread_local, "_control_dependencies_stack"):
   5032         self._thread_local._control_dependencies_stack = (
   5033             self._graph_control_dependencies_stack[:])
   5034       return self._thread_local._control_dependencies_stack
   5035     else:
   5036       return self._graph_control_dependencies_stack
   5037 
   5038   @_control_dependencies_stack.setter
   5039   def _control_dependencies_stack(self, control_dependencies):
   5040     if self._stack_state_is_thread_local:
   5041       self._thread_local._control_dependencies_stack = control_dependencies
   5042     else:
   5043       self._graph_control_dependencies_stack = control_dependencies
   5044 
   5045   @property
   5046   def _distribution_strategy_stack(self):
   5047     """A stack to maintain distribution strategy context for each thread."""
   5048     if not hasattr(self._thread_local, "_distribution_strategy_stack"):
   5049       self._thread_local._distribution_strategy_stack = []  # pylint: disable=protected-access
   5050     return self._thread_local._distribution_strategy_stack  # pylint: disable=protected-access
   5051 
   5052   @_distribution_strategy_stack.setter
   5053   def _distribution_strategy_stack(self, _distribution_strategy_stack):
   5054     self._thread_local._distribution_strategy_stack = (  # pylint: disable=protected-access
   5055         _distribution_strategy_stack)
   5056 
   5057   @property
   5058   def _auto_cast_variable_read_dtype(self):
   5059     """The dtype that instances of `AutoCastVariable` will be casted to.
   5060 
   5061     This is None if `AutoCastVariables` should not be casted.
   5062 
   5063     See `AutoCastVariable` for more information.
   5064 
   5065     Returns:
   5066       The dtype that instances of `AutoCastVariable` will be casted to.
   5067     """
   5068     if not hasattr(self._thread_local, "_auto_cast_variable_read_dtype"):
   5069       self._thread_local._auto_cast_variable_read_dtype = None  # pylint: disable=protected-access
   5070     return self._thread_local._auto_cast_variable_read_dtype  # pylint: disable=protected-access
   5071 
   5072   @_auto_cast_variable_read_dtype.setter
   5073   def _auto_cast_variable_read_dtype(self, _auto_cast_variable_read_dtype):
   5074     self._thread_local._auto_cast_variable_read_dtype = (  # pylint: disable=protected-access
   5075         _auto_cast_variable_read_dtype)
   5076 
   5077   @tf_contextlib.contextmanager
   5078   def _enable_auto_casting_variables(self, dtype):
   5079     """Context manager to automatically cast AutoCastVariables.
   5080 
   5081     If an AutoCastVariable `var` is used under this context manager, it will be
   5082     casted to `dtype` before being used.
   5083 
   5084     See `AutoCastVariable` for more information.
   5085 
   5086     Args:
   5087       dtype: The dtype that AutoCastVariables should be casted to.
   5088 
   5089     Yields:
   5090       Nothing.
   5091     """
   5092     prev_read_dtype = self._auto_cast_variable_read_dtype
   5093     try:
   5094       self._auto_cast_variable_read_dtype = dtype
   5095       yield
   5096     finally:
   5097       self._auto_cast_variable_read_dtype = prev_read_dtype
   5098 
   5099   def _mutation_lock(self):
   5100     """Returns a lock to guard code that creates & mutates ops.
   5101 
   5102     See the comment for self._group_lock for more info.
   5103     """
   5104     return self._group_lock.group(_MUTATION_LOCK_GROUP)
   5105 
   5106   def _session_run_lock(self):
   5107     """Returns a lock to guard code for Session.run.
   5108 
   5109     See the comment for self._group_lock for more info.
   5110     """
   5111     return self._group_lock.group(_SESSION_RUN_LOCK_GROUP)
   5112 
   5113 
   5114 # TODO(agarwal): currently device directives in an outer eager scope will not
   5115 # apply to inner graph mode code. Fix that.
   5116 
   5117 
   5118 @tf_export(v1=["device"])
   5119 def device(device_name_or_function):
   5120   """Wrapper for `Graph.device()` using the default graph.
   5121 
   5122   See
   5123   `tf.Graph.device`
   5124   for more details.
   5125 
   5126   Args:
   5127     device_name_or_function: The device name or function to use in
   5128       the context.
   5129 
   5130   Returns:
   5131     A context manager that specifies the default device to use for newly
   5132     created ops.
   5133 
   5134   Raises:
   5135     RuntimeError: If eager execution is enabled and a function is passed in.
   5136   """
   5137   if context.executing_eagerly():
   5138     # TODO(agarwal): support device functions in EAGER mode.
   5139     if callable(device_name_or_function):
   5140       raise RuntimeError(
   5141           "tf.device does not support functions when eager execution "
   5142           "is enabled.")
   5143     return context.device(device_name_or_function)
   5144   else:
   5145     return get_default_graph().device(device_name_or_function)
   5146 
   5147 
   5148 @tf_export("device", v1=[])
   5149 def device_v2(device_name):
   5150   """Specifies the device for ops created/executed in this context.
   5151 
   5152   `device_name` can be fully specified, as in "/job:worker/task:1/device:cpu:0",
   5153   or partially specified, containing only a subset of the "/"-separated
   5154   fields. Any fields which are specified override device annotations from outer
   5155   scopes. For example:
   5156 
   5157   with tf.device('/job:foo'):
   5158     # ops created here have devices with /job:foo
   5159     with tf.device('/job:bar/task:0/device:gpu:2'):
   5160       # ops created here have the fully specified device above
   5161     with tf.device('/device:gpu:1'):
   5162       # ops created here have the device '/job:foo/device:gpu:1'
   5163 
   5164   Args:
   5165     device_name: The device name to use in the context.
   5166 
   5167   Returns:
   5168     A context manager that specifies the default device to use for newly
   5169     created ops.
   5170 
   5171   Raises:
   5172     RuntimeError: If a function is passed in.
   5173   """
   5174   if callable(device_name):
   5175     raise RuntimeError("tf.device does not support functions.")
   5176   if context.executing_eagerly():
   5177     return context.device(device_name)
   5178   else:
   5179     return get_default_graph().device(device_name)
   5180 
   5181 
   5182 @tf_export(v1=["container"])
   5183 def container(container_name):
   5184   """Wrapper for `Graph.container()` using the default graph.
   5185 
   5186   Args:
   5187     container_name: The container string to use in the context.
   5188 
   5189   Returns:
   5190     A context manager that specifies the default container to use for newly
   5191     created stateful ops.
   5192   """
   5193   return get_default_graph().container(container_name)
   5194 
   5195 
   5196 def _colocate_with_for_gradient(op, gradient_uid, ignore_existing=False):
   5197   if context.executing_eagerly():
   5198     if op is not None:
   5199       if not hasattr(op, "device"):
   5200         op = internal_convert_to_tensor_or_indexed_slices(op)
   5201       return device(op.device)
   5202     else:
   5203       return NullContextmanager()
   5204   else:
   5205     default_graph = get_default_graph()
   5206     if isinstance(op, EagerTensor):
   5207       if default_graph.building_function:
   5208         return default_graph.device(op.device)
   5209       else:
   5210         raise ValueError("Encountered an Eager-defined Tensor during graph "
   5211                          "construction, but a function was not being built.")
   5212     return default_graph._colocate_with_for_gradient(
   5213         op, gradient_uid=gradient_uid, ignore_existing=ignore_existing)
   5214 
   5215 
   5216 # Internal interface to colocate_with. colocate_with has been deprecated from
   5217 # public API. There are still a few internal uses of colocate_with. Add internal
   5218 # only API for those uses to avoid deprecation warning.
   5219 def colocate_with(op, ignore_existing=False):
   5220   return _colocate_with_for_gradient(op, None, ignore_existing=ignore_existing)
   5221 
   5222 
   5223 @deprecation.deprecated(
   5224     date=None,
   5225     instructions="Colocations handled automatically by placer.")
   5226 @tf_export(v1=["colocate_with"])
   5227 def _colocate_with(op, ignore_existing=False):
   5228   return colocate_with(op, ignore_existing)
   5229 
   5230 
   5231 @tf_export("control_dependencies")
   5232 def control_dependencies(control_inputs):
   5233   """Wrapper for `Graph.control_dependencies()` using the default graph.
   5234 
   5235   See `tf.Graph.control_dependencies`
   5236   for more details.
   5237 
   5238   When eager execution is enabled, any callable object in the `control_inputs`
   5239   list will be called.
   5240 
   5241   Args:
   5242     control_inputs: A list of `Operation` or `Tensor` objects which
   5243       must be executed or computed before running the operations
   5244       defined in the context.  Can also be `None` to clear the control
   5245       dependencies. If eager execution is enabled, any callable object in the
   5246       `control_inputs` list will be called.
   5247 
   5248   Returns:
   5249    A context manager that specifies control dependencies for all
   5250    operations constructed within the context.
   5251   """
   5252   if context.executing_eagerly():
   5253     if control_inputs:
   5254       # Excute any pending callables.
   5255       for control in control_inputs:
   5256         if callable(control):
   5257           control()
   5258     return NullContextmanager()
   5259   else:
   5260     return get_default_graph().control_dependencies(control_inputs)
   5261 
   5262 
   5263 class _DefaultStack(threading.local):
   5264   """A thread-local stack of objects for providing implicit defaults."""
   5265 
   5266   def __init__(self):
   5267     super(_DefaultStack, self).__init__()
   5268     self._enforce_nesting = True
   5269     self.stack = []
   5270 
   5271   def get_default(self):
   5272     return self.stack[-1] if len(self.stack) >= 1 else None
   5273 
   5274   def reset(self):
   5275     self.stack = []
   5276 
   5277   def is_cleared(self):
   5278     return not self.stack
   5279 
   5280   @property
   5281   def enforce_nesting(self):
   5282     return self._enforce_nesting
   5283 
   5284   @enforce_nesting.setter
   5285   def enforce_nesting(self, value):
   5286     self._enforce_nesting = value
   5287 
   5288   @tf_contextlib.contextmanager
   5289   def get_controller(self, default):
   5290     """A context manager for manipulating a default stack."""
   5291     self.stack.append(default)
   5292     try:
   5293       yield default
   5294     finally:
   5295       # stack may be empty if reset() was called
   5296       if self.stack:
   5297         if self._enforce_nesting:
   5298           if self.stack[-1] is not default:
   5299             raise AssertionError(
   5300                 "Nesting violated for default stack of %s objects" %
   5301                 type(default))
   5302           self.stack.pop()
   5303         else:
   5304           self.stack.remove(default)
   5305 
   5306 
   5307 _default_session_stack = _DefaultStack()  # pylint: disable=protected-access
   5308 
   5309 
   5310 def default_session(session):
   5311   """Python "with" handler for defining a default session.
   5312 
   5313   This function provides a means of registering a session for handling
   5314   Tensor.eval() and Operation.run() calls. It is primarily intended for use
   5315   by session.Session, but can be used with any object that implements
   5316   the Session.run() interface.
   5317 
   5318   Use with the "with" keyword to specify that Tensor.eval() and Operation.run()
   5319   invocations within the scope of a block should be executed by a particular
   5320   session.
   5321 
   5322   The default session applies to the current thread only, so it is always
   5323   possible to inspect the call stack and determine the scope of a default
   5324   session. If you create a new thread, and wish to use the default session
   5325   in that thread, you must explicitly add a "with ops.default_session(sess):"
   5326   block in that thread's function.
   5327 
   5328   Example:
   5329     The following code examples are equivalent:
   5330 
   5331     # 1. Using the Session object directly:
   5332     sess = ...
   5333     c = tf.constant(5.0)
   5334     sess.run(c)
   5335 
   5336     # 2. Using default_session():
   5337     sess = ...
   5338     with ops.default_session(sess):
   5339       c = tf.constant(5.0)
   5340       result = c.eval()
   5341 
   5342     # 3. Overriding default_session():
   5343     sess = ...
   5344     with ops.default_session(sess):
   5345       c = tf.constant(5.0)
   5346       with ops.default_session(...):
   5347         c.eval(session=sess)
   5348 
   5349   Args:
   5350     session: The session to be installed as the default session.
   5351 
   5352   Returns:
   5353     A context manager for the default session.
   5354   """
   5355   return _default_session_stack.get_controller(session)
   5356 
   5357 
   5358 @tf_export(v1=["get_default_session"])
   5359 def get_default_session():
   5360   """Returns the default session for the current thread.
   5361 
   5362   The returned `Session` will be the innermost session on which a
   5363   `Session` or `Session.as_default()` context has been entered.
   5364 
   5365   NOTE: The default session is a property of the current thread. If you
   5366   create a new thread, and wish to use the default session in that
   5367   thread, you must explicitly add a `with sess.as_default():` in that
   5368   thread's function.
   5369 
   5370   Returns:
   5371     The default `Session` being used in the current thread.
   5372   """
   5373   return _default_session_stack.get_default()
   5374 
   5375 
   5376 def _eval_using_default_session(tensors, feed_dict, graph, session=None):
   5377   """Uses the default session to evaluate one or more tensors.
   5378 
   5379   Args:
   5380     tensors: A single Tensor, or a list of Tensor objects.
   5381     feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
   5382       numpy ndarrays, TensorProtos, or strings.
   5383     graph: The graph in which the tensors are defined.
   5384     session: (Optional) A different session to use to evaluate "tensors".
   5385 
   5386   Returns:
   5387     Either a single numpy ndarray if "tensors" is a single tensor; or a list
   5388     of numpy ndarrays that each correspond to the respective element in
   5389     "tensors".
   5390 
   5391   Raises:
   5392     ValueError: If no default session is available; the default session
   5393       does not have "graph" as its graph; or if "session" is specified,
   5394       and it does not have "graph" as its graph.
   5395   """
   5396   if session is None:
   5397     session = get_default_session()
   5398     if session is None:
   5399       raise ValueError("Cannot evaluate tensor using `eval()`: No default "
   5400                        "session is registered. Use `with "
   5401                        "sess.as_default()` or pass an explicit session to "
   5402                        "`eval(session=sess)`")
   5403     if session.graph is not graph:
   5404       raise ValueError("Cannot use the default session to evaluate tensor: "
   5405                        "the tensor's graph is different from the session's "
   5406                        "graph. Pass an explicit session to "
   5407                        "`eval(session=sess)`.")
   5408   else:
   5409     if session.graph is not graph:
   5410       raise ValueError("Cannot use the given session to evaluate tensor: "
   5411                        "the tensor's graph is different from the session's "
   5412                        "graph.")
   5413   return session.run(tensors, feed_dict)
   5414 
   5415 
   5416 def _run_using_default_session(operation, feed_dict, graph, session=None):
   5417   """Uses the default session to run "operation".
   5418 
   5419   Args:
   5420     operation: The Operation to be run.
   5421     feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists,
   5422       numpy ndarrays, TensorProtos, or strings.
   5423     graph: The graph in which "operation" is defined.
   5424     session: (Optional) A different session to use to run "operation".
   5425 
   5426   Raises:
   5427     ValueError: If no default session is available; the default session
   5428       does not have "graph" as its graph; or if "session" is specified,
   5429       and it does not have "graph" as its graph.
   5430   """
   5431   if session is None:
   5432     session = get_default_session()
   5433     if session is None:
   5434       raise ValueError("Cannot execute operation using `run()`: No default "
   5435                        "session is registered. Use `with "
   5436                        "sess.as_default():` or pass an explicit session to "
   5437                        "`run(session=sess)`")
   5438     if session.graph is not graph:
   5439       raise ValueError("Cannot use the default session to execute operation: "
   5440                        "the operation's graph is different from the "
   5441                        "session's graph. Pass an explicit session to "
   5442                        "run(session=sess).")
   5443   else:
   5444     if session.graph is not graph:
   5445       raise ValueError("Cannot use the given session to execute operation: "
   5446                        "the operation's graph is different from the session's "
   5447                        "graph.")
   5448   session.run(operation, feed_dict)
   5449 
   5450 
   5451 class _DefaultGraphStack(_DefaultStack):  # pylint: disable=protected-access
   5452   """A thread-local stack of objects for providing an implicit default graph."""
   5453 
   5454   def __init__(self):
   5455     super(_DefaultGraphStack, self).__init__()
   5456     self._global_default_graph = None
   5457 
   5458   def get_default(self):
   5459     """Override that returns a global default if the stack is empty."""
   5460     ret = super(_DefaultGraphStack, self).get_default()
   5461     if ret is None:
   5462       ret = self._GetGlobalDefaultGraph()
   5463     return ret
   5464 
   5465   def _GetGlobalDefaultGraph(self):
   5466     if self._global_default_graph is None:
   5467       # TODO(mrry): Perhaps log that the default graph is being used, or set
   5468       #   provide some other feedback to prevent confusion when a mixture of
   5469       #   the global default graph and an explicit graph are combined in the
   5470       #   same process.
   5471       self._global_default_graph = Graph()
   5472     return self._global_default_graph
   5473 
   5474   def reset(self):
   5475     super(_DefaultGraphStack, self).reset()
   5476     self._global_default_graph = None
   5477 
   5478   @tf_contextlib.contextmanager
   5479   def get_controller(self, default):
   5480     context.context().context_switches.push(
   5481         default.building_function, default.as_default,
   5482         default._device_function_stack)
   5483     try:
   5484       with super(_DefaultGraphStack, self).get_controller(
   5485           default) as g, context.graph_mode():
   5486         yield g
   5487     finally:
   5488       # If an exception is raised here it may be hiding a related exception in
   5489       # the try-block (just above).
   5490       context.context().context_switches.pop()
   5491 
   5492 
   5493 _default_graph_stack = _DefaultGraphStack()
   5494 
   5495 
   5496 # pylint: disable=g-doc-return-or-yield,line-too-long
   5497 @tf_export("init_scope")
   5498 @tf_contextlib.contextmanager
   5499 def init_scope():
   5500   """A context manager that lifts ops out of control-flow scopes and function-building graphs.
   5501 
   5502   There is often a need to lift variable initialization ops out of control-flow
   5503   scopes, function-building graphs, and gradient tapes. Entering an
   5504   `init_scope` is a mechanism for satisfying these desiderata. In particular,
   5505   entering an `init_scope` has three effects:
   5506 
   5507     (1) All control dependencies are cleared the moment the scope is entered;
   5508         this is equivalent to entering the context manager returned from
   5509         `control_dependencies(None)`, which has the side-effect of exiting
   5510         control-flow scopes like `tf.cond` and `tf.while_loop`.
   5511 
   5512     (2) All operations that are created while the scope is active are lifted
   5513         into the lowest context on the `context_stack` that is not building a
   5514         graph function. Here, a context is defined as either a graph or an eager
   5515         context. Every context switch, i.e., every installation of a graph as
   5516         the default graph and every switch into eager mode, is logged in a
   5517         thread-local stack called `context_switches`; the log entry for a
   5518         context switch is popped from the stack when the context is exited.
   5519         Entering an `init_scope` is equivalent to crawling up
   5520         `context_switches`, finding the first context that is not building a
   5521         graph function, and entering it. A caveat is that if graph mode is
   5522         enabled but the default graph stack is empty, then entering an
   5523         `init_scope` will simply install a fresh graph as the default one.
   5524 
   5525     (3) The gradient tape is paused while the scope is active.
   5526 
   5527   When eager execution is enabled, code inside an init_scope block runs with
   5528   eager execution enabled even when defining graph functions via
   5529   tf.contrib.eager.defun. For example:
   5530 
   5531   ```python
   5532   tf.enable_eager_execution()
   5533 
   5534   @tf.contrib.eager.defun
   5535   def func():
   5536     # A defun-decorated function constructs TensorFlow graphs,
   5537     # it does not execute eagerly.
   5538     assert not tf.executing_eagerly()
   5539     with tf.init_scope():
   5540       # Initialization runs with eager execution enabled
   5541       assert tf.executing_eagerly()
   5542   ```
   5543 
   5544   Raises:
   5545     RuntimeError: if graph state is incompatible with this initialization.
   5546   """
   5547   # pylint: enable=g-doc-return-or-yield,line-too-long
   5548 
   5549   if context.executing_eagerly():
   5550     # Fastpath.
   5551     with tape.stop_recording():
   5552       yield
   5553   else:
   5554     # Retrieve the active name scope: entering an `init_scope` preserves
   5555     # the name scope of the current context.
   5556     default_graph = get_default_graph()
   5557     scope = default_graph.get_name_scope()
   5558     if scope and scope[-1] != "/":
   5559       # Names that end with trailing slashes are treated by `name_scope` as
   5560       # absolute.
   5561       scope = scope + "/"
   5562     innermost_nonempty_device_stack = default_graph._device_function_stack  # pylint: disable=protected-access
   5563 
   5564     outer_context = None
   5565     if not _default_graph_stack.stack:
   5566       # If the default graph stack is empty, then we cannot be building a
   5567       # function. Install the global graph (which, in this case, is also the
   5568       # default graph) as the outer context.
   5569       if default_graph.building_function:
   5570         raise RuntimeError("The global graph is building a function.")
   5571       outer_context = default_graph.as_default
   5572     else:
   5573       # Find a context that is not building a function.
   5574       for stack_entry in reversed(context.context().context_switches.stack):
   5575         if not innermost_nonempty_device_stack:
   5576           innermost_nonempty_device_stack = stack_entry.device_stack
   5577         if not stack_entry.is_building_function:
   5578           outer_context = stack_entry.enter_context_fn
   5579           break
   5580 
   5581       if outer_context is None:
   5582         # As a last resort, obtain the global default graph; this graph doesn't
   5583         # necessarily live on the graph stack (and hence it doesn't necessarily
   5584         # live on the context stack), but it is stored in the graph stack's
   5585         # encapsulating object.
   5586         outer_context = _default_graph_stack._GetGlobalDefaultGraph().as_default  # pylint: disable=protected-access
   5587 
   5588     if outer_context is None:
   5589       # Sanity check; this shouldn't be triggered.
   5590       raise RuntimeError("All graphs are building functions, and no "
   5591                          "eager context was previously active.")
   5592 
   5593     outer_graph = None
   5594     outer_device_stack = None
   5595     try:
   5596       with outer_context(), name_scope(scope), control_dependencies(
   5597           None), tape.stop_recording():
   5598         context_manager = NullContextmanager
   5599         context_manager_input = None
   5600         if not context.executing_eagerly():
   5601           # The device stack is preserved when lifting into a graph. Eager
   5602           # execution doesn't implement device stacks and in particular it
   5603           # doesn't support device functions, so in general it's not possible
   5604           # to do the same when lifting into the eager context.
   5605           outer_graph = get_default_graph()
   5606           outer_device_stack = outer_graph._device_function_stack  # pylint: disable=protected-access
   5607           outer_graph._device_function_stack = innermost_nonempty_device_stack  # pylint: disable=protected-access
   5608         elif innermost_nonempty_device_stack is not None:
   5609           for device_spec in innermost_nonempty_device_stack.peek_objs():
   5610             if device_spec.function is None:
   5611               break
   5612             if device_spec.raw_string:
   5613               context_manager = context.device
   5614               context_manager_input = device_spec.raw_string
   5615               break
   5616             # It is currently not possible to have a device function in V2,
   5617             # but in V1 we are unable to apply device functions in eager mode.
   5618             # This means that we will silently skip some of the entries on the
   5619             # device stack in V1 + eager mode.
   5620 
   5621         with context_manager(context_manager_input):
   5622           yield
   5623     finally:
   5624       # If an exception is raised here it may be hiding a related exception in
   5625       # try-block (just above).
   5626       if outer_graph is not None:
   5627         outer_graph._device_function_stack = outer_device_stack  # pylint: disable=protected-access
   5628 
   5629 
   5630 def executing_eagerly_outside_functions():
   5631   """Returns True if executing eagerly, even if inside a graph function."""
   5632   # Fastpath for when this is called eagerly (its not necessary to init_scope).
   5633   if context.executing_eagerly():
   5634     return True
   5635 
   5636   with init_scope():
   5637     return context.executing_eagerly()
   5638 
   5639 
   5640 def inside_function():
   5641   return get_default_graph().building_function
   5642 
   5643 
   5644 @tf_export(v1=["enable_eager_execution"])
   5645 def enable_eager_execution(config=None,
   5646                            device_policy=None,
   5647                            execution_mode=None):
   5648   """Enables eager execution for the lifetime of this program.
   5649 
   5650   Eager execution provides an imperative interface to TensorFlow. With eager
   5651   execution enabled, TensorFlow functions execute operations immediately (as
   5652   opposed to adding to a graph to be executed later in a `tf.Session`) and
   5653   return concrete values (as opposed to symbolic references to a node in a
   5654   computational graph).
   5655 
   5656   For example:
   5657 
   5658   ```python
   5659   tf.enable_eager_execution()
   5660 
   5661   # After eager execution is enabled, operations are executed as they are
   5662   # defined and Tensor objects hold concrete values, which can be accessed as
   5663   # numpy.ndarray`s through the numpy() method.
   5664   assert tf.multiply(6, 7).numpy() == 42
   5665   ```
   5666 
   5667   Eager execution cannot be enabled after TensorFlow APIs have been used to
   5668   create or execute graphs. It is typically recommended to invoke this function
   5669   at program startup and not in a library (as most libraries should be usable
   5670   both with and without eager execution).
   5671 
   5672   Args:
   5673     config: (Optional.) A `tf.ConfigProto` to use to configure the environment
   5674       in which operations are executed. Note that `tf.ConfigProto` is also
   5675       used to configure graph execution (via `tf.Session`) and many options
   5676       within `tf.ConfigProto` are not implemented (or are irrelevant) when
   5677       eager execution is enabled.
   5678     device_policy: (Optional.) Policy controlling how operations requiring
   5679       inputs on a specific device (e.g., a GPU 0) handle inputs on a different
   5680       device  (e.g. GPU 1 or CPU). When set to None, an appropriate value will be
   5681       picked automatically. The value picked may change between TensorFlow
   5682       releases.
   5683       Valid values:
   5684       - tf.contrib.eager.DEVICE_PLACEMENT_EXPLICIT: raises an error if the
   5685         placement is not correct.
   5686       - tf.contrib.eager.DEVICE_PLACEMENT_WARN: copies the tensors which are not
   5687         on the right device but logs a warning.
   5688       - tf.contrib.eager.DEVICE_PLACEMENT_SILENT: silently copies the tensors.
   5689         Note that this may hide performance problems as there is no notification
   5690         provided when operations are blocked on the tensor being copied between
   5691         devices.
   5692       - tf.contrib.eager.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies
   5693         int32 tensors, raising errors on the other ones.
   5694     execution_mode: (Optional.) Policy controlling how operations dispatched are
   5695       actually executed. When set to None, an appropriate value will be picked
   5696       automatically. The value picked may change between TensorFlow releases.
   5697       Valid values:
   5698       - tf.contrib.eager.SYNC: executes each operation synchronously.
   5699       - tf.contrib.eager.ASYNC: executes each operation asynchronously. These
   5700         operations may return "non-ready" handles.
   5701 
   5702   Raises:
   5703     ValueError: If eager execution is enabled after creating/executing a
   5704      TensorFlow graph, or if options provided conflict with a previous call
   5705      to this function.
   5706   """
   5707   if context.default_execution_mode != context.EAGER_MODE:
   5708     return enable_eager_execution_internal(
   5709         config=config,
   5710         device_policy=device_policy,
   5711         execution_mode=execution_mode,
   5712         server_def=None)
   5713 
   5714 
   5715 @tf_export(v1=["disable_eager_execution"])
   5716 def disable_eager_execution():
   5717   """Disables eager execution.
   5718 
   5719   This function can only be called before any Graphs, Ops, or Tensors have been
   5720   created. It can be used at the beginning of the program for complex migration
   5721   projects from TensorFlow 1.x to 2.x.
   5722   """
   5723   context.default_execution_mode = context.GRAPH_MODE
   5724   c = context.context_safe()
   5725   if c is not None:
   5726     c._thread_local_data.is_eager = False  # pylint: disable=protected-access
   5727 
   5728 
   5729 def enable_eager_execution_internal(config=None,
   5730                                     device_policy=None,
   5731                                     execution_mode=None,
   5732                                     server_def=None):
   5733   """Enables eager execution for the lifetime of this program.
   5734 
   5735   Most of the doc string for enable_eager_execution is relevant here as well.
   5736 
   5737   Args:
   5738     config: See enable_eager_execution doc string
   5739     device_policy: See enable_eager_execution doc string
   5740     execution_mode: See enable_eager_execution doc string
   5741     server_def: (Optional.) A tensorflow::ServerDef proto.
   5742       Enables execution on remote devices. GrpcServers need to be started by
   5743       creating an identical server_def to this, and setting the appropriate
   5744       task_indexes, so that the servers can communicate. It will then be
   5745       possible to execute operations on remote devices.
   5746 
   5747   Raises:
   5748     ValueError
   5749 
   5750   """
   5751   if config is not None and not isinstance(config, config_pb2.ConfigProto):
   5752     raise TypeError(
   5753         "config must be a tf.ConfigProto, but got %s" % type(config))
   5754   if device_policy not in (None, context.DEVICE_PLACEMENT_EXPLICIT,
   5755                            context.DEVICE_PLACEMENT_WARN,
   5756                            context.DEVICE_PLACEMENT_SILENT,
   5757                            context.DEVICE_PLACEMENT_SILENT_FOR_INT32):
   5758     raise ValueError(
   5759         "device_policy must be one of None, tf.contrib.eager.DEVICE_PLACEMENT_*"
   5760     )
   5761   if execution_mode not in (None, context.SYNC, context.ASYNC):
   5762     raise ValueError(
   5763         "execution_mode must be one of None, tf.contrib.eager.SYNC, "
   5764         "tf.contrib.eager.ASYNC")
   5765   if context.default_execution_mode == context.GRAPH_MODE:
   5766     graph_mode_has_been_used = (
   5767         _default_graph_stack._global_default_graph is not None) # pylint: disable=protected-access
   5768     if graph_mode_has_been_used:
   5769       raise ValueError(
   5770           "tf.enable_eager_execution must be called at program startup.")
   5771   context.default_execution_mode = context.EAGER_MODE
   5772   # pylint: disable=protected-access
   5773   if context._context is None:
   5774     context._context = context.Context(
   5775         config=config,
   5776         device_policy=device_policy,
   5777         execution_mode=execution_mode,
   5778         server_def=server_def)
   5779   elif ((config is not None and config is not context._context._config) or
   5780         (device_policy is not None and
   5781          device_policy is not context._context._device_policy) or
   5782         (execution_mode is not None and
   5783          execution_mode is not context._context._execution_mode)):
   5784     raise ValueError("Trying to change the options of an active eager"
   5785                      " execution. Context config: %s, specified config:"
   5786                      " %s. Context device policy: %s, specified device"
   5787                      " policy: %s. Context execution mode: %s, "
   5788                      " specified execution mode %s." %
   5789                      (context._context._config, config,
   5790                       context._context._device_policy, device_policy,
   5791                       context._context._execution_mode, execution_mode))
   5792   else:
   5793     raise ValueError(
   5794         "tf.enable_eager_execution must be called at program startup.")
   5795 
   5796   # Monkey patch to get rid of an unnecessary conditional since the context is
   5797   # now initialized.
   5798   context.context = context.context_safe
   5799 
   5800 
   5801 def eager_run(main=None, argv=None):
   5802   """Runs the program with an optional main function and argv list.
   5803 
   5804   The program will run with eager execution enabled.
   5805 
   5806   Example:
   5807   ```python
   5808   import tensorflow as tf
   5809   # Import subject to future changes:
   5810   from tensorflow.contrib.eager.python import tfe
   5811 
   5812   def main(_):
   5813     u = tf.constant(6.0)
   5814     v = tf.constant(7.0)
   5815     print(u * v)
   5816 
   5817   if __name__ == "__main__":
   5818     tfe.run()
   5819   ```
   5820 
   5821   Args:
   5822     main: the main function to run.
   5823     argv: the arguments to pass to it.
   5824   """
   5825   enable_eager_execution()
   5826   app.run(main, argv)
   5827 
   5828 
   5829 @tf_export(v1=["reset_default_graph"])
   5830 def reset_default_graph():
   5831   """Clears the default graph stack and resets the global default graph.
   5832 
   5833   NOTE: The default graph is a property of the current thread. This
   5834   function applies only to the current thread.  Calling this function while
   5835   a `tf.Session` or `tf.InteractiveSession` is active will result in undefined
   5836   behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects
   5837   after calling this function will result in undefined behavior.
   5838   Raises:
   5839     AssertionError: If this function is called within a nested graph.
   5840   """
   5841   if not _default_graph_stack.is_cleared():
   5842     raise AssertionError("Do not use tf.reset_default_graph() to clear "
   5843                          "nested graphs. If you need a cleared graph, "
   5844                          "exit the nesting and create a new graph.")
   5845   _default_graph_stack.reset()
   5846 
   5847 
   5848 @tf_export(v1=["get_default_graph"])
   5849 def get_default_graph():
   5850   """Returns the default graph for the current thread.
   5851 
   5852   The returned graph will be the innermost graph on which a
   5853   `Graph.as_default()` context has been entered, or a global default
   5854   graph if none has been explicitly created.
   5855 
   5856   NOTE: The default graph is a property of the current thread. If you
   5857   create a new thread, and wish to use the default graph in that
   5858   thread, you must explicitly add a `with g.as_default():` in that
   5859   thread's function.
   5860 
   5861   Returns:
   5862     The default `Graph` being used in the current thread.
   5863   """
   5864   return _default_graph_stack.get_default()
   5865 
   5866 def has_default_graph():
   5867   """Returns True if there is a default graph."""
   5868   return len(_default_graph_stack.stack) >= 1
   5869 
   5870 
   5871 def get_name_scope():
   5872   """Returns the current name scope in the default_graph.
   5873 
   5874   For example:
   5875 
   5876   ```python
   5877   with tf.name_scope('scope1'):
   5878     with tf.name_scope('scope2'):
   5879       print(tf.get_name_scope())
   5880   ```
   5881   would print the string `scope1/scope2`.
   5882 
   5883   Returns:
   5884     A string representing the current name scope.
   5885   """
   5886   if context.executing_eagerly():
   5887     return context.context().scope_name.rstrip("/")
   5888   return get_default_graph().get_name_scope()
   5889 
   5890 
   5891 def _assert_same_graph(original_item, item):
   5892   """Fail if the 2 items are from different graphs.
   5893 
   5894   Args:
   5895     original_item: Original item to check against.
   5896     item: Item to check.
   5897 
   5898   Raises:
   5899     ValueError: if graphs do not match.
   5900   """
   5901   if original_item.graph is not item.graph:
   5902     raise ValueError("%s must be from the same graph as %s." % (item,
   5903                                                                 original_item))
   5904 
   5905 
   5906 def _get_graph_from_inputs(op_input_list, graph=None):
   5907   """Returns the appropriate graph to use for the given inputs.
   5908 
   5909   This library method provides a consistent algorithm for choosing the graph
   5910   in which an Operation should be constructed:
   5911 
   5912   1. If the default graph is being used to construct a function, we
   5913      use the default graph.
   5914   2. If the "graph" is specified explicitly, we validate that all of the inputs
   5915      in "op_input_list" are compatible with that graph.
   5916   3. Otherwise, we attempt to select a graph from the first Operation-
   5917      or Tensor-valued input in "op_input_list", and validate that all other
   5918      such inputs are in the same graph.
   5919   4. If the graph was not specified and it could not be inferred from
   5920      "op_input_list", we attempt to use the default graph.
   5921 
   5922   Args:
   5923     op_input_list: A list of inputs to an operation, which may include `Tensor`,
   5924       `Operation`, and other objects that may be converted to a graph element.
   5925     graph: (Optional) The explicit graph to use.
   5926 
   5927   Raises:
   5928     TypeError: If op_input_list is not a list or tuple, or if graph is not a
   5929       Graph.
   5930     ValueError: If a graph is explicitly passed and not all inputs are from it,
   5931       or if the inputs are from multiple graphs, or we could not find a graph
   5932       and there was no default graph.
   5933 
   5934   Returns:
   5935     The appropriate graph to use for the given inputs.
   5936 
   5937   """
   5938   if get_default_graph().building_function:
   5939     return get_default_graph()
   5940 
   5941   op_input_list = tuple(op_input_list)  # Handle generators correctly
   5942   if graph and not isinstance(graph, Graph):
   5943     raise TypeError("Input graph needs to be a Graph: %s" % graph)
   5944 
   5945   # 1. We validate that all of the inputs are from the same graph. This is
   5946   #    either the supplied graph parameter, or the first one selected from one
   5947   #    the graph-element-valued inputs. In the latter case, we hold onto
   5948   #    that input in original_graph_element so we can provide a more
   5949   #    informative error if a mismatch is found.
   5950   original_graph_element = None
   5951   for op_input in op_input_list:
   5952     # Determine if this is a valid graph_element.
   5953     # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this
   5954     # up.
   5955     graph_element = None
   5956     if (isinstance(op_input, (Operation, _TensorLike)) and
   5957         ((not isinstance(op_input, Tensor)) or type(op_input) == Tensor)):  # pylint: disable=unidiomatic-typecheck
   5958       graph_element = op_input
   5959     else:
   5960       graph_element = _as_graph_element(op_input)
   5961 
   5962     if graph_element is not None:
   5963       if not graph:
   5964         original_graph_element = graph_element
   5965         graph = graph_element.graph
   5966       elif original_graph_element is not None:
   5967         _assert_same_graph(original_graph_element, graph_element)
   5968       elif graph_element.graph is not graph:
   5969         raise ValueError("%s is not from the passed-in graph." % graph_element)
   5970 
   5971   # 2. If all else fails, we use the default graph, which is always there.
   5972   return graph or get_default_graph()
   5973 
   5974 
   5975 @tf_export(v1=["GraphKeys"])
   5976 class GraphKeys(object):
   5977   """Standard names to use for graph collections.
   5978 
   5979   The standard library uses various well-known names to collect and
   5980   retrieve values associated with a graph. For example, the
   5981   `tf.Optimizer` subclasses default to optimizing the variables
   5982   collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is
   5983   specified, but it is also possible to pass an explicit list of
   5984   variables.
   5985 
   5986   The following standard keys are defined:
   5987 
   5988   * `GLOBAL_VARIABLES`: the default collection of `Variable` objects, shared
   5989     across distributed environment (model variables are subset of these). See
   5990     `tf.global_variables`
   5991     for more details.
   5992     Commonly, all `TRAINABLE_VARIABLES` variables will be in `MODEL_VARIABLES`,
   5993     and all `MODEL_VARIABLES` variables will be in `GLOBAL_VARIABLES`.
   5994   * `LOCAL_VARIABLES`: the subset of `Variable` objects that are local to each
   5995     machine. Usually used for temporarily variables, like counters.
   5996     Note: use `tf.contrib.framework.local_variable` to add to this collection.
   5997   * `MODEL_VARIABLES`: the subset of `Variable` objects that are used in the
   5998     model for inference (feed forward). Note: use
   5999     `tf.contrib.framework.model_variable` to add to this collection.
   6000   * `TRAINABLE_VARIABLES`: the subset of `Variable` objects that will
   6001     be trained by an optimizer. See
   6002     `tf.trainable_variables`
   6003     for more details.
   6004   * `SUMMARIES`: the summary `Tensor` objects that have been created in the
   6005     graph. See
   6006     `tf.summary.merge_all`
   6007     for more details.
   6008   * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to
   6009     produce input for a computation. See
   6010     `tf.train.start_queue_runners`
   6011     for more details.
   6012   * `MOVING_AVERAGE_VARIABLES`: the subset of `Variable` objects that will also
   6013     keep moving averages.  See
   6014     `tf.moving_average_variables`
   6015     for more details.
   6016   * `REGULARIZATION_LOSSES`: regularization losses collected during graph
   6017     construction.
   6018 
   6019   The following standard keys are _defined_, but their collections are **not**
   6020   automatically populated as many of the others are:
   6021 
   6022   * `WEIGHTS`
   6023   * `BIASES`
   6024   * `ACTIVATIONS`
   6025   """
   6026 
   6027   # Key to collect Variable objects that are global (shared across machines).
   6028   # Default collection for all variables, except local ones.
   6029   GLOBAL_VARIABLES = "variables"
   6030   # Key to collect local variables that are local to the machine and are not
   6031   # saved/restored.
   6032   LOCAL_VARIABLES = "local_variables"
   6033   # Key to collect local variables which are used to accumulate interal state
   6034   # to be used in tf.metrics.*.
   6035   METRIC_VARIABLES = "metric_variables"
   6036   # Key to collect model variables defined by layers.
   6037   MODEL_VARIABLES = "model_variables"
   6038   # Key to collect Variable objects that will be trained by the
   6039   # optimizers.
   6040   TRAINABLE_VARIABLES = "trainable_variables"
   6041   # Key to collect summaries.
   6042   SUMMARIES = "summaries"
   6043   # Key to collect QueueRunners.
   6044   QUEUE_RUNNERS = "queue_runners"
   6045   # Key to collect table initializers.
   6046   TABLE_INITIALIZERS = "table_initializer"
   6047   # Key to collect asset filepaths. An asset represents an external resource
   6048   # like a vocabulary file.
   6049   ASSET_FILEPATHS = "asset_filepaths"
   6050   # Key to collect Variable objects that keep moving averages.
   6051   MOVING_AVERAGE_VARIABLES = "moving_average_variables"
   6052   # Key to collect regularization losses at graph construction.
   6053   REGULARIZATION_LOSSES = "regularization_losses"
   6054   # Key to collect concatenated sharded variables.
   6055   CONCATENATED_VARIABLES = "concatenated_variables"
   6056   # Key to collect savers.
   6057   SAVERS = "savers"
   6058   # Key to collect weights
   6059   WEIGHTS = "weights"
   6060   # Key to collect biases
   6061   BIASES = "biases"
   6062   # Key to collect activations
   6063   ACTIVATIONS = "activations"
   6064   # Key to collect update_ops
   6065   UPDATE_OPS = "update_ops"
   6066   # Key to collect losses
   6067   LOSSES = "losses"
   6068   # Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
   6069   SAVEABLE_OBJECTS = "saveable_objects"
   6070   # Key to collect all shared resources used by the graph which need to be
   6071   # initialized once per cluster.
   6072   RESOURCES = "resources"
   6073   # Key to collect all shared resources used in this graph which need to be
   6074   # initialized once per session.
   6075   LOCAL_RESOURCES = "local_resources"
   6076   # Trainable resource-style variables.
   6077   TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"
   6078 
   6079   # Key to indicate various ops.
   6080   INIT_OP = "init_op"
   6081   LOCAL_INIT_OP = "local_init_op"
   6082   READY_OP = "ready_op"
   6083   READY_FOR_LOCAL_INIT_OP = "ready_for_local_init_op"
   6084   SUMMARY_OP = "summary_op"
   6085   GLOBAL_STEP = "global_step"
   6086 
   6087   # Used to count the number of evaluations performed during a single evaluation
   6088   # run.
   6089   EVAL_STEP = "eval_step"
   6090   TRAIN_OP = "train_op"
   6091 
   6092   # Key for control flow context.
   6093   COND_CONTEXT = "cond_context"
   6094   WHILE_CONTEXT = "while_context"
   6095 
   6096   # Used to store v2 summary names.
   6097   _SUMMARY_COLLECTION = "_SUMMARY_V2"
   6098 
   6099   # List of all collections that keep track of variables.
   6100   _VARIABLE_COLLECTIONS = [
   6101       GLOBAL_VARIABLES,
   6102       LOCAL_VARIABLES,
   6103       METRIC_VARIABLES,
   6104       MODEL_VARIABLES,
   6105       TRAINABLE_VARIABLES,
   6106       MOVING_AVERAGE_VARIABLES,
   6107       CONCATENATED_VARIABLES,
   6108       TRAINABLE_RESOURCE_VARIABLES,
   6109   ]
   6110 
   6111   # Key for streaming model ports.
   6112   # NOTE(yuanbyu): internal and experimental.
   6113   _STREAMING_MODEL_PORTS = "streaming_model_ports"
   6114 
   6115   @decorator_utils.classproperty
   6116   @deprecation.deprecated(None, "Use `tf.GraphKeys.GLOBAL_VARIABLES` instead.")
   6117   def VARIABLES(cls):  # pylint: disable=no-self-argument
   6118     return cls.GLOBAL_VARIABLES
   6119 
   6120 
   6121 def dismantle_graph(graph):
   6122   """Cleans up reference cycles from a `Graph`.
   6123 
   6124   Helpful for making sure the garbage collector doesn't need to run after a
   6125   temporary `Graph` is no longer needed.
   6126 
   6127   Args:
   6128     graph: A `Graph` object to destroy. Neither it nor any of its ops are usable
   6129       after this function runs.
   6130   """
   6131   memory.dismantle_ordered_dict(graph._functions)  # pylint: disable=protected-access
   6132 
   6133   # Now clean up Operation<->Graph reference cycles by clearing all of the
   6134   # attributes for the Graph and its ops.
   6135   graph_operations = graph.get_operations()
   6136   for op in graph_operations:
   6137     op.__dict__ = {}
   6138   graph.__dict__ = {}
   6139 
   6140 
   6141 @tf_export(v1=["add_to_collection"])
   6142 def add_to_collection(name, value):
   6143   """Wrapper for `Graph.add_to_collection()` using the default graph.
   6144 
   6145   See `tf.Graph.add_to_collection`
   6146   for more details.
   6147 
   6148   Args:
   6149     name: The key for the collection. For example, the `GraphKeys` class
   6150       contains many standard names for collections.
   6151     value: The value to add to the collection.
   6152 
   6153   @compatibility(eager)
   6154   Collections are only supported in eager when variables are created inside an
   6155   EagerVariableStore (e.g. as part of a layer or template).
   6156   @end_compatibility
   6157   """
   6158   get_default_graph().add_to_collection(name, value)
   6159 
   6160 
   6161 @tf_export(v1=["add_to_collections"])
   6162 def add_to_collections(names, value):
   6163   """Wrapper for `Graph.add_to_collections()` using the default graph.
   6164 
   6165   See `tf.Graph.add_to_collections`
   6166   for more details.
   6167 
   6168   Args:
   6169     names: The key for the collections. The `GraphKeys` class
   6170       contains many standard names for collections.
   6171     value: The value to add to the collections.
   6172 
   6173   @compatibility(eager)
   6174   Collections are only supported in eager when variables are created inside an
   6175   EagerVariableStore (e.g. as part of a layer or template).
   6176   @end_compatibility
   6177   """
   6178   get_default_graph().add_to_collections(names, value)
   6179 
   6180 
   6181 @tf_export(v1=["get_collection_ref"])
   6182 def get_collection_ref(key):
   6183   """Wrapper for `Graph.get_collection_ref()` using the default graph.
   6184 
   6185   See `tf.Graph.get_collection_ref`
   6186   for more details.
   6187 
   6188   Args:
   6189     key: The key for the collection. For example, the `GraphKeys` class
   6190       contains many standard names for collections.
   6191 
   6192   Returns:
   6193     The list of values in the collection with the given `name`, or an empty
   6194     list if no value has been added to that collection.  Note that this returns
   6195     the collection list itself, which can be modified in place to change the
   6196     collection.
   6197 
   6198   @compatibility(eager)
   6199   Collections are not supported when eager execution is enabled.
   6200   @end_compatibility
   6201   """
   6202   return get_default_graph().get_collection_ref(key)
   6203 
   6204 
   6205 @tf_export(v1=["get_collection"])
   6206 def get_collection(key, scope=None):
   6207   """Wrapper for `Graph.get_collection()` using the default graph.
   6208 
   6209   See `tf.Graph.get_collection`
   6210   for more details.
   6211 
   6212   Args:
   6213     key: The key for the collection. For example, the `GraphKeys` class
   6214       contains many standard names for collections.
   6215     scope: (Optional.) If supplied, the resulting list is filtered to include
   6216       only items whose `name` attribute matches using `re.match`. Items
   6217       without a `name` attribute are never returned if a scope is supplied and
   6218       the choice or `re.match` means that a `scope` without special tokens
   6219       filters by prefix.
   6220 
   6221   Returns:
   6222     The list of values in the collection with the given `name`, or
   6223     an empty list if no value has been added to that collection. The
   6224     list contains the values in the order under which they were
   6225     collected.
   6226 
   6227   @compatibility(eager)
   6228   Collections are not supported when eager execution is enabled.
   6229   @end_compatibility
   6230   """
   6231   return get_default_graph().get_collection(key, scope)
   6232 
   6233 
   6234 def get_all_collection_keys():
   6235   """Returns a list of collections used in the default graph."""
   6236   return get_default_graph().get_all_collection_keys()
   6237 
   6238 
   6239 name_scope_cache = {}
   6240 
   6241 
   6242 # Named like a function for backwards compatibility with the
   6243 # @tf_contextlib.contextmanager version, which was switched to a class to avoid
   6244 # some object creation overhead.
   6245 @tf_export(v1=["name_scope"])
   6246 class name_scope(object):  # pylint: disable=invalid-name
   6247   """A context manager for use when defining a Python op.
   6248 
   6249   This context manager validates that the given `values` are from the
   6250   same graph, makes that graph the default graph, and pushes a
   6251   name scope in that graph (see
   6252   `tf.Graph.name_scope`
   6253   for more details on that).
   6254 
   6255   For example, to define a new Python op called `my_op`:
   6256 
   6257   ```python
   6258   def my_op(a, b, c, name=None):
   6259     with tf.name_scope(name, "MyOp", [a, b, c]) as scope:
   6260       a = tf.convert_to_tensor(a, name="a")
   6261       b = tf.convert_to_tensor(b, name="b")
   6262       c = tf.convert_to_tensor(c, name="c")
   6263       # Define some computation that uses `a`, `b`, and `c`.
   6264       return foo_op(..., name=scope)
   6265   ```
   6266   """
   6267 
   6268   @property
   6269   def name(self):
   6270     return self._name
   6271 
   6272   def __init__(self, name, default_name=None, values=None):
   6273     """Initialize the context manager.
   6274 
   6275     Args:
   6276       name: The name argument that is passed to the op function.
   6277       default_name: The default name to use if the `name` argument is `None`.
   6278       values: The list of `Tensor` arguments that are passed to the op function.
   6279 
   6280     Raises:
   6281       TypeError: if `default_name` is passed in but not a string.
   6282     """
   6283     if not (default_name is None or isinstance(default_name, six.string_types)):
   6284       raise TypeError(
   6285           "`default_name` type (%s) is not a string type. You likely meant to "
   6286           "pass this into the `values` kwarg."
   6287           % type(default_name))
   6288     self._name = default_name if name is None else name
   6289     self._default_name = default_name
   6290     self._values = values
   6291     self._ctx = context.context()
   6292     self._in_eager_mode = self._ctx.executing_eagerly()
   6293     self._has_symbolic_input_in_eager = False
   6294     if self._values and self._in_eager_mode:
   6295       # The presence of a graph tensor in `self._values` overrides the context.
   6296       for value in self._values:
   6297         if hasattr(value, "graph"):
   6298           self._has_symbolic_input_in_eager = True
   6299           self._name_scope = value.graph.name_scope(self._name)
   6300 
   6301   def __enter__(self):
   6302     """Start the scope block.
   6303 
   6304     Returns:
   6305       The scope name.
   6306 
   6307     Raises:
   6308       ValueError: if neither `name` nor `default_name` is provided
   6309         but `values` are.
   6310     """
   6311     if self._has_symbolic_input_in_eager:
   6312       return self._name_scope.__enter__()
   6313 
   6314     if self._in_eager_mode:
   6315       self._old_name = self._ctx.scope_name
   6316       if not self._name:
   6317         scope_name = ""
   6318       else:
   6319         cache_key = self._name, self._old_name, self._default_name
   6320         if cache_key in name_scope_cache:
   6321           self._ctx.scope_name = name_scope_cache[cache_key]
   6322           return self._ctx.scope_name
   6323         elif self._name[-1] == "/":
   6324           # A trailing slash breaks out of nested name scopes, indicating a
   6325           # fully specified scope name, for compatibility with Graph.name_scope.
   6326           scope_name = self._name
   6327         else:
   6328           name_with_trailing_slash = self._name + "/"
   6329           scope_name = (
   6330               self._old_name + name_with_trailing_slash
   6331               if self._old_name else name_with_trailing_slash)
   6332         name_scope_cache[cache_key] = scope_name
   6333       self._ctx.scope_name = scope_name
   6334       return scope_name
   6335     else:
   6336       if self._name is None and self._values is not None:
   6337         # We only raise an error if values is not None (provided) because
   6338         # currently tf.name_scope(None) (values=None then) is sometimes used as
   6339         # an idiom to reset to top scope.
   6340         raise ValueError(
   6341             "At least one of name (%s) and default_name (%s) must be provided."
   6342             % (self._name, self._default_name))
   6343       if self._values is None:
   6344         self._values = []
   6345       g = _get_graph_from_inputs(self._values)
   6346       self._g_manager = g.as_default()
   6347       self._g_manager.__enter__()
   6348       try:
   6349         self._name_scope = g.name_scope(self._name)
   6350         return self._name_scope.__enter__()
   6351       except:
   6352         self._g_manager.__exit__(*sys.exc_info())
   6353         raise
   6354 
   6355   def __exit__(self, type_arg, value_arg, traceback_arg):
   6356     if self._has_symbolic_input_in_eager:
   6357       self._name_scope.__exit__(type_arg, value_arg, traceback_arg)
   6358     elif self._in_eager_mode:
   6359       self._ctx.scope_name = self._old_name
   6360     else:
   6361       self._name_scope.__exit__(type_arg, value_arg, traceback_arg)
   6362       self._g_manager.__exit__(type_arg, value_arg, traceback_arg)
   6363     return False  # False values do not suppress exceptions
   6364 
   6365 
   6366 @tf_export("name_scope", v1=[])
   6367 class name_scope_v2(name_scope):
   6368   """A context manager for use when defining a Python op.
   6369 
   6370   This context manager pushes a name scope, which will make the name of all
   6371   operations added within it have a prefix.
   6372 
   6373   For example, to define a new Python op called `my_op`:
   6374 
   6375   ```python
   6376   def my_op(a, b, c, name=None):
   6377     with tf.name_scope("MyOp") as scope:
   6378       a = tf.convert_to_tensor(a, name="a")
   6379       b = tf.convert_to_tensor(b, name="b")
   6380       c = tf.convert_to_tensor(c, name="c")
   6381       # Define some computation that uses `a`, `b`, and `c`.
   6382       return foo_op(..., name=scope)
   6383   ```
   6384 
   6385   When executed, the Tensors `a`, `b`, `c`, will have names `MyOp/a`, `MyOp/b`,
   6386   and `MyOp/c`.
   6387 
   6388   If the scope name already exists, the name will be made unique by appending
   6389   `_n`. For example, calling `my_op` the second time will generate `MyOp_1/a`,
   6390   etc.
   6391   """
   6392 
   6393   def __init__(self, name):
   6394     """Initialize the context manager.
   6395 
   6396     Args:
   6397       name: The prefix to use on all names created within the name scope.
   6398 
   6399     Raises:
   6400       ValueError: If name is None, or not a string.
   6401     """
   6402     if name is None or not isinstance(name, six.string_types):
   6403       raise ValueError("name for name_scope must be a string.")
   6404     super(name_scope_v2, self).__init__(name=None, default_name=name)
   6405 
   6406 
   6407 def strip_name_scope(name, export_scope):
   6408   """Removes name scope from a name.
   6409 
   6410   Args:
   6411     name: A `string` name.
   6412     export_scope: Optional `string`. Name scope to remove.
   6413 
   6414   Returns:
   6415     Name with name scope removed, or the original name if export_scope
   6416     is None.
   6417   """
   6418   if export_scope:
   6419     if export_scope[-1] == "/":
   6420       export_scope = export_scope[:-1]
   6421 
   6422     try:
   6423       # Strips export_scope/, export_scope///,
   6424       # ^export_scope/, loc:@export_scope/.
   6425       str_to_replace = r"([\^]|loc:@|^)" + export_scope + r"[\/]+(.*)"
   6426       return re.sub(str_to_replace, r"\1\2", compat.as_str(name), count=1)
   6427     except TypeError as e:
   6428       # If the name is not of a type we can process, simply return it.
   6429       logging.warning(e)
   6430       return name
   6431   else:
   6432     return name
   6433 
   6434 
   6435 def prepend_name_scope(name, import_scope):
   6436   """Prepends name scope to a name.
   6437 
   6438   Args:
   6439     name: A `string` name.
   6440     import_scope: Optional `string`. Name scope to add.
   6441 
   6442   Returns:
   6443     Name with name scope added, or the original name if import_scope
   6444     is None.
   6445   """
   6446   if import_scope:
   6447     if import_scope[-1] == "/":
   6448       import_scope = import_scope[:-1]
   6449 
   6450     try:
   6451       str_to_replace = r"([\^]|loc:@|^)(.*)"
   6452       return re.sub(str_to_replace, r"\1" + import_scope + r"/\2",
   6453                     compat.as_str(name))
   6454     except TypeError as e:
   6455       # If the name is not of a type we can process, simply return it.
   6456       logging.warning(e)
   6457       return name
   6458   else:
   6459     return name
   6460 
   6461 
   6462 # pylint: disable=g-doc-return-or-yield
   6463 # pylint: disable=not-context-manager
   6464 @tf_export(v1=["op_scope"])
   6465 @tf_contextlib.contextmanager
   6466 def op_scope(values, name, default_name=None):
   6467   """DEPRECATED. Same as name_scope above, just different argument order."""
   6468   logging.warn("tf.op_scope(values, name, default_name) is deprecated,"
   6469                " use tf.name_scope(name, default_name, values)")
   6470   with name_scope(name, default_name=default_name, values=values) as scope:
   6471     yield scope
   6472 
   6473 
   6474 _proto_function_registry = registry.Registry("proto functions")
   6475 
   6476 
   6477 def register_proto_function(collection_name,
   6478                             proto_type=None,
   6479                             to_proto=None,
   6480                             from_proto=None):
   6481   """Registers `to_proto` and `from_proto` functions for collection_name.
   6482 
   6483   `to_proto` function converts a Python object to the corresponding protocol
   6484   buffer, and returns the protocol buffer.
   6485 
   6486   `from_proto` function converts protocol buffer into a Python object, and
   6487   returns the object..
   6488 
   6489   Args:
   6490     collection_name: Name of the collection.
   6491     proto_type: Protobuf type, such as `saver_pb2.SaverDef`,
   6492       `variable_pb2.VariableDef`, `queue_runner_pb2.QueueRunnerDef`..
   6493     to_proto: Function that implements Python object to protobuf conversion.
   6494     from_proto: Function that implements protobuf to Python object conversion.
   6495   """
   6496   if to_proto and not callable(to_proto):
   6497     raise TypeError("to_proto must be callable.")
   6498   if from_proto and not callable(from_proto):
   6499     raise TypeError("from_proto must be callable.")
   6500 
   6501   _proto_function_registry.register((proto_type, to_proto, from_proto),
   6502                                     collection_name)
   6503 
   6504 
   6505 def get_collection_proto_type(collection_name):
   6506   """Returns the proto_type for collection_name."""
   6507   try:
   6508     return _proto_function_registry.lookup(collection_name)[0]
   6509   except LookupError:
   6510     return None
   6511 
   6512 
   6513 def get_to_proto_function(collection_name):
   6514   """Returns the to_proto function for collection_name."""
   6515   try:
   6516     return _proto_function_registry.lookup(collection_name)[1]
   6517   except LookupError:
   6518     return None
   6519 
   6520 
   6521 def get_from_proto_function(collection_name):
   6522   """Returns the from_proto function for collection_name."""
   6523   try:
   6524     return _proto_function_registry.lookup(collection_name)[2]
   6525   except LookupError:
   6526     return None
   6527 
   6528 
   6529 def _operation_conversion_error(op, dtype=None, name=None, as_ref=False):
   6530   """Produce a nice error if someone converts an Operation to a Tensor."""
   6531   raise TypeError(("Can't convert Operation '%s' to Tensor "
   6532                    "(target dtype=%r, name=%r, as_ref=%r)") % (op.name, dtype,
   6533                                                                name, as_ref))
   6534 
   6535 
   6536 def _op_to_colocate_with(v):
   6537   """Operation object corresponding to v to use for colocation constraints."""
   6538   if v is None:
   6539     return None
   6540   if isinstance(v, Operation):
   6541     return v
   6542   # We always want to colocate with the reference op.
   6543   # When 'v' is a ResourceVariable, the reference op is the handle creating op.
   6544   #
   6545   # What this should be is:
   6546   # if isinstance(v, ResourceVariable):
   6547   #   return v.handle.op
   6548   # However, that would require a circular import dependency.
   6549   # As of October 2018, there were attempts underway to remove
   6550   # colocation constraints altogether. Assuming that will
   6551   # happen soon, perhaps this hack to work around the circular
   6552   # import dependency is acceptable.
   6553   if hasattr(v, "handle") and hasattr(v.handle, "op") and isinstance(
   6554       v.handle.op, Operation):
   6555     return v.handle.op
   6556   return internal_convert_to_tensor_or_indexed_slices(v, as_ref=True).op
   6557 
   6558 
   6559 def _is_keras_symbolic_tensor(x):
   6560   return hasattr(x, "graph") and getattr(x.graph, "name", None) == "keras_graph"
   6561 
   6562 
   6563 register_tensor_conversion_function(Operation, _operation_conversion_error)
   6564