Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Ops to use variables as resources."""
     16 
     17 # pylint: disable=g-bad-name
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 from tensorflow.core.framework import attr_value_pb2
     23 from tensorflow.core.framework import variable_pb2
     24 from tensorflow.python.eager import context
     25 from tensorflow.python.eager import tape
     26 from tensorflow.python.framework import dtypes
     27 from tensorflow.python.framework import ops
     28 from tensorflow.python.framework import tensor_shape
     29 from tensorflow.python.ops import array_ops
     30 from tensorflow.python.ops import gen_array_ops
     31 from tensorflow.python.ops import gen_resource_variable_ops
     32 from tensorflow.python.ops import gen_state_ops
     33 from tensorflow.python.ops import variables
     34 # go/tf-wildcard-import
     35 # pylint: disable=wildcard-import
     36 from tensorflow.python.ops.gen_resource_variable_ops import *
     37 # pylint: enable=wildcard-import
     38 from tensorflow.python.training import checkpointable
     39 from tensorflow.python.util import compat
     40 
     41 
     42 def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
     43   """Creates a variable handle with information to do shape inference."""
     44   container = ops.get_default_graph()._container  # pylint: disable=protected-access
     45   if container is None:
     46     container = ""
     47   if not graph_mode:
     48     # When in eager mode use a uid for the shared_name, to prevent accidental
     49     # sharing.
     50     shared_name = str(ops.uid())
     51   handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
     52                                                    shared_name=shared_name,
     53                                                    name=name,
     54                                                    container=container)
     55   if graph_mode:
     56     return handle
     57 
     58   # We do not want two distinct ResourceVariable objects for the same
     59   # underlying resource in the runtime.
     60   # When in eager mode, explicitly ensure so here. When in graph mode, it's
     61   # ensured by always generating different variable names.
     62   exists = gen_resource_variable_ops.var_is_initialized_op(handle)
     63   if exists:
     64     raise ValueError("variable object with name '%s' already created. Use "
     65                      "get_variable() if reuse is desired." %
     66                      shared_name)
     67   with context.graph_mode(), ops.Graph().as_default() as graph:
     68     h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
     69                                                 shared_name=shared_name,
     70                                                 name=name,
     71                                                 container=container)
     72 
     73     # Tensor._handle_data contains information for the shape-inference code to
     74     # know the shape and dtype of the variable pointed to by a handle. Since
     75     # shape inference doesn't run in eager mode we copy this data here for when
     76     # the handle is captured by an eager mode function.
     77     handle._handle_data = h._handle_data  # pylint: disable=protected-access
     78   # Clean up our reference cycles to avoid making the garbage collector run.
     79   # pylint: disable=protected-access
     80   # OrderedDict, constructed on Graph creation, makes a simple reference loop
     81   # and hides it in an __attribute in some Python versions. We don't need to
     82   # throw an error if we can't find it, but if we do find it we can break the
     83   # loop to avoid creating work for the garbage collector.
     84   problematic_cycle = graph._functions.__dict__.get("_OrderedDict__root", None)
     85   # pylint: enable=protected-access
     86   if problematic_cycle:
     87     try:
     88       del problematic_cycle[0][:]
     89     except TypeError:
     90       # This is probably not one of the problematic Python versions. Continue
     91       # with the rest of our cleanup.
     92       pass
     93   # Now clean up our own reference cycles by clearing all of the attributes for
     94   # the Graph and op we created.
     95   h.__dict__ = {}
     96   graph.__dict__ = {}
     97   return handle
     98 
     99 
    100 class EagerResourceDeleter(object):
    101   """An object which cleans up a resource handle.
    102 
    103   An alternative to defining a __del__ method on an object. The intended use is
    104   that ResourceVariables or other objects with resource handles will maintain a
    105   single reference to this object. When the parent object is collected, this
    106   object will be too. Even if the parent object is part of a reference cycle,
    107   the cycle will be collectable.
    108   """
    109 
    110   def __init__(self, handle, handle_device):
    111     if not isinstance(handle, ops.Tensor):
    112       raise ValueError(
    113           ("Passed handle=%s to EagerResourceDeleter. Was expecting a handle "
    114            "Tensor." % (handle,)))
    115     self._handle = handle
    116     self._handle_device = handle_device
    117 
    118   def __del__(self):
    119     # Resources follow object-identity when executing eagerly, so it is safe to
    120     # delete the resource we have a handle to. Each Graph has a unique container
    121     # name, which prevents resource sharing.
    122     try:
    123       # This resource was created in eager mode. However, this destructor may be
    124       # running in graph mode (especially during unit tests). To clean up
    125       # successfully, we switch back into eager mode temporarily.
    126       with context.eager_mode():
    127         with ops.device(self._handle_device):
    128           gen_resource_variable_ops.destroy_resource_op(
    129               self._handle, ignore_lookup_error=True)
    130     except TypeError:
    131       # Suppress some exceptions, mainly for the case when we're running on
    132       # module deletion. Things that can go wrong include the context module
    133       # already being unloaded, self._handle._handle_data no longer being
    134       # valid, and so on. Printing warnings in these cases is silly
    135       # (exceptions raised from __del__ are printed as warnings to stderr).
    136       pass  # 'NoneType' object is not callable when the handle has been
    137             # partially unloaded.
    138     except AttributeError:
    139       pass  # 'NoneType' object has no attribute 'eager_mode' when context has
    140             # been unloaded. Will catch other module unloads as well.
    141 
    142 
    143 def shape_safe_assign_variable_handle(handle, shape, value, name=None):
    144   """Helper that checks shape compatibility and assigns variable."""
    145   value_tensor = ops.convert_to_tensor(value)
    146   shape.assert_is_compatible_with(value_tensor.shape)
    147   return gen_resource_variable_ops.assign_variable_op(handle,
    148                                                       value_tensor,
    149                                                       name=name)
    150 
    151 
    152 class ResourceVariable(variables.Variable):
    153   """Variable based on resource handles.
    154 
    155   See the ${variables} documentation for more details.
    156 
    157   A `ResourceVariable` allows you to maintain state across subsequent calls to
    158   session.run.
    159 
    160   The `ResourceVariable` constructor requires an initial value for the variable,
    161   which can be a `Tensor` of any type and shape. The initial value defines the
    162   type and shape of the variable. After construction, the type and shape of
    163   the variable are fixed. The value can be changed using one of the assign
    164   methods.
    165 
    166   Just like any `Tensor`, variables created with `ResourceVariable()` can be
    167   used as inputs for other Ops in the graph. Additionally, all the operators
    168   overloaded for the `Tensor` class are carried over to variables, so you can
    169   also add nodes to the graph by just doing arithmetic on variables.
    170 
    171   Unlike tf.Variable, a tf.ResourceVariable has well-defined semantics. Each
    172   usage of a ResourceVariable in a TensorFlow graph adds a read_value operation
    173   to the graph. The Tensors returned by a read_value operation are guaranteed
    174   to see all modifications to the value of the variable which happen in any
    175   operation on which the read_value depends on (either directly, indirectly, or
    176   via a control dependency) and guaranteed to not see any modification to the
    177   value of the variable on which the read_value operation does not depend on.
    178 
    179   For example, if there is more than one assignment to a ResourceVariable in
    180   a single session.run call there is a well-defined value for each operation
    181   which uses the variable's value if the assignments and the read are connected
    182   by edges in the graph. Consider the following example, in which two writes
    183   can cause tf.Variable and tf.ResourceVariable to behave differently:
    184 
    185    ```python
    186     a = tf.ResourceVariable(1.0)
    187     a.initializer.run()
    188 
    189     assign = a.assign(2.0)
    190     with tf.control_dependencies([assign]):
    191       b = a.read_value()
    192     with tf.control_dependencies([b]):
    193       other_assign = a.assign(3.0)
    194     with tf.control_dependencies([other_assign]):
    195       # Will print 2.0 because the value was read before other_assign ran. If
    196       # `a` was a tf.Variable instead, 2.0 or 3.0 could be printed.
    197       tf.Print(b, [b]).eval()
    198   ```
    199 
    200   To enforce these consistency properties tf.ResourceVariable might make more
    201   copies than an equivalent tf.Variable under the hood, so tf.Variable is still
    202   not deprecated.
    203   """
    204 
    205   def __init__(self,
    206                initial_value=None,
    207                trainable=True,
    208                collections=None,
    209                validate_shape=True,
    210                caching_device=None,
    211                name=None,
    212                dtype=None,
    213                variable_def=None,
    214                import_scope=None,
    215                constraint=None):
    216     """Creates a variable.
    217 
    218     Args:
    219       initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
    220         which is the initial value for the Variable. The initial value must have
    221         a shape specified unless `validate_shape` is set to False. Can also be a
    222         callable with no argument that returns the initial value when called.
    223         (Note that initializer functions from init_ops.py must first be bound
    224          to a shape before being used here.)
    225       trainable: If `True`, the default, also adds the variable to the graph
    226         collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
    227         the default list of variables to use by the `Optimizer` classes.
    228       collections: List of graph collections keys. The new variable is added to
    229         these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
    230       validate_shape: Ignored. Provided for compatibility with tf.Variable.
    231       caching_device: Optional device string or function describing where the
    232         Variable should be cached for reading.  Defaults to the Variable's
    233         device.  If not `None`, caches on another device.  Typical use is to
    234         cache on the device where the Ops using the Variable reside, to
    235         deduplicate copying through `Switch` and other conditional statements.
    236       name: Optional name for the variable. Defaults to `'Variable'` and gets
    237         uniquified automatically.
    238       dtype: If set, initial_value will be converted to the given type.
    239         If None, either the datatype will be kept (if initial_value is
    240         a Tensor) or float32 will be used (if it is a Python object convertible
    241         to a Tensor).
    242       variable_def: `VariableDef` protocol buffer. If not None, recreates the
    243         `ResourceVariable` object with its contents. `variable_def` and other
    244         arguments (except for import_scope) are mutually exclusive.
    245       import_scope: Optional `string`. Name scope to add to the
    246         ResourceVariable. Only used when `variable_def` is provided.
    247       constraint: An optional projection function to be applied to the variable
    248         after being updated by an `Optimizer` (e.g. used to implement norm
    249         constraints or value constraints for layer weights). The function must
    250         take as input the unprojected Tensor representing the value of the
    251         variable and return the Tensor for the projected value
    252         (which must have the same shape). Constraints are not safe to
    253         use when doing asynchronous distributed training.
    254 
    255     Raises:
    256       ValueError: If the initial value is not specified, or does not have a
    257         shape and `validate_shape` is `True`.
    258 
    259     @compatibility(eager)
    260     When Eager Execution is enabled, the default for the `collections` argument
    261     is `None`, which signifies that this `Variable` will not be added to any
    262     collections.
    263     @end_compatibility
    264     """
    265     if variable_def:
    266       if initial_value is not None:
    267         raise ValueError("variable_def and initial_value are mutually "
    268                          "exclusive.")
    269       if not context.in_graph_mode():
    270         raise ValueError("Creating ResourceVariable from variable_def"
    271                          " only supported in GRAPH mode.")
    272       self._init_from_proto(variable_def, import_scope=import_scope)
    273     else:
    274       self._init_from_args(
    275           initial_value=initial_value,
    276           trainable=trainable,
    277           collections=collections,
    278           validate_shape=validate_shape,
    279           caching_device=caching_device,
    280           name=name,
    281           dtype=dtype,
    282           constraint=constraint)
    283 
    284   # pylint: disable=unused-argument
    285   def _init_from_args(self,
    286                       initial_value=None,
    287                       trainable=True,
    288                       collections=None,
    289                       validate_shape=True,
    290                       caching_device=None,
    291                       name=None,
    292                       dtype=None,
    293                       constraint=None):
    294     """Creates a variable.
    295 
    296     Args:
    297       initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
    298         which is the initial value for the Variable. The initial value must have
    299         a shape specified unless `validate_shape` is set to False. Can also be a
    300         callable with no argument that returns the initial value when called.
    301         (Note that initializer functions from init_ops.py must first be bound
    302          to a shape before being used here.)
    303       trainable: If `True`, the default, also adds the variable to the graph
    304         collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
    305         the default list of variables to use by the `Optimizer` classes.
    306       collections: List of graph collections keys. The new variable is added to
    307         these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
    308       validate_shape: Ignored. Provided for compatibility with tf.Variable.
    309       caching_device: Optional device string or function describing where the
    310         Variable should be cached for reading.  Defaults to the Variable's
    311         device.  If not `None`, caches on another device.  Typical use is to
    312         cache on the device where the Ops using the Variable reside, to
    313         deduplicate copying through `Switch` and other conditional statements.
    314       name: Optional name for the variable. Defaults to `'Variable'` and gets
    315         uniquified automatically.
    316       dtype: If set, initial_value will be converted to the given type.
    317         If None, either the datatype will be kept (if initial_value is
    318        a Tensor) or float32 will be used (if it is a Python object convertible
    319        to a Tensor).
    320       constraint: An optional projection function to be applied to the variable
    321         after being updated by an `Optimizer` (e.g. used to implement norm
    322         constraints or value constraints for layer weights). The function must
    323         take as input the unprojected Tensor representing the value of the
    324         variable and return the Tensor for the projected value
    325         (which must have the same shape). Constraints are not safe to
    326         use when doing asynchronous distributed training.
    327 
    328     Raises:
    329       ValueError: If the initial value is not specified, or does not have a
    330         shape and `validate_shape` is `True`.
    331 
    332     @compatibility(eager)
    333     When Eager Execution is enabled, variables are never added to collections.
    334     It is not implicitly added to the `GLOBAL_VARIABLES` or
    335     `TRAINABLE_VARIABLES` collections, and the `collections` argument is
    336     ignored.
    337     @end_compatibility
    338     """
    339     if initial_value is None:
    340       raise ValueError("initial_value must be specified.")
    341     init_from_fn = callable(initial_value)
    342 
    343     if collections is None:
    344       collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    345     if not isinstance(collections, (list, tuple, set)):
    346       raise ValueError(
    347           "collections argument to Variable constructor must be a list, tuple, "
    348           "or set. Got %s of type %s" % (collections, type(collections)))
    349     if constraint is not None and not callable(constraint):
    350       raise ValueError("The `constraint` argument must be a callable.")
    351 
    352     if isinstance(initial_value, checkpointable.CheckpointInitialValue):
    353       self._maybe_initialize_checkpointable()
    354       self._update_uid = initial_value.checkpoint_position.restore_uid
    355       initial_value = initial_value.wrapped_value
    356 
    357     self._trainable = trainable
    358     if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
    359       collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
    360     self._save_slice_info = None
    361     # Store the graph key so optimizers know how to only retrieve variables from
    362     # this graph.
    363     self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
    364     with ops.init_scope():
    365       self._in_graph_mode = context.in_graph_mode()
    366       with ops.name_scope(name, "Variable", []
    367                           if init_from_fn else [initial_value]) as name:
    368         # pylint: disable=protected-access
    369         handle_name = ops._name_from_scope_name(name)
    370         if init_from_fn:
    371           # Use attr_scope and device(None) to simulate the behavior of
    372           # colocate_with when the variable we want to colocate with doesn't
    373           # yet exist.
    374           if self._in_graph_mode:
    375             attr = attr_value_pb2.AttrValue(
    376                 list=attr_value_pb2.AttrValue.ListValue(
    377                     s=[compat.as_bytes("loc:@%s" % handle_name)]))
    378             with ops.get_default_graph()._attr_scope({"_class": attr}):
    379               with ops.name_scope("Initializer"), ops.device(None):
    380                 initial_value = ops.convert_to_tensor(
    381                     initial_value(), name="initial_value", dtype=dtype)
    382               self._handle = _eager_safe_variable_handle(
    383                   shape=initial_value.get_shape(),
    384                   dtype=initial_value.dtype.base_dtype,
    385                   shared_name=handle_name,
    386                   name=name,
    387                   graph_mode=self._in_graph_mode)
    388               self._handle_device = (
    389                   self._handle.device if self._in_graph_mode else
    390                   context.get_default_context().device_name)
    391               self._shape = initial_value.get_shape()
    392           else:
    393             initial_value = initial_value()
    394             with ops.name_scope("Initializer"):
    395               initial_value = ops.convert_to_tensor(
    396                   initial_value, name="initial_value", dtype=dtype)
    397             self._handle = _eager_safe_variable_handle(
    398                 shape=initial_value.get_shape(),
    399                 dtype=initial_value.dtype.base_dtype,
    400                 shared_name=handle_name,
    401                 name=name,
    402                 graph_mode=False)
    403             self._handle_device = (
    404                 self._handle.device if self._in_graph_mode else
    405                 context.get_default_context().device_name)
    406             self._shape = initial_value.get_shape()
    407         # pylint: enable=protected-access
    408 
    409         # Or get the initial value from a Tensor or Python object.
    410         else:
    411           with ops.name_scope("Initializer"):
    412             initial_value = ops.convert_to_tensor(
    413                 initial_value, name="initial_value", dtype=dtype)
    414           # pylint: disable=protected-access
    415           if (self._in_graph_mode and initial_value is not None and
    416               initial_value.op._get_control_flow_context() is not None):
    417             raise ValueError(
    418                 "Initializer for variable %s is from inside a control-flow "
    419                 "construct, such as a loop or conditional. When creating a "
    420                 "variable inside a loop or conditional, use a lambda as the "
    421                 "initializer." % name)
    422           # pylint: enable=protected-access
    423           self._handle = _eager_safe_variable_handle(
    424               shape=initial_value.get_shape(),
    425               dtype=initial_value.dtype.base_dtype,
    426               shared_name=handle_name,
    427               name=name,
    428               graph_mode=self._in_graph_mode)
    429           self._handle_device = (self._handle.device if self._in_graph_mode else
    430                                  context.get_default_context().device_name)
    431           self._shape = initial_value.get_shape()
    432 
    433         self._initial_value = initial_value if self._in_graph_mode else None
    434         self._handle_name = handle_name + ":0"
    435         self._dtype = initial_value.dtype.base_dtype
    436         self._constraint = constraint
    437 
    438         if self._in_graph_mode:
    439           with ops.name_scope("IsInitialized"):
    440             self._is_initialized_op = (
    441                 gen_resource_variable_ops.var_is_initialized_op(self._handle))
    442           if initial_value is not None:
    443             with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
    444               self._initializer_op = (
    445                   gen_resource_variable_ops.assign_variable_op(
    446                       self._handle,
    447                       self._try_guard_against_uninitialized_dependencies(
    448                           initial_value),
    449                       name=n))
    450           with ops.name_scope("Read"), ops.colocate_with(self._handle):
    451             # Manually assign reads to the handle's device to avoid log
    452             # messages.
    453             with ops.device(self._handle_device):
    454               value = self._read_variable_op()
    455             self._graph_element = value
    456             if caching_device is not None:
    457               # Variables may be created in a tf.device() or ops.colocate_with()
    458               # context. At the same time, users would expect caching device to
    459               # be independent of this context, and/or would not expect the
    460               # current device context to be merged with the caching device
    461               # spec.  Therefore we reset the colocation stack before creating
    462               # the cached value. Note that resetting the colocation stack will
    463               # also reset the device stack.
    464               with ops.colocate_with(None, ignore_existing=True):
    465                 with ops.device(caching_device):
    466                   self._cached_value = array_ops.identity(value)
    467             else:
    468               self._cached_value = None
    469         else:
    470           gen_resource_variable_ops.assign_variable_op(self._handle,
    471                                                        initial_value)
    472           self._is_initialized_op = None
    473           self._initializer_op = None
    474           self._graph_element = None
    475           if caching_device:
    476             with ops.device(caching_device):
    477               self._cached_value = self._read_variable_op()
    478           else:
    479             self._cached_value = None
    480         if context.in_graph_mode():
    481           ops.add_to_collections(collections, self)
    482         elif ops.GraphKeys.GLOBAL_STEP in collections:
    483           ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)
    484 
    485     if not self._in_graph_mode:
    486       # After the handle has been created, set up a way to clean it up when
    487       # executing eagerly. We'll hold the only reference to the deleter, so that
    488       # when this object is garbage collected the deleter will be too. This
    489       # means ResourceVariables can be part of reference cycles without those
    490       # cycles being uncollectable, and means that no __del__ will be defined at
    491       # all in graph mode.
    492       self._handle_deleter = EagerResourceDeleter(
    493           handle=self._handle, handle_device=self._handle_device)
    494 
    495   def _init_from_proto(self, variable_def, import_scope=None):
    496     """Initializes from `VariableDef` proto."""
    497     # Note that init_from_proto is currently not supported in Eager mode.
    498     assert context.in_graph_mode()
    499     self._in_graph_mode = True
    500     assert isinstance(variable_def, variable_pb2.VariableDef)
    501     if not variable_def.is_resource:
    502       raise ValueError("Trying to restore Variable as ResourceVariable.")
    503 
    504     # Create from variable_def.
    505     g = ops.get_default_graph()
    506     self._handle = g.as_graph_element(
    507         ops.prepend_name_scope(
    508             variable_def.variable_name, import_scope=import_scope))
    509     self._shape = tensor_shape.TensorShape(
    510         self._handle.op.get_attr("shape"))
    511     self._handle_device = self._handle.device
    512     self._handle_name = self._handle.name
    513     self._initializer_op = g.as_graph_element(
    514         ops.prepend_name_scope(
    515             variable_def.initializer_name, import_scope=import_scope))
    516     # Check whether initial_value_name exists for backwards compatibility.
    517     if (hasattr(variable_def, "initial_value_name") and
    518         variable_def.initial_value_name):
    519       self._initial_value = g.as_graph_element(
    520           ops.prepend_name_scope(variable_def.initial_value_name,
    521                                  import_scope=import_scope))
    522     else:
    523       self._initial_value = None
    524     if variable_def.snapshot_name:
    525       self._cached_value = g.as_graph_element(
    526           ops.prepend_name_scope(
    527               variable_def.snapshot_name, import_scope=import_scope))
    528     else:
    529       self._cached_value = None
    530     if variable_def.HasField("save_slice_info_def"):
    531       self._save_slice_info = variables.Variable.SaveSliceInfo(
    532           save_slice_info_def=variable_def.save_slice_info_def,
    533           import_scope=import_scope)
    534     else:
    535       self._save_slice_info = None
    536     self._caching_device = None
    537     self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
    538     self._graph_element = self.value()
    539     self._constraint = None
    540 
    541   def __nonzero__(self):
    542     return self.__bool__()
    543 
    544   def __bool__(self):
    545     return bool(self.read_value())
    546 
    547   @property
    548   def dtype(self):
    549     """The dtype of this variable."""
    550     return self._dtype
    551 
    552   @property
    553   def device(self):
    554     """The device this variable is on."""
    555     return self._handle_device
    556 
    557   @property
    558   def graph(self):
    559     """The `Graph` of this variable."""
    560     return self._handle.graph
    561 
    562   @property
    563   def name(self):
    564     """The name of the handle for this variable."""
    565     return self._handle_name
    566 
    567   @property
    568   def shape(self):
    569     """The shape of this variable."""
    570     return self._shape
    571 
    572   @property
    573   def create(self):
    574     """The op responsible for initializing this variable."""
    575     if not self._in_graph_mode:
    576       raise RuntimeError("Calling create in EAGER mode not supported.")
    577     return self._initializer_op
    578 
    579   @property
    580   def handle(self):
    581     """The handle by which this variable can be accessed."""
    582     return self._handle
    583 
    584   def value(self):
    585     """A cached operation which reads the value of this variable."""
    586     if self._cached_value is not None:
    587       return self._cached_value
    588     with ops.colocate_with(None, ignore_existing=True):
    589       with ops.device(self._handle_device):
    590         return self._read_variable_op()
    591 
    592   def _as_graph_element(self):
    593     """Conversion function for Graph.as_graph_element()."""
    594     return self._graph_element
    595 
    596   @property
    597   def initializer(self):
    598     """The op responsible for initializing this variable."""
    599     return self._initializer_op
    600 
    601   @property
    602   def initial_value(self):
    603     """Returns the Tensor used as the initial value for the variable."""
    604     if context.in_eager_mode():
    605       raise RuntimeError("initial_value not supported in EAGER mode.")
    606     return self._initial_value
    607 
    608   @property
    609   def constraint(self):
    610     """Returns the constraint function associated with this variable.
    611 
    612     Returns:
    613       The constraint function that was passed to the variable constructor.
    614       Can be `None` if no constraint was passed.
    615     """
    616     return self._constraint
    617 
    618   @property
    619   def op(self):
    620     """The op for this variable."""
    621     return self._handle.op
    622 
    623   def eval(self, session=None):
    624     """Evaluates and returns the value of this variable."""
    625     if context.in_eager_mode():
    626       raise RuntimeError("Trying to eval in EAGER mode")
    627     return self._graph_element.eval(session=session)
    628 
    629   def numpy(self):
    630     if context.in_graph_mode():
    631       raise NotImplementedError(
    632           "numpy() is only available when eager execution is enabled.")
    633     return self.read_value().numpy()
    634 
    635   def count_up_to(self, limit):
    636     """Increments this variable until it reaches `limit`.
    637 
    638     When that Op is run it tries to increment the variable by `1`. If
    639     incrementing the variable would bring it above `limit` then the Op raises
    640     the exception `OutOfRangeError`.
    641 
    642     If no error is raised, the Op outputs the value of the variable before
    643     the increment.
    644 
    645     This is essentially a shortcut for `count_up_to(self, limit)`.
    646 
    647     Args:
    648       limit: value at which incrementing the variable raises an error.
    649 
    650     Returns:
    651       A `Tensor` that will hold the variable value before the increment. If no
    652       other Op modifies this variable, the values produced will all be
    653       distinct.
    654     """
    655     return gen_state_ops.resource_count_up_to(self.handle, limit=limit,
    656                                               T=self.dtype)
    657 
    658   def _set_save_slice_info(self, save_slice_info):
    659     """Sets the slice info for this `ResourceVariable`.
    660 
    661     Args:
    662       save_slice_info: A `Variable.SaveSliceInfo` object.
    663     """
    664     self._save_slice_info = save_slice_info
    665 
    666   def _get_save_slice_info(self):
    667     return self._save_slice_info
    668 
    669   def _read_variable_op(self):
    670     if hasattr(self, "_trainable") and self._trainable:
    671       tape.watch_variable(self)
    672     return gen_resource_variable_ops.read_variable_op(self._handle,
    673                                                       self._dtype)
    674 
    675   def read_value(self):
    676     """Constructs an op which reads the value of this variable.
    677 
    678     Should be used when there are multiple reads, or when it is desirable to
    679     read the value only after some condition is true.
    680 
    681     Returns:
    682      the read operation.
    683     """
    684     with ops.name_scope("Read"):
    685       # Ensure we read the variable in the same device as the handle.
    686       with ops.device(self._handle_device):
    687         value = self._read_variable_op()
    688     # Return an identity so it can get placed on whatever device the context
    689     # specifies instead of the device where the variable is.
    690     return array_ops.identity(value)
    691 
    692   def sparse_read(self, indices, name=None):
    693     """Reads the value of this variable sparsely, using `gather`."""
    694     with ops.name_scope("Gather" if name is None else name) as name:
    695       if self._trainable:
    696         tape.watch_variable(self)
    697       value = gen_resource_variable_ops.resource_gather(
    698           self._handle, indices, dtype=self._dtype, name=name)
    699     return array_ops.identity(value)
    700 
    701   def to_proto(self, export_scope=None):
    702     """Converts a `ResourceVariable` to a `VariableDef` protocol buffer.
    703 
    704     Args:
    705       export_scope: Optional `string`. Name scope to remove.
    706 
    707     Raises:
    708       RuntimeError: If run in EAGER mode.
    709 
    710     Returns:
    711       A `VariableDef` protocol buffer, or `None` if the `Variable` is not
    712       in the specified name scope.
    713     """
    714     if context.in_eager_mode():
    715       raise RuntimeError("to_proto not supported in EAGER mode.")
    716     if export_scope is None or self.handle.name.startswith(export_scope):
    717       var_def = variable_pb2.VariableDef()
    718       var_def.variable_name = ops.strip_name_scope(self.handle.name,
    719                                                    export_scope)
    720       if self._initial_value is not None:
    721         # This is inside an if-statement for backwards compatibility, since
    722         # self._initial_value might be None for variables constructed from old
    723         # protos.
    724         var_def.initial_value_name = ops.strip_name_scope(
    725             self._initial_value.name, export_scope)
    726       var_def.initializer_name = ops.strip_name_scope(self.initializer.name,
    727                                                       export_scope)
    728       if self._cached_value is not None:
    729         var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name,
    730                                                      export_scope)
    731       var_def.is_resource = True
    732       if self._save_slice_info:
    733         var_def.save_slice_info_def.MergeFrom(
    734             self._save_slice_info.to_proto(export_scope=export_scope))
    735       return var_def
    736     else:
    737       return None
    738 
    739   @staticmethod
    740   def from_proto(variable_def, import_scope=None):
    741     if context.in_eager_mode():
    742       raise RuntimeError("from_proto not supported in EAGER mode.")
    743     return ResourceVariable(
    744         variable_def=variable_def, import_scope=import_scope)
    745 
    746   @staticmethod
    747   def _OverloadAllOperators():  # pylint: disable=invalid-name
    748     """Register overloads for all operators."""
    749     for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
    750       ResourceVariable._OverloadOperator(operator)
    751     # For slicing, bind getitem differently than a tensor (use SliceHelperVar
    752     # instead)
    753     # pylint: disable=protected-access
    754     setattr(ResourceVariable, "__getitem__", array_ops._SliceHelperVar)
    755 
    756   def _AsTensor(self):
    757     return self.value()
    758 
    759   def _ref(self):
    760     """Unsupported."""
    761     raise NotImplementedError("ResourceVariable does not implement _ref()")
    762 
    763   def set_shape(self, shape):
    764     """Unsupported."""
    765     raise NotImplementedError("ResourceVariable does not implement set_shape()")
    766 
    767   @staticmethod
    768   def _OverloadOperator(operator):  # pylint: disable=invalid-name
    769     """Defer an operator overload to `ops.Tensor`.
    770 
    771     We pull the operator out of ops.Tensor dynamically to avoid ordering issues.
    772 
    773     Args:
    774       operator: string. The operator name.
    775     """
    776 
    777     def _run_op(a, *args):
    778       # pylint: disable=protected-access
    779       value = a._AsTensor()
    780       return getattr(ops.Tensor, operator)(value, *args)
    781 
    782     # Propagate __doc__ to wrapper
    783     try:
    784       _run_op.__doc__ = getattr(ops.Tensor, operator).__doc__
    785     except AttributeError:
    786       pass
    787 
    788     setattr(ResourceVariable, operator, _run_op)
    789 
    790   __array_priority__ = 100
    791 
    792   def assign_sub(self, delta, use_locking=None, name=None):
    793     # TODO(apassos): this here and below is not atomic. Consider making it
    794     # atomic if there's a way to do so without a performance cost for those who
    795     # don't need it.
    796     return self._lazy_read(gen_resource_variable_ops.assign_sub_variable_op(
    797         self.handle,
    798         ops.convert_to_tensor(delta, dtype=self.dtype),
    799         name=name))
    800 
    801   def assign_add(self, delta, use_locking=None, name=None):
    802     return self._lazy_read(gen_resource_variable_ops.assign_add_variable_op(
    803         self.handle,
    804         ops.convert_to_tensor(delta, dtype=self.dtype),
    805         name=name))
    806 
    807   def _lazy_read(self, op):
    808     if hasattr(self, "_trainable") and self._trainable:
    809       tape.watch_variable(self)
    810     return _UnreadVariable(
    811         self._handle, self.dtype, self._handle_device, self._shape,
    812         self._in_graph_mode,
    813         self._handle_deleter if not self._in_graph_mode else None, op)
    814 
    815   def assign(self, value, use_locking=None, name=None):
    816     value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
    817     self._shape.assert_is_compatible_with(value_tensor.shape)
    818     return self._lazy_read(
    819         gen_resource_variable_ops.assign_variable_op(
    820             self.handle,
    821             value_tensor,
    822             name=name))
    823 
    824   def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask,
    825                             end_mask, ellipsis_mask, new_axis_mask,
    826                             shrink_axis_mask):
    827     return self._lazy_read(
    828         gen_array_ops.resource_strided_slice_assign(
    829             ref=self.handle,
    830             begin=begin,
    831             end=end,
    832             strides=strides,
    833             value=value,
    834             name=name,
    835             begin_mask=begin_mask,
    836             end_mask=end_mask,
    837             ellipsis_mask=ellipsis_mask,
    838             new_axis_mask=new_axis_mask,
    839             shrink_axis_mask=shrink_axis_mask))
    840 
    841   def __int__(self):
    842     if self.dtype != dtypes.int32 and self.dtype != dtypes.int64:
    843       raise TypeError("Non-integer variable can't be converted to integer.")
    844     return int(self.value().numpy())
    845 
    846   def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
    847     del name
    848     if dtype is not None and dtype != self.dtype:
    849       print("trying to switch the dtype to ", dtype, " from ", self.dtype)
    850       return NotImplemented
    851     if as_ref:
    852       return self.read_value().op.inputs[0]
    853     else:
    854       return self.value()
    855 
    856   def __iadd__(self, unused_other):
    857     raise RuntimeError("Variable += value not supported. Use "
    858                        "variable.assign_add(value) to modify the variable "
    859                        "value and variable = variable + value to get a new "
    860                        "Tensor object.")
    861 
    862   def __isub__(self, unused_other):
    863     raise RuntimeError("Variable -= value not supported. Use "
    864                        "variable.assign_sub(value) to modify the variable "
    865                        "value and variable = variable - value to get a new "
    866                        "Tensor object.")
    867 
    868   def __imul__(self, unused_other):
    869     raise RuntimeError("Variable *= value not supported. Use "
    870                        "variable.assign_mul(value) to modify the variable "
    871                        "value and variable = variable * value to get a new "
    872                        "Tensor object.")
    873 
    874   def __idiv__(self, unused_other):
    875     raise RuntimeError("Variable /= value not supported. Use "
    876                        "variable.assign_div(value) to modify the variable "
    877                        "value and variable = variable / value to get a new "
    878                        "Tensor object.")
    879 
    880   def __itruediv__(self, unused_other):
    881     raise RuntimeError("Variable /= value not supported. Use "
    882                        "variable.assign_div(value) to modify the variable "
    883                        "value and variable = variable / value to get a new "
    884                        "Tensor object.")
    885 
    886   def __irealdiv__(self, unused_other):
    887     raise RuntimeError("Variable /= value not supported. Use "
    888                        "variable.assign_div(value) to modify the variable "
    889                        "value and variable = variable / value to get a new "
    890                        "Tensor object.")
    891 
    892   def __ipow__(self, unused_other):
    893     raise RuntimeError("Variable **= value not supported. Use "
    894                        "value and variable = variable ** value to get a new "
    895                        "Tensor object.")
    896 
    897 
    898 def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
    899   return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
    900 
    901 
    902 class _UnreadVariable(ResourceVariable):
    903   """Represents a future for a read of a variable.
    904 
    905   Pretends to be the tensor if anyone looks.
    906   """
    907 
    908   def __init__(self, handle, dtype, handle_device,  # pylint: disable=super-init-not-called
    909                shape, in_graph_mode, deleter, parent_op):
    910     # We do not call super init on purpose.
    911     self._trainable = False
    912     self._save_slice_info = None
    913     self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
    914     self._in_graph_mode = in_graph_mode
    915     self._handle = handle
    916     self._handle_device = handle_device
    917     self._shape = shape
    918     self._initial_value = None
    919     if isinstance(self._handle, ops.EagerTensor):
    920       self._handle_name = ""
    921     else:
    922       self._handle_name = self._handle.name
    923     self._dtype = dtype
    924     self._constraint = None
    925     self._cached_value = None
    926     self._is_initialized_op = None
    927     self._initializer_op = None
    928     self._parent_op = parent_op
    929     if context.in_graph_mode():
    930       self._graph_element = self.read_value()
    931     else:
    932       self._graph_element = None
    933     self._handle_deleter = deleter
    934 
    935   def value(self):
    936     return self._read_variable_op()
    937 
    938   def read_value(self):
    939     return self._read_variable_op()
    940 
    941   def _read_variable_op(self):
    942     with ops.control_dependencies([self._parent_op]):
    943       return gen_resource_variable_ops.read_variable_op(self._handle,
    944                                                         self._dtype)
    945 
    946   def set_shape(self, shape):
    947     self._shape = shape
    948 
    949   @property
    950   def op(self):
    951     """The op for this variable."""
    952     return self._parent_op
    953 
    954 ops.register_tensor_conversion_function(_UnreadVariable, _dense_var_to_tensor)
    955 ops.register_dense_tensor_like_type(_UnreadVariable)
    956 
    957 # Register a conversion function which reads the value of the variable,
    958 # allowing instances of the class to be used as tensors.
    959 
    960 # Note: registering for Variable after ResourceVariable because inheritance will
    961 # otherwise lead to the wrong behavior.
    962 ops.register_tensor_conversion_function(ResourceVariable, _dense_var_to_tensor)
    963 ops.register_tensor_conversion_function(
    964     variables.Variable, variables.Variable._TensorConversionFunction)  # pylint: disable=protected-access
    965 
    966 # pylint: disable=protected-access
    967 ResourceVariable._OverloadAllOperators()
    968 ops.register_dense_tensor_like_type(ResourceVariable)
    969 
    970 
    971 @ops.RegisterGradient("ReadVariableOp")
    972 def _ReadGrad(_, grad):
    973   """Gradient for read op."""
    974   return grad
    975 
    976 
    977 @ops.RegisterGradient("ResourceGather")
    978 def _GatherGrad(op, grad):
    979   """Gradient for gather op."""
    980   # Build appropriately shaped IndexedSlices
    981   handle = op.inputs[0]
    982   indices = op.inputs[1]
    983   params_shape = gen_resource_variable_ops.variable_shape(handle)
    984   size = array_ops.expand_dims(array_ops.size(indices), 0)
    985   values_shape = array_ops.concat([size, params_shape[1:]], 0)
    986   values = array_ops.reshape(grad, values_shape)
    987   indices = array_ops.reshape(indices, size)
    988   return (ops.IndexedSlices(values, indices, params_shape), None)
    989 
    990 
    991 def _to_proto_fn(v, export_scope=None):
    992   """Converts Variable and ResourceVariable to VariableDef for collections."""
    993   return v.to_proto(export_scope=export_scope)
    994 
    995 
    996 def _from_proto_fn(v, import_scope=None):
    997   """Creates Variable or ResourceVariable from VariableDef as needed."""
    998   if v.is_resource:
    999     return ResourceVariable.from_proto(v, import_scope=import_scope)
   1000   return variables.Variable.from_proto(v, import_scope=import_scope)
   1001 
   1002 
   1003 ops.register_proto_function(
   1004     ops.GraphKeys.GLOBAL_VARIABLES,
   1005     proto_type=variable_pb2.VariableDef,
   1006     to_proto=_to_proto_fn,
   1007     from_proto=_from_proto_fn)
   1008 ops.register_proto_function(
   1009     ops.GraphKeys.TRAINABLE_VARIABLES,
   1010     proto_type=variable_pb2.VariableDef,
   1011     to_proto=_to_proto_fn,
   1012     from_proto=_from_proto_fn)
   1013 ops.register_proto_function(
   1014     ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
   1015     proto_type=variable_pb2.VariableDef,
   1016     to_proto=_to_proto_fn,
   1017     from_proto=_from_proto_fn)
   1018 ops.register_proto_function(
   1019     ops.GraphKeys.LOCAL_VARIABLES,
   1020     proto_type=variable_pb2.VariableDef,
   1021     to_proto=_to_proto_fn,
   1022     from_proto=_from_proto_fn)
   1023 ops.register_proto_function(
   1024     ops.GraphKeys.MODEL_VARIABLES,
   1025     proto_type=variable_pb2.VariableDef,
   1026     to_proto=_to_proto_fn,
   1027     from_proto=_from_proto_fn)
   1028 
   1029 
   1030 def is_resource_variable(var):
   1031   """"Returns True if `var` is to be considered a ResourceVariable."""
   1032   return isinstance(var, ResourceVariable) or hasattr(
   1033       var, "_should_act_as_resource_variable")
   1034