Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """A class to store named variables and a scope operator to manage sharing."""
     17 
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import collections as collections_lib
     23 import copy
     24 import enum  # pylint: disable=g-bad-import-order
     25 import functools
     26 import sys
     27 import traceback
     28 
     29 import six
     30 from six import iteritems
     31 from six.moves import xrange  # pylint: disable=redefined-builtin
     32 
     33 from tensorflow.python.eager import context
     34 from tensorflow.python.estimator import util as estimator_util
     35 from tensorflow.python.framework import dtypes
     36 from tensorflow.python.framework import ops
     37 from tensorflow.python.framework import tensor_shape
     38 from tensorflow.python.ops import array_ops
     39 from tensorflow.python.ops import init_ops
     40 from tensorflow.python.ops import resource_variable_ops
     41 from tensorflow.python.ops import variables
     42 from tensorflow.python.platform import tf_logging as logging
     43 from tensorflow.python.util import tf_contextlib
     44 from tensorflow.python.util.tf_export import tf_export
     45 
     46 __all__ = ["AUTO_REUSE", "VariableScope", "get_variable_scope",
     47            "get_variable", "get_local_variable", "variable_scope",
     48            "variable_op_scope", "no_regularizer"]
     49 
     50 
     51 class _PartitionInfo(object):
     52   """Holds partition info used by initializer functions.
     53   """
     54 
     55   def __init__(self, full_shape, var_offset):
     56     """Constructor.
     57 
     58     Args:
     59       full_shape: Tuple or list of `int` indicating the full combined shape
     60         of the partitioned variables.
     61       var_offset: Tuple or list of `int` specifying offset of this partition
     62         with respect to the full variable for each dimension.
     63 
     64     Raises:
     65       TypeError: If `full_shape` or `var_offset` is not a sequence.
     66       ValueError: If `full_shape` or `var_offset` differ in length. If
     67         `var_offset` exceeds `full_shape` in any dimension.
     68     """
     69     if not isinstance(full_shape, collections_lib.Sequence) or isinstance(
     70         full_shape, six.string_types):
     71       raise TypeError(
     72           "`full_shape` must be a sequence (like tuple or list) instead of " +
     73           type(full_shape).__name__)
     74 
     75     if not isinstance(var_offset, collections_lib.Sequence) or isinstance(
     76         var_offset, six.string_types):
     77       raise TypeError(
     78           "`var_offset` must be a sequence (like tuple or list) instead of " +
     79           type(var_offset).__name__)
     80 
     81     if len(var_offset) != len(full_shape):
     82       raise ValueError(
     83           "Expected equal length, but `var_offset` is of length {} while "
     84           "full_shape is of length {}.".format(
     85               len(var_offset), len(full_shape)))
     86 
     87     for i in xrange(len(full_shape)):
     88       offset = var_offset[i]
     89       shape = full_shape[i]
     90       if offset < 0 or offset >= shape:
     91         raise ValueError(
     92             "Expected 0 <= offset < shape but found offset={}, shape={} for "
     93             "var_offset={}, full_shape={}".format(offset, shape, var_offset,
     94                                                   full_shape))
     95 
     96     self._full_shape = full_shape
     97     self._var_offset = var_offset
     98 
     99   @property
    100   def full_shape(self):
    101     return self._full_shape
    102 
    103   @property
    104   def var_offset(self):
    105     return self._var_offset
    106 
    107   def single_offset(self, shape):
    108     """Returns the offset when the variable is partitioned in at most one dim.
    109 
    110     Args:
    111       shape: Tuple or list of `int` indicating the shape of one specific
    112         variable partition.
    113 
    114     Returns:
    115       `int` representing the offset in the dimension along which the variable is
    116        partitioned. Returns 0 if the variable is not being partitioned.
    117 
    118     Raises:
    119       ValueError: Depending on self.single_slice_dim().
    120     """
    121 
    122     single_slice_dim = self.single_slice_dim(shape)
    123     # If this variable is not being partitioned at all, single_slice_dim() could
    124     # return None.
    125     if single_slice_dim is None:
    126       return 0
    127     return self.var_offset[single_slice_dim]
    128 
    129   def single_slice_dim(self, shape):
    130     """Returns the slice dim when the variable is partitioned only in one dim.
    131 
    132     Args:
    133       shape: Tuple or list of `int` indicating the shape of one specific
    134         variable partition.
    135 
    136     Returns:
    137       `int` representing the dimension that the variable is partitioned in, or
    138       `None` if the variable doesn't seem to be partitioned at all.
    139 
    140     Raises:
    141       TypeError: If `shape` is not a sequence.
    142       ValueError: If `shape` is not the same length as `self.full_shape`. If
    143         the variable is partitioned in more than one dimension.
    144     """
    145     if not isinstance(shape, collections_lib.Sequence) or isinstance(
    146         shape, six.string_types):
    147       raise TypeError(
    148           "`shape` must be a sequence (like tuple or list) instead of " +
    149           type(shape).__name__)
    150 
    151     if len(shape) != len(self.full_shape):
    152       raise ValueError(
    153           "Expected equal length, but received shape={} of length {} while "
    154           "self.full_shape={} is of length {}.".format(shape, len(
    155               shape), self.full_shape, len(self.full_shape)))
    156 
    157     for i in xrange(len(shape)):
    158       if self.var_offset[i] + shape[i] > self.full_shape[i]:
    159         raise ValueError(
    160             "With self.var_offset={}, a partition of shape={} would exceed "
    161             "self.full_shape={} in dimension {}.".format(
    162                 self.var_offset, shape, self.full_shape, i))
    163 
    164     slice_dim = None
    165     for i in xrange(len(shape)):
    166       if shape[i] == self.full_shape[i]:
    167         continue
    168       if slice_dim is not None:
    169         raise ValueError(
    170             "Cannot use single_slice_dim() with shape={} and "
    171             "self.full_shape={} since slice dim could be either dimension {} "
    172             "or {}.".format(shape, self.full_shape, i, slice_dim))
    173       slice_dim = i
    174 
    175     return slice_dim
    176 
    177 
    178 class _ReuseMode(enum.Enum):
    179   """Mode for variable access within a variable scope."""
    180 
    181   # Indicates that variables are to be fetched if they already exist or
    182   # otherwise created.
    183   AUTO_REUSE = 1
    184 
    185   # TODO(alive): For TensorFlow 2.0, Deprecate True/False/None API in favor of
    186   #              enum values.
    187   # REUSE_FALSE = 2
    188   # REUSE_TRUE = 3
    189 
    190 AUTO_REUSE = _ReuseMode.AUTO_REUSE
    191 tf_export("AUTO_REUSE").export_constant(__name__, "AUTO_REUSE")
    192 AUTO_REUSE.__doc__ = """
    193 When passed in as the value for the `reuse` flag, AUTO_REUSE indicates that
    194 get_variable() should create the requested variable if it doesn't exist or, if
    195 it does exist, simply return it.
    196 """
    197 
    198 
    199 class _VariableStore(object):
    200   """Variable store that carries a number of named Variables.
    201 
    202   New variable names and new variables can be created; all stored
    203   variables are initialized with the initializer passed to __init__.
    204 
    205   Attributes:
    206     vars: a dictionary with string names (same as passed in GetVar) as keys
    207           and the corresponding TensorFlow Variables as values.
    208   """
    209 
    210   def __init__(self):
    211     """Create a variable store."""
    212     self._vars = {}  # A dictionary of the stored TensorFlow variables.
    213     self._partitioned_vars = {}  # A dict of the stored PartitionedVariables.
    214     self.variable_scopes_count = {}  # Count re-used variable scopes.
    215     self._store_eager_variables = False
    216 
    217   def open_variable_scope(self, scope_name):
    218     if scope_name in self.variable_scopes_count:
    219       self.variable_scopes_count[scope_name] += 1
    220     else:
    221       self.variable_scopes_count[scope_name] = 1
    222 
    223   def close_variable_subscopes(self, scope_name):
    224     for k in self.variable_scopes_count:
    225       if not scope_name or k.startswith(scope_name + "/"):
    226         self.variable_scopes_count[k] = 0
    227 
    228   def variable_scope_count(self, scope_name):
    229     return self.variable_scopes_count.get(scope_name, 0)
    230 
    231   def get_variable(self, name, shape=None, dtype=dtypes.float32,
    232                    initializer=None, regularizer=None, reuse=None,
    233                    trainable=True, collections=None, caching_device=None,
    234                    partitioner=None, validate_shape=True, use_resource=None,
    235                    custom_getter=None, constraint=None):
    236     """Gets an existing variable with these parameters or create a new one.
    237 
    238     If a variable with the given name is already stored, we return the stored
    239     variable. Otherwise, we create a new one.
    240 
    241     Set `reuse` to `True` when you only want to reuse existing Variables.
    242     Set `reuse` to `False` when you only want to create new Variables.
    243     Set `reuse` to None (the default) or tf.AUTO_REUSE when you want
    244     variables to be created if they don't exist or returned if they do.
    245 
    246     If initializer is `None` (the default), the default initializer passed in
    247     the constructor is used. If that one is `None` too, we use a new
    248     `glorot_uniform_initializer`. If initializer is a Tensor, we use
    249     it as a value and derive the shape from the initializer.
    250 
    251     If a partitioner is provided, a `PartitionedVariable` is returned.
    252     Accessing this object as a `Tensor` returns the shards concatenated along
    253     the partition axis.
    254 
    255     Some useful partitioners are available.  See, e.g.,
    256     `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
    257 
    258     Args:
    259       name: The name of the new or existing variable.
    260       shape: Shape of the new or existing variable.
    261       dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
    262       initializer: Initializer for the variable.
    263       regularizer: A (Tensor -> Tensor or None) function; the result of
    264         applying it on a newly created variable will be added to the collection
    265         GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
    266       reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation
    267         of variables. When eager execution is enabled  this argument is always
    268         forced to be False.
    269       trainable: If `True` also add the variable to the graph collection
    270         `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
    271       collections: List of graph collections keys to add the `Variable` to.
    272         Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
    273       caching_device: Optional device string or function describing where the
    274         Variable should be cached for reading.  Defaults to the Variable's
    275         device.  If not `None`, caches on another device.  Typical use is to
    276         cache on the device where the Ops using the `Variable` reside, to
    277         deduplicate copying through `Switch` and other conditional statements.
    278       partitioner: Optional callable that accepts a fully defined `TensorShape`
    279         and dtype of the `Variable` to be created, and returns a list of
    280         partitions for each axis (currently only one axis can be partitioned).
    281       validate_shape: If False, allows the variable to be initialized with a
    282         value of unknown shape. If True, the default, the shape of initial_value
    283         must be known.
    284       use_resource: If False, creates a regular Variable. If True, creates
    285         instead an experimental ResourceVariable which has well-defined
    286         semantics. Defaults to False (will later change to True).
    287         When eager execution is enabled this argument is always forced to be
    288         true.
    289       custom_getter: Callable that takes as a first argument the true getter,
    290         and allows overwriting the internal get_variable method.
    291         The signature of `custom_getter` should match that of this method,
    292         but the most future-proof version will allow for changes:
    293         `def custom_getter(getter, *args, **kwargs)`.  Direct access to
    294         all `get_variable` parameters is also allowed:
    295         `def custom_getter(getter, name, *args, **kwargs)`.  A simple identity
    296         custom getter that simply creates variables with modified names is:
    297         ```python
    298         def custom_getter(getter, name, *args, **kwargs):
    299           return getter(name + '_suffix', *args, **kwargs)
    300         ```
    301       constraint: An optional projection function to be applied to the variable
    302         after being updated by an `Optimizer` (e.g. used to implement norm
    303         constraints or value constraints for layer weights). The function must
    304         take as input the unprojected Tensor representing the value of the
    305         variable and return the Tensor for the projected value
    306         (which must have the same shape). Constraints are not safe to
    307         use when doing asynchronous distributed training.
    308 
    309     Returns:
    310       The created or existing `Variable` (or `PartitionedVariable`, if a
    311       partitioner was used).
    312 
    313     Raises:
    314       ValueError: when creating a new variable and shape is not declared,
    315         when reusing a variable and specifying a conflicting shape,
    316         or when violating reuse during variable creation.
    317       RuntimeError: when eager execution is enabled and not called from an
    318         EagerVariableStore.
    319     """
    320     if custom_getter is not None and not callable(custom_getter):
    321       raise ValueError(
    322           "Passed a custom_getter which is not callable: %s" % custom_getter)
    323 
    324     if context.in_eager_mode():
    325       if not self._store_eager_variables and reuse:
    326         raise RuntimeError(
    327             "When eager execution is enabled variable reuse is only supported"
    328             " when an EagerVariableStore is active. See the documentation on"
    329             " EagerVariableStore for example usage.")
    330       if self._store_eager_variables:
    331         reuse = AUTO_REUSE
    332       use_resource = True
    333 
    334     # If a *_ref type is passed in an error would be triggered further down the
    335     # stack. We prevent this using base_dtype to get a non-ref version of the
    336     # type, before doing anything else. When _ref types are removed in favor of
    337     # resources, this line can be removed.
    338     try:
    339       dtype = dtype.base_dtype
    340     except AttributeError:
    341       # .base_dtype not existing means that we will try and use the raw dtype
    342       # which was passed in - this might be a NumPy type which is valid.
    343       pass
    344 
    345     # This is the main logic of get_variable.  However, custom_getter
    346     # may override this logic.  So we save it as a callable and pass
    347     # it to custom_getter.
    348     # Note: the parameters of _true_getter, and their documentation, match
    349     # *exactly* item-for-item with the docstring of this method.
    350     def _true_getter(name, shape=None, dtype=dtypes.float32,  # pylint: disable=missing-docstring
    351                      initializer=None, regularizer=None, reuse=None,
    352                      trainable=True, collections=None, caching_device=None,
    353                      partitioner=None, validate_shape=True, use_resource=None,
    354                      constraint=None):
    355       is_scalar = (shape is not None
    356                    and isinstance(shape, collections_lib.Sequence)
    357                    and not shape)
    358       # Partitioned variable case
    359       if partitioner is not None and not is_scalar:
    360         if not callable(partitioner):
    361           raise ValueError(
    362               "Partitioner must be callable, but received: %s" % partitioner)
    363         with ops.name_scope(None):
    364           return self._get_partitioned_variable(name=name,
    365                                                 shape=shape,
    366                                                 dtype=dtype,
    367                                                 initializer=initializer,
    368                                                 regularizer=regularizer,
    369                                                 reuse=reuse,
    370                                                 trainable=trainable,
    371                                                 collections=collections,
    372                                                 caching_device=caching_device,
    373                                                 partitioner=partitioner,
    374                                                 validate_shape=validate_shape,
    375                                                 use_resource=use_resource,
    376                                                 constraint=constraint)
    377 
    378       # Special case for partitioned variable to allow reuse without having to
    379       # specify partitioner.
    380       if (reuse is True and partitioner is None
    381           and name in self._partitioned_vars):
    382         return self._get_partitioned_variable(name=name,
    383                                               shape=shape,
    384                                               dtype=dtype,
    385                                               initializer=initializer,
    386                                               regularizer=regularizer,
    387                                               reuse=reuse,
    388                                               trainable=trainable,
    389                                               collections=collections,
    390                                               caching_device=caching_device,
    391                                               partitioner=None,
    392                                               validate_shape=validate_shape,
    393                                               use_resource=use_resource,
    394                                               constraint=constraint)
    395 
    396       # Single variable case
    397       if "%s/part_0" % name in self._vars:
    398         raise ValueError(
    399             "No partitioner was provided, but a partitioned version of the "
    400             "variable was found: %s/part_0. Perhaps a variable of the same "
    401             "name was already created with partitioning?" % name)
    402 
    403       return self._get_single_variable(
    404           name=name, shape=shape, dtype=dtype,
    405           initializer=initializer, regularizer=regularizer, reuse=reuse,
    406           trainable=trainable, collections=collections,
    407           caching_device=caching_device, validate_shape=validate_shape,
    408           use_resource=use_resource, constraint=constraint)
    409 
    410     if custom_getter is not None:
    411       # Handle backwards compatibility with getter arguments that were added
    412       # to the API after users started writing custom getters.
    413       custom_getter_kwargs = {
    414           "getter": _true_getter,
    415           "name": name,
    416           "shape": shape,
    417           "dtype": dtype,
    418           "initializer": initializer,
    419           "regularizer": regularizer,
    420           "reuse": reuse,
    421           "trainable": trainable,
    422           "collections": collections,
    423           "caching_device": caching_device,
    424           "partitioner": partitioner,
    425           "validate_shape": validate_shape,
    426           "use_resource": use_resource,
    427       }
    428       # `fn_args` can handle functions, `functools.partial`, `lambda`.
    429       if "constraint" in estimator_util.fn_args(custom_getter):
    430         custom_getter_kwargs["constraint"] = constraint
    431       return custom_getter(**custom_getter_kwargs)
    432     else:
    433       return _true_getter(
    434           name, shape=shape, dtype=dtype,
    435           initializer=initializer, regularizer=regularizer,
    436           reuse=reuse, trainable=trainable, collections=collections,
    437           caching_device=caching_device, partitioner=partitioner,
    438           validate_shape=validate_shape, use_resource=use_resource,
    439           constraint=constraint)
    440 
    441   def _get_partitioned_variable(
    442       self, name, partitioner, shape=None, dtype=dtypes.float32,
    443       initializer=None, regularizer=None, reuse=None,
    444       trainable=True, collections=None, caching_device=None,
    445       validate_shape=True, use_resource=None, constraint=None):
    446     """Gets or creates a sharded variable list with these parameters.
    447 
    448     The `partitioner` must be a callable that accepts a fully defined
    449     `TensorShape` and returns a sequence of integers (the `partitions`).
    450     These integers describe how to partition the given sharded `Variable`
    451     along the given dimension.  That is, `partitions[1] = 3` means split
    452     the `Variable` into 3 shards along dimension 1.  Currently, sharding along
    453     only one axis is supported.
    454 
    455     If the list of variables with the given name (prefix) is already stored,
    456     we return the stored variables. Otherwise, we create a new one.
    457 
    458     Set `reuse` to `True` when you only want to reuse existing Variables.
    459     Set `reuse` to `False` when you only want to create new Variables.
    460     Set `reuse` to None (the default) or tf.AUTO_REUSE when you want
    461     variables to be created if they don't exist or returned if they do.
    462 
    463     If initializer is `None` (the default), the default initializer passed in
    464     the constructor is used. If that one is `None` too, we use a new
    465     `glorot_uniform_initializer`. If initializer is a Tensor, we use
    466     it as a value and derive the shape from the initializer.
    467 
    468     If the initializer is a callable, then it will be called for each
    469     shard.  Otherwise the initializer should match the shape of the entire
    470     sharded Variable, and it will be sliced accordingly for each shard.
    471 
    472     Some useful partitioners are available.  See, e.g.,
    473     `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
    474 
    475     Args:
    476       name: the name of the new or existing sharded variable.
    477       partitioner: Optional callable that accepts a fully defined `TensorShape`
    478         and `dtype` of the Variable to be created, and returns a list of
    479         partitions for each axis (currently only one axis can be partitioned).
    480       shape: shape of the new or existing sharded variable.
    481       dtype: type of the new or existing sharded variable
    482         (defaults to `DT_FLOAT`).
    483       initializer: initializer for the sharded variable.
    484       regularizer: a (Tensor -> Tensor or None) function; the result of
    485         applying it on a newly created variable will be added to the collection
    486         GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
    487       reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation
    488         of variables.
    489       trainable: If `True` also add the variable to the graph collection
    490         `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
    491       collections: List of graph collections keys to add the Variable to.
    492         Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
    493       caching_device: Optional device string or function describing where the
    494         Variable should be cached for reading.  Defaults to the Variable's
    495         device.  If not `None`, caches on another device.  Typical use is to
    496         cache on the device where the Ops using the Variable reside, to
    497         deduplicate copying through `Switch` and other conditional statements.
    498       validate_shape: If False, allows the variable to be initialized with a
    499         value of unknown shape. If True, the default, the shape of initial_value
    500         must be known.
    501       use_resource: If False, creates a regular Variable. If True, creates an
    502         experimental ResourceVariable which has well-defined semantics. Defaults
    503         to False (will later change to True).
    504       constraint: An optional projection function to be applied to the variable
    505         after being updated by an `Optimizer` (e.g. used to implement norm
    506         constraints or value constraints for layer weights). The function must
    507         take as input the unprojected Tensor representing the value of the
    508         variable and return the Tensor for the projected value
    509         (which must have the same shape). Constraints are not safe to
    510         use when doing asynchronous distributed training.
    511 
    512     Returns:
    513       A `PartitionedVariable` object.
    514 
    515     Raises:
    516       ValueError: when creating a new variable and shape is not declared,
    517         when reusing a variable and specifying a conflicting shape,
    518         when violating reuse during variable creation, or if an existing
    519         sharded variable exists for the given name but with different sharding.
    520     """
    521     if context.in_eager_mode():
    522       raise NotImplementedError("Partitioned variables are not yet supported "
    523                                 "when eager execution is enabled.")
    524 
    525     initializing_from_value = initializer is not None and isinstance(
    526         initializer, ops.Tensor)
    527     reuse_without_partition = reuse and not partitioner
    528 
    529     if name in self._vars:
    530       raise ValueError(
    531           "A partitioner was provided, but an unpartitioned version of the "
    532           "variable was found: %s.  Perhaps a variable of the same name was "
    533           "already created without partitioning?" % name)
    534 
    535     shape = tensor_shape.as_shape(shape)
    536     if initializing_from_value:
    537       shape = shape.merge_with(initializer.get_shape())
    538 
    539     if not reuse_without_partition:
    540       if not shape.is_fully_defined():
    541         raise ValueError("Shape of a new partitioned variable (%s) must be "
    542                          "fully defined, but instead was %s." % (name, shape))
    543 
    544       if shape.ndims < 1:
    545         raise ValueError("A partitioned Variable must have rank at least 1, "
    546                          "shape: %s" % shape)
    547 
    548       partitions = partitioner(shape=shape, dtype=dtype)
    549 
    550       if not isinstance(partitions, collections_lib.Sequence):
    551         raise ValueError("Partitioner must return a sequence, but saw: %s"
    552                          % partitions)
    553 
    554       if len(partitions) != shape.ndims:
    555         raise ValueError(
    556             "Partitioner returned a partition list that does not match the "
    557             "Variable's rank: %s vs. %s" % (partitions, shape))
    558 
    559       if any([p < 1 for p in partitions]):
    560         raise ValueError(
    561             "Partitioner returned zero partitions for some axes: %s" %
    562             partitions)
    563 
    564     if name in self._partitioned_vars:
    565       if reuse is False:
    566         raise ValueError(
    567             "Partitioned variable with name %s already exists. Did you mean to "
    568             "set reuse=True or reuse=tf.AUTO_REUSE in VarScope?"
    569             % name)
    570 
    571       existing_var = self._partitioned_vars[name]
    572       if not shape.is_compatible_with(existing_var.get_shape()):
    573         raise ValueError(
    574             "Trying to reuse partitioned variable %s, but specified shape %s "
    575             "and found shape %s."
    576             % (name, shape, existing_var.get_shape()))
    577       if not dtype.is_compatible_with(existing_var.dtype):
    578         raise ValueError(
    579             "Trying to reuse partitioned variable %s, but specified dtype %s "
    580             "and found dtype %s."
    581             % (name, dtype.name, existing_var.dtype.name))
    582 
    583       # pylint: disable=protected-access
    584       if (not reuse_without_partition and
    585           existing_var._get_partitions() != partitions):
    586         raise ValueError(
    587             "Trying to reuse partitioned variable %s, but specified partitions "
    588             "%s and found partitions %s." %
    589             (name, partitions, existing_var._get_partitions()))
    590       # pylint: enable=protected-access
    591 
    592       return existing_var
    593 
    594     if reuse is True:
    595       raise ValueError("PartitionedVariable %s does not exist, or was not "
    596                        "created with tf.get_variable(). Did you mean to set "
    597                        "reuse=False or reuse=tf.AUTO_REUSE in VarScope?" % name)
    598 
    599     slice_dim, slice_shape = _compute_slice_dim_and_shape(
    600         shape.as_list(), partitions)
    601 
    602     vs = []
    603     num_slices = partitions[slice_dim]
    604     num_slices_with_excess = shape[slice_dim].value % num_slices
    605 
    606     slice_offset = [0] * shape.ndims
    607 
    608     if "%s/part_0" % name in self._vars:
    609       if "%s/part_%d" % (name, num_slices - 1) not in self._vars:
    610         raise ValueError(
    611             "Partitioner returned a different partitioning than what was "
    612             "already found.  Partitioner returned %d shards, and shard "
    613             "%s/part_0 was found, but %s/part_%d was not."
    614             % (num_slices, name, name, num_slices - 1))
    615       if "%s/part_%d" % (name, num_slices) in self._vars:
    616         raise ValueError(
    617             "Partitioner returned a different partitioning than what was "
    618             "already found.  Partitioner returned %d shards, and shard "
    619             "%s/part_0 was found, but so was the extra shard %s/part_%d."
    620             % (num_slices, name, name, num_slices))
    621 
    622     for i in xrange(num_slices):
    623       var_shape = slice_shape[:]
    624       var_offset = slice_offset[:]
    625       partition_info = _PartitionInfo(
    626           full_shape=shape.as_list(), var_offset=var_offset)
    627       if i < num_slices_with_excess:
    628         var_shape[slice_dim] += 1
    629       slice_offset[slice_dim] += var_shape[slice_dim]
    630 
    631       var_full_name = "%s/part_%d" % (name, i)
    632       with ops.name_scope(var_full_name + "/PartitionedInitializer"):
    633         # Create the tensor to initialize the variable with default value.
    634         if initializer is None:
    635           init, initializing_from_value = self._get_default_initializer(
    636               name=name, shape=shape, dtype=dtype)
    637           if initializing_from_value:
    638             init_shape = None
    639           else:
    640             init_shape = var_shape
    641         elif callable(initializer):
    642           init = initializer
    643           init_shape = var_shape
    644         elif isinstance(initializer, ops.Tensor):
    645           init = array_ops.slice(initializer, var_offset, var_shape)
    646           # Use the dtype of the given tensor.
    647           dtype = init.dtype.base_dtype
    648           init_shape = None
    649         else:
    650           init = ops.convert_to_tensor(initializer, dtype=dtype)
    651           init = array_ops.slice(init, var_offset, var_shape)
    652           init_shape = None
    653 
    654       with ops.name_scope(None):
    655         var = self._get_single_variable(
    656             name=var_full_name,
    657             shape=init_shape,
    658             dtype=dtype,
    659             initializer=init,
    660             partition_info=partition_info,
    661             regularizer=regularizer,
    662             reuse=reuse,
    663             trainable=trainable,
    664             collections=collections,
    665             caching_device=caching_device,
    666             validate_shape=validate_shape,
    667             use_resource=use_resource,
    668             constraint=constraint)
    669 
    670       # pylint: disable=protected-access
    671       var._set_save_slice_info(variables.Variable.SaveSliceInfo(
    672           name, shape.as_list(), var_offset, var_shape))
    673       vs.append(var)
    674       # pylint: enable=protected-access
    675 
    676       # pylint: disable=protected-access
    677     partitioned_var = variables.PartitionedVariable(name=name,
    678                                                     shape=shape,
    679                                                     dtype=dtype,
    680                                                     variable_list=vs,
    681                                                     partitions=partitions)
    682     # pylint: enable=protected-access
    683 
    684     self._partitioned_vars[name] = partitioned_var
    685     return partitioned_var
    686 
    687   def _get_single_variable(self,
    688                            name,
    689                            shape=None,
    690                            dtype=dtypes.float32,
    691                            initializer=None,
    692                            regularizer=None,
    693                            partition_info=None,
    694                            reuse=None,
    695                            trainable=True,
    696                            collections=None,
    697                            caching_device=None,
    698                            validate_shape=True,
    699                            use_resource=None,
    700                            constraint=None):
    701     """Get or create a single Variable (e.g. a shard or entire variable).
    702 
    703     See the documentation of get_variable above (ignore partitioning components)
    704     for details.
    705 
    706     Args:
    707       name: see get_variable.
    708       shape: see get_variable.
    709       dtype: see get_variable.
    710       initializer: see get_variable.
    711       regularizer: see get_variable.
    712       partition_info: _PartitionInfo object.
    713       reuse: see get_variable.
    714       trainable: see get_variable.
    715       collections: see get_variable.
    716       caching_device: see get_variable.
    717       validate_shape: see get_variable.
    718       use_resource: see get_variable.
    719       constraint: see get_variable.
    720 
    721     Returns:
    722       A Variable.  See documentation of get_variable above.
    723 
    724     Raises:
    725       ValueError: See documentation of get_variable above.
    726     """
    727     # Set to true if initializer is a constant.
    728     initializing_from_value = False
    729     if initializer is not None and not callable(initializer):
    730       initializing_from_value = True
    731     if shape is not None and initializing_from_value:
    732       raise ValueError("If initializer is a constant, do not specify shape.")
    733 
    734     dtype = dtypes.as_dtype(dtype)
    735     shape = tensor_shape.as_shape(shape)
    736 
    737     if name in self._vars:
    738       # Here we handle the case when returning an existing variable.
    739       if reuse is False:
    740         tb = self._vars[name].op.traceback[::-1]
    741         # Throw away internal tf entries and only take a few lines.
    742         tb = [x for x in tb if "tensorflow/python" not in x[0]][:3]
    743         raise ValueError("Variable %s already exists, disallowed."
    744                          " Did you mean to set reuse=True or "
    745                          "reuse=tf.AUTO_REUSE in VarScope? "
    746                          "Originally defined at:\n\n%s" % (
    747                              name, "".join(traceback.format_list(tb))))
    748       found_var = self._vars[name]
    749       if not shape.is_compatible_with(found_var.get_shape()):
    750         raise ValueError("Trying to share variable %s, but specified shape %s"
    751                          " and found shape %s." % (name, shape,
    752                                                    found_var.get_shape()))
    753       if not dtype.is_compatible_with(found_var.dtype):
    754         dtype_str = dtype.name
    755         found_type_str = found_var.dtype.name
    756         raise ValueError("Trying to share variable %s, but specified dtype %s"
    757                          " and found dtype %s." % (name, dtype_str,
    758                                                    found_type_str))
    759       return found_var
    760 
    761     # The code below handles only the case of creating a new variable.
    762     if reuse is True:
    763       raise ValueError("Variable %s does not exist, or was not created with "
    764                        "tf.get_variable(). Did you mean to set "
    765                        "reuse=tf.AUTO_REUSE in VarScope?" % name)
    766     if not shape.is_fully_defined() and not initializing_from_value:
    767       raise ValueError("Shape of a new variable (%s) must be fully defined, "
    768                        "but instead was %s." % (name, shape))
    769 
    770     # Create the tensor to initialize the variable with default value.
    771     if initializer is None:
    772       initializer, initializing_from_value = self._get_default_initializer(
    773           name=name, shape=shape, dtype=dtype)
    774     # Enter an init scope when creating the initializer.
    775     with ops.init_scope():
    776       if initializing_from_value:
    777         init_val = initializer
    778         variable_dtype = None
    779       else:
    780         # Instantiate initializer if provided initializer is a type object.
    781         if isinstance(initializer, type(init_ops.Initializer)):
    782           initializer = initializer(dtype=dtype)
    783         init_val = lambda: initializer(  # pylint: disable=g-long-lambda
    784             shape.as_list(), dtype=dtype, partition_info=partition_info)
    785         variable_dtype = dtype.base_dtype
    786 
    787     # Create the variable.
    788     if use_resource is None:
    789       # Set the default value if unspecified.
    790       use_resource = False
    791     v = variable(
    792         initial_value=init_val,
    793         name=name,
    794         trainable=trainable,
    795         collections=collections,
    796         caching_device=caching_device,
    797         dtype=variable_dtype,
    798         validate_shape=validate_shape,
    799         constraint=constraint,
    800         use_resource=use_resource)
    801     if context.in_graph_mode() or self._store_eager_variables:
    802       # In eager mode we do not want to keep default references to Variable
    803       # objects as this will prevent their memory from being released.
    804       self._vars[name] = v
    805     logging.vlog(1, "Created variable %s with shape %s and init %s", v.name,
    806                  format(shape), initializer)
    807 
    808     # Run the regularizer if requested and save the resulting loss.
    809     if regularizer:
    810       with ops.colocate_with(v):
    811         with ops.name_scope(name + "/Regularizer/"):
    812           loss = regularizer(v)
    813         if loss is not None:
    814           if context.in_graph_mode():
    815             v_name = v.name
    816             loss_name = loss.name
    817           else:
    818             v_name = "v_%s" % type(v)
    819             loss_name = "loss_%s" % type(loss)
    820           logging.vlog(1, "Applied regularizer to %s and added the result %s "
    821                        "to REGULARIZATION_LOSSES.", v_name, loss_name)
    822           ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, loss)
    823     return v
    824 
    825   # Initialize variable when no initializer provided
    826   def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32):
    827     """Provide a default initializer and a corresponding value.
    828 
    829     Args:
    830       name: see get_variable.
    831       shape: see get_variable.
    832       dtype: see get_variable.
    833 
    834     Returns:
    835       initializer and initializing_from_value. See get_variable above.
    836 
    837     Raises:
    838       ValueError: When giving unsupported dtype.
    839     """
    840     del shape
    841     # If dtype is DT_FLOAT, provide a uniform unit scaling initializer
    842     if dtype.is_floating:
    843       initializer = init_ops.glorot_uniform_initializer()
    844       initializing_from_value = False
    845     # If dtype is DT_INT/DT_UINT, provide a default value `zero`
    846     # If dtype is DT_BOOL, provide a default value `FALSE`
    847     elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:
    848       initializer = init_ops.zeros_initializer()
    849       initializing_from_value = False
    850     # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
    851     else:
    852       raise ValueError("An initializer for variable %s of %s is required"
    853                        % (name, dtype.base_dtype))
    854 
    855     return initializer, initializing_from_value
    856 
    857 
    858 # To stop regularization, use this regularizer
    859 @tf_export("no_regularizer")
    860 def no_regularizer(_):
    861   """Use this function to prevent regularization of variables."""
    862   return None
    863 
    864 
    865 # TODO(alive): support caching devices and partitioned variables in Eager mode.
    866 @tf_export("VariableScope")
    867 class VariableScope(object):
    868   """Variable scope object to carry defaults to provide to `get_variable`.
    869 
    870   Many of the arguments we need for `get_variable` in a variable store are most
    871   easily handled with a context. This object is used for the defaults.
    872 
    873   Attributes:
    874     name: name of the current scope, used as prefix in get_variable.
    875     initializer: default initializer passed to get_variable.
    876     regularizer: default regularizer passed to get_variable.
    877     reuse: Boolean, None, or tf.AUTO_REUSE, setting the reuse in
    878       get_variable. When eager execution is enabled this argument is always
    879       forced to be False.
    880     caching_device: string, callable, or None: the caching device passed to
    881       get_variable.
    882     partitioner: callable or `None`: the partitioner passed to `get_variable`.
    883     custom_getter: default custom getter passed to get_variable.
    884     name_scope: The name passed to `tf.name_scope`.
    885     dtype: default type passed to get_variable (defaults to DT_FLOAT).
    886     use_resource: if False, create a normal Variable; if True create an
    887       experimental ResourceVariable with well-defined semantics. Defaults
    888       to False (will later change to True). When eager execution is enabled
    889       this argument is always forced to be True.
    890     constraint: An optional projection function to be applied to the variable
    891       after being updated by an `Optimizer` (e.g. used to implement norm
    892       constraints or value constraints for layer weights). The function must
    893       take as input the unprojected Tensor representing the value of the
    894       variable and return the Tensor for the projected value
    895       (which must have the same shape). Constraints are not safe to
    896       use when doing asynchronous distributed training.
    897   """
    898 
    899   def __init__(self,
    900                reuse,
    901                name="",
    902                initializer=None,
    903                regularizer=None,
    904                caching_device=None,
    905                partitioner=None,
    906                custom_getter=None,
    907                name_scope="",
    908                dtype=dtypes.float32,
    909                use_resource=None,
    910                constraint=None):
    911     """Creates a new VariableScope with the given properties."""
    912     self._name = name
    913     self._initializer = initializer
    914     self._regularizer = regularizer
    915     self._reuse = reuse
    916     self._caching_device = caching_device
    917     self._partitioner = partitioner
    918     self._custom_getter = custom_getter
    919     self._name_scope = name_scope
    920     self._dtype = dtype
    921     self._use_resource = use_resource
    922     self._constraint = constraint
    923     if context.in_eager_mode():
    924       if self._caching_device is not None:
    925         raise NotImplementedError("Caching devices is not yet supported "
    926                                   "when eager execution is enabled.")
    927       if self._partitioner is not None:
    928         raise NotImplementedError("Partitioned variables are not yet supported "
    929                                   "when eager execution is enabled.")
    930       self._reuse = AUTO_REUSE
    931       self._use_resource = True
    932 
    933   @property
    934   def name(self):
    935     return self._name
    936 
    937   @property
    938   def original_name_scope(self):
    939     return self._name_scope
    940 
    941   @property
    942   def reuse(self):
    943     return self._reuse
    944 
    945   @property
    946   def initializer(self):
    947     return self._initializer
    948 
    949   @property
    950   def dtype(self):
    951     return self._dtype
    952 
    953   @property
    954   def use_resource(self):
    955     return self._use_resource
    956 
    957   @property
    958   def regularizer(self):
    959     return self._regularizer
    960 
    961   @property
    962   def caching_device(self):
    963     return self._caching_device
    964 
    965   @property
    966   def partitioner(self):
    967     return self._partitioner
    968 
    969   @property
    970   def custom_getter(self):
    971     return self._custom_getter
    972 
    973   @property
    974   def constraint(self):
    975     return self._constraint
    976 
    977   def reuse_variables(self):
    978     """Reuse variables in this scope."""
    979     self._reuse = True
    980 
    981   def set_initializer(self, initializer):
    982     """Set initializer for this scope."""
    983     self._initializer = initializer
    984 
    985   def set_dtype(self, dtype):
    986     """Set data type for this scope."""
    987     self._dtype = dtype
    988 
    989   def set_use_resource(self, use_resource):
    990     """Sets whether to use ResourceVariables for this scope."""
    991     if context.in_eager_mode() and not use_resource:
    992       raise ValueError("When eager execution is enabled, "
    993                        "use_resource cannot be set to false.")
    994     self._use_resource = use_resource
    995 
    996   def set_regularizer(self, regularizer):
    997     """Set regularizer for this scope."""
    998     self._regularizer = regularizer
    999 
   1000   def set_caching_device(self, caching_device):
   1001     """Set caching_device for this scope."""
   1002     if context.in_eager_mode():
   1003       raise NotImplementedError("Caching devices are not yet supported "
   1004                                 "when eager execution is enabled.")
   1005     self._caching_device = caching_device
   1006 
   1007   def set_partitioner(self, partitioner):
   1008     """Set partitioner for this scope."""
   1009     if partitioner and context.in_eager_mode():
   1010       raise NotImplementedError("Partitioned variables are not yet supported "
   1011                                 "when eager execution is enabled.")
   1012     self._partitioner = partitioner
   1013 
   1014   def set_custom_getter(self, custom_getter):
   1015     """Set custom getter for this scope."""
   1016     self._custom_getter = custom_getter
   1017 
   1018   def get_collection(self, name):
   1019     """Get this scope's variables."""
   1020     scope = self._name + "/" if self._name else ""
   1021     return ops.get_collection(name, scope)
   1022 
   1023   def trainable_variables(self):
   1024     """Get this scope's trainable variables."""
   1025     return self.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
   1026 
   1027   def global_variables(self):
   1028     """Get this scope's global variables."""
   1029     return self.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
   1030 
   1031   def local_variables(self):
   1032     """Get this scope's local variables."""
   1033     return self.get_collection(ops.GraphKeys.LOCAL_VARIABLES)
   1034 
   1035   def get_variable(self,
   1036                    var_store,
   1037                    name,
   1038                    shape=None,
   1039                    dtype=None,
   1040                    initializer=None,
   1041                    regularizer=None,
   1042                    reuse=None,
   1043                    trainable=True,
   1044                    collections=None,
   1045                    caching_device=None,
   1046                    partitioner=None,
   1047                    validate_shape=True,
   1048                    use_resource=None,
   1049                    custom_getter=None,
   1050                    constraint=None):
   1051     """Gets an existing variable with this name or create a new one."""
   1052     if regularizer is None:
   1053       regularizer = self._regularizer
   1054     if caching_device is None:
   1055       caching_device = self._caching_device
   1056     if partitioner is None:
   1057       partitioner = self._partitioner
   1058     if custom_getter is None:
   1059       custom_getter = self._custom_getter
   1060     if context.in_graph_mode():
   1061       if reuse is None:
   1062         reuse = self._reuse
   1063       if use_resource is None:
   1064         use_resource = self._use_resource
   1065     else:
   1066       reuse = False
   1067       use_resource = True
   1068 
   1069     full_name = self.name + "/" + name if self.name else name
   1070     # Variable names only depend on variable_scope (full_name here),
   1071     # not name_scope, so we reset it below for the time of variable creation.
   1072     with ops.name_scope(None):
   1073       # Check that `initializer` dtype and `dtype` are consistent before
   1074       # replacing them with defaults.
   1075       if (dtype is not None and initializer is not None and
   1076           not callable(initializer)):
   1077         init_dtype = ops.convert_to_tensor(initializer).dtype.base_dtype
   1078         if init_dtype != dtype:
   1079           raise ValueError("Initializer type '%s' and explicit dtype '%s' "
   1080                            "don't match." % (init_dtype, dtype))
   1081       if initializer is None:
   1082         initializer = self._initializer
   1083       if constraint is None:
   1084         constraint = self._constraint
   1085       if dtype is None:
   1086         dtype = self._dtype
   1087       return var_store.get_variable(
   1088           full_name, shape=shape, dtype=dtype, initializer=initializer,
   1089           regularizer=regularizer, reuse=reuse, trainable=trainable,
   1090           collections=collections, caching_device=caching_device,
   1091           partitioner=partitioner, validate_shape=validate_shape,
   1092           use_resource=use_resource, custom_getter=custom_getter,
   1093           constraint=constraint)
   1094 
   1095   def _get_partitioned_variable(self,
   1096                                 var_store,
   1097                                 name,
   1098                                 shape=None,
   1099                                 dtype=None,
   1100                                 initializer=None,
   1101                                 regularizer=None,
   1102                                 trainable=True,
   1103                                 collections=None,
   1104                                 caching_device=None,
   1105                                 partitioner=None,
   1106                                 validate_shape=True,
   1107                                 use_resource=None,
   1108                                 constraint=None):
   1109     """Gets an existing variable with this name or create a new one."""
   1110     if context.in_eager_mode():
   1111       raise NotImplementedError("Partitioned variables are not yet supported "
   1112                                 "when eager execution is enabled.")
   1113     if initializer is None:
   1114       initializer = self._initializer
   1115     if regularizer is None:
   1116       regularizer = self._regularizer
   1117     if constraint is None:
   1118       constraint = self._constraint
   1119     if caching_device is None:
   1120       caching_device = self._caching_device
   1121     if partitioner is None:
   1122       partitioner = self._partitioner
   1123     if dtype is None:
   1124       dtype = self._dtype
   1125     if use_resource is None:
   1126       use_resource = self._use_resource
   1127 
   1128     if self._custom_getter is not None:
   1129       raise ValueError(
   1130           "Private access to _get_partitioned_variable is not allowed when "
   1131           "a custom getter is set.  Current custom getter: %s.  "
   1132           "It is likely that you're using create_partitioned_variables.  "
   1133           "If so, consider instead using get_variable with a non-empty "
   1134           "partitioner parameter instead." % self._custom_getter)
   1135 
   1136     if partitioner is None:
   1137       raise ValueError("No partitioner was specified")
   1138 
   1139     # This allows the variable scope name to be used as the variable name if
   1140     # this function is invoked with an empty name arg, for backward
   1141     # compatibility with create_partitioned_variables().
   1142     full_name_list = []
   1143     if self.name:
   1144       full_name_list.append(self.name)
   1145     if name:
   1146       full_name_list.append(name)
   1147     full_name = "/".join(full_name_list)
   1148 
   1149     # Variable names only depend on variable_scope (full_name here),
   1150     # not name_scope, so we reset it below for the time of variable creation.
   1151     with ops.name_scope(None):
   1152       # pylint: disable=protected-access
   1153       return var_store._get_partitioned_variable(
   1154           full_name, shape=shape, dtype=dtype, initializer=initializer,
   1155           regularizer=regularizer, reuse=self.reuse, trainable=trainable,
   1156           collections=collections, caching_device=caching_device,
   1157           partitioner=partitioner, validate_shape=validate_shape,
   1158           use_resource=use_resource, constraint=constraint)
   1159       # pylint: enable=protected-access
   1160 
   1161 
   1162 _VARSTORE_KEY = ("__variable_store",)
   1163 _VARSCOPE_KEY = ("__varscope",)
   1164 
   1165 
   1166 @tf_export("get_variable_scope")
   1167 def get_variable_scope():
   1168   """Returns the current variable scope."""
   1169   scope = ops.get_collection(_VARSCOPE_KEY)
   1170   if scope:  # This collection has at most 1 element, the default scope at [0].
   1171     return scope[0]
   1172   scope = VariableScope(False)
   1173   ops.add_to_collection(_VARSCOPE_KEY, scope)
   1174   return scope
   1175 
   1176 
   1177 def _get_default_variable_store():
   1178   store = ops.get_collection(_VARSTORE_KEY)
   1179   if store:
   1180     return store[0]
   1181   store = _VariableStore()
   1182   ops.add_to_collection(_VARSTORE_KEY, store)
   1183   return store
   1184 
   1185 
   1186 @tf_contextlib.contextmanager
   1187 def with_variable_store(store):
   1188   store_collection = ops.get_collection_ref(_VARSTORE_KEY)
   1189   old = list(store_collection)
   1190   store_collection[:] = [store]
   1191   try:
   1192     yield
   1193   finally:
   1194     store_collection[:] = old
   1195 
   1196 
   1197 class EagerVariableStore(object):
   1198   """Wrapper allowing functional layers to be used with eager execution.
   1199 
   1200   When eager execution is enabled Variables get deleted when they go out of
   1201   scope, and are not stored in global collections by default. A lot of code
   1202   (mostly the functional layers in tf.layers) assumes that variables are kept in
   1203   a global list.
   1204 
   1205   EagerVariableStore can be used in conjunction with this code to make it
   1206   eager-friendly. For example, to create a dense layer, use:
   1207 
   1208   ```
   1209     container = tfe.EagerVariableStore()
   1210     for input in dataset_iterator:
   1211       with container.as_default():
   1212         x = tf.layers.dense(input, name="l1")
   1213     print(container.variables)  # Should print the variables used in the layer.
   1214   ```
   1215   """
   1216 
   1217   def __init__(self, store=None):
   1218     if store is not None:
   1219       if not store._store_eager_variables:  # pylint: disable=protected-access
   1220         raise ValueError("Cannot construct EagerVariableStore from a "
   1221                          "VariableStore object that does not hold eager "
   1222                          "variables.")
   1223       self._store = store
   1224     else:
   1225       self._store = _VariableStore()
   1226     self._store._store_eager_variables = True  # pylint: disable=protected-access
   1227 
   1228   def as_default(self):
   1229     return with_variable_store(self._store)
   1230 
   1231   def variables(self):
   1232     return sorted(self._store._vars.values(), key=lambda x: x.name)  # pylint: disable=protected-access
   1233 
   1234   def trainable_variables(self):
   1235     # pylint: disable=protected-access
   1236     return sorted([x for x in self._store._vars.values() if x._trainable],
   1237                   key=lambda x: x.name)
   1238     # pylint: enable=protected-access
   1239 
   1240   def non_trainable_variables(self):
   1241     # pylint: disable=protected-access
   1242     return sorted([x for x in self._store._vars.values() if not x._trainable],
   1243                   key=lambda x: x.name)
   1244     # pylint: enable=protected-access
   1245 
   1246   def copy(self):
   1247     """Copy this variable store and all of its contents.
   1248 
   1249     Variables contained in this store will be copied over to the new variable
   1250     store, meaning that they can be modified without affecting the variables in
   1251     this store.
   1252 
   1253     Returns:
   1254       A new EagerVariableStore instance containing copied variables.
   1255     """
   1256     # pylint: disable=protected-access
   1257     new_store = EagerVariableStore()
   1258     for key, var in iteritems(self._store._vars):
   1259       # Strip device out of variable name.
   1260       try:
   1261         index = var.name.index(":")
   1262       except ValueError:
   1263         stripped_var_name = var.name
   1264       else:
   1265         stripped_var_name = var.name[:index]
   1266 
   1267       # Create new variable with same value, name, and "trainable" flag.
   1268       new_var = resource_variable_ops.ResourceVariable(
   1269           var.read_value(),
   1270           name=stripped_var_name,
   1271           trainable=var._trainable)
   1272       new_store._store._vars[key] = new_var
   1273     return new_store
   1274     # pylint: enable=protected-access
   1275 
   1276 
   1277 @tf_export("get_variable")
   1278 def get_variable(name,
   1279                  shape=None,
   1280                  dtype=None,
   1281                  initializer=None,
   1282                  regularizer=None,
   1283                  trainable=True,
   1284                  collections=None,
   1285                  caching_device=None,
   1286                  partitioner=None,
   1287                  validate_shape=True,
   1288                  use_resource=None,
   1289                  custom_getter=None,
   1290                  constraint=None):
   1291   return get_variable_scope().get_variable(
   1292       _get_default_variable_store(), name, shape=shape, dtype=dtype,
   1293       initializer=initializer, regularizer=regularizer, trainable=trainable,
   1294       collections=collections, caching_device=caching_device,
   1295       partitioner=partitioner, validate_shape=validate_shape,
   1296       use_resource=use_resource, custom_getter=custom_getter,
   1297       constraint=constraint)
   1298 get_variable_or_local_docstring = (
   1299     """%s
   1300 
   1301 %sThis function prefixes the name with the current variable scope
   1302 and performs reuse checks. See the
   1303 @{$variables$Variable Scope How To}
   1304 for an extensive description of how reusing works. Here is a basic example:
   1305 
   1306 ```python
   1307 def foo():
   1308   with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
   1309     v = tf.get_variable("v", [1])
   1310   return v
   1311 
   1312 v1 = foo()  # Creates v.
   1313 v2 = foo()  # Gets the same, existing v.
   1314 assert v1 == v2
   1315 ```
   1316 
   1317 If initializer is `None` (the default), the default initializer passed in
   1318 the variable scope will be used. If that one is `None` too, a
   1319 `glorot_uniform_initializer` will be used. The initializer can also be
   1320 a Tensor, in which case the variable is initialized to this value and shape.
   1321 
   1322 Similarly, if the regularizer is `None` (the default), the default regularizer
   1323 passed in the variable scope will be used (if that is `None` too,
   1324 then by default no regularization is performed).
   1325 
   1326 If a partitioner is provided, a `PartitionedVariable` is returned.
   1327 Accessing this object as a `Tensor` returns the shards concatenated along
   1328 the partition axis.
   1329 
   1330 Some useful partitioners are available.  See, e.g.,
   1331 `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
   1332 
   1333 Args:
   1334   name: The name of the new or existing variable.
   1335   shape: Shape of the new or existing variable.
   1336   dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
   1337   initializer: Initializer for the variable if one is created.
   1338   regularizer: A (Tensor -> Tensor or None) function; the result of
   1339     applying it on a newly created variable will be added to the collection
   1340     @{tf.GraphKeys.REGULARIZATION_LOSSES} and can be used for regularization.
   1341   %scollections: List of graph collections keys to add the Variable to.
   1342     Defaults to `[%s]` (see `tf.Variable`).
   1343   caching_device: Optional device string or function describing where the
   1344     Variable should be cached for reading.  Defaults to the Variable's
   1345     device.  If not `None`, caches on another device.  Typical use is to
   1346     cache on the device where the Ops using the Variable reside, to
   1347     deduplicate copying through `Switch` and other conditional statements.
   1348   partitioner: Optional callable that accepts a fully defined `TensorShape`
   1349     and `dtype` of the Variable to be created, and returns a list of
   1350     partitions for each axis (currently only one axis can be partitioned).
   1351   validate_shape: If False, allows the variable to be initialized with a
   1352       value of unknown shape. If True, the default, the shape of initial_value
   1353       must be known.
   1354   use_resource: If False, creates a regular Variable. If true, creates an
   1355     experimental ResourceVariable instead with well-defined semantics.
   1356     Defaults to False (will later change to True). When eager execution is
   1357     enabled this argument is always forced to be True.
   1358   custom_getter: Callable that takes as a first argument the true getter, and
   1359     allows overwriting the internal get_variable method.
   1360     The signature of `custom_getter` should match that of this method,
   1361     but the most future-proof version will allow for changes:
   1362     `def custom_getter(getter, *args, **kwargs)`.  Direct access to
   1363     all `get_variable` parameters is also allowed:
   1364     `def custom_getter(getter, name, *args, **kwargs)`.  A simple identity
   1365     custom getter that simply creates variables with modified names is:
   1366     ```python
   1367     def custom_getter(getter, name, *args, **kwargs):
   1368       return getter(name + '_suffix', *args, **kwargs)
   1369     ```
   1370 
   1371 Returns:
   1372   The created or existing `Variable` (or `PartitionedVariable`, if a
   1373   partitioner was used).
   1374 
   1375 Raises:
   1376   ValueError: when creating a new variable and shape is not declared,
   1377     when violating reuse during variable creation, or when `initializer` dtype
   1378     and `dtype` don't match. Reuse is set inside `variable_scope`.
   1379 """)
   1380 get_variable.__doc__ = get_variable_or_local_docstring % (
   1381     "Gets an existing variable with these parameters or create a new one.",
   1382     "",
   1383     "trainable: If `True` also add the variable to the graph collection\n"
   1384     "    `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).\n  ",
   1385     "GraphKeys.GLOBAL_VARIABLES")
   1386 
   1387 
   1388 @functools.wraps(get_variable)
   1389 @tf_export("get_local_variable")
   1390 def get_local_variable(*args, **kwargs):
   1391   kwargs["trainable"] = False
   1392   if "collections" in kwargs:
   1393     kwargs["collections"] += [ops.GraphKeys.LOCAL_VARIABLES]
   1394   else:
   1395     kwargs["collections"] = [ops.GraphKeys.LOCAL_VARIABLES]
   1396   return get_variable(*args, **kwargs)
   1397 get_local_variable.__doc__ = get_variable_or_local_docstring % (
   1398     "Gets an existing *local* variable or creates a new one.",
   1399     "Behavior is the same as in `get_variable`, except that variables are\n"
   1400     "added to the `LOCAL_VARIABLES` collection and `trainable` is set to\n"
   1401     "`False`.\n",
   1402     "",
   1403     "GraphKeys.LOCAL_VARIABLES")
   1404 
   1405 
   1406 def _get_partitioned_variable(name,
   1407                               shape=None,
   1408                               dtype=None,
   1409                               initializer=None,
   1410                               regularizer=None,
   1411                               trainable=True,
   1412                               collections=None,
   1413                               caching_device=None,
   1414                               partitioner=None,
   1415                               validate_shape=True,
   1416                               use_resource=None,
   1417                               constraint=None):
   1418   """Gets or creates a sharded variable list with these parameters.
   1419 
   1420   The `partitioner` must be a callable that accepts a fully defined
   1421   `TensorShape` and returns a sequence of integers (the `partitions`).
   1422   These integers describe how to partition the given sharded `Variable`
   1423   along the given dimension.  That is, `partitions[1] = 3` means split
   1424   the `Variable` into 3 shards along dimension 1.  Currently, sharding along
   1425   only one axis is supported.
   1426 
   1427   If the list of variables with the given name (prefix) is already stored,
   1428   we return the stored variables. Otherwise, we create a new one.
   1429 
   1430   If initializer is `None` (the default), the default initializer passed in
   1431   the constructor is used. If that one is `None` too, we use a new
   1432   `glorot_uniform_initializer`. If initializer is a Tensor, we use
   1433   it as a value and derive the shape from the initializer.
   1434 
   1435   If the initializer is a callable, then it will be called for each
   1436   shard.  Otherwise the initializer should match the shape of the entire
   1437   sharded Variable, and it will be sliced accordingly for each shard.
   1438 
   1439   Some useful partitioners are available.  See, e.g.,
   1440   `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
   1441 
   1442   Args:
   1443     name: The name of the new or existing variable.
   1444     shape: Shape of the new or existing variable.
   1445     dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
   1446     initializer: Initializer for the variable if one is created.
   1447     regularizer: A (Tensor -> Tensor or None) function; the result of
   1448       applying it on a newly created variable will be added to the collection
   1449       GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
   1450     trainable: If `True` also add the variable to the graph collection
   1451       `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
   1452     collections: List of graph collections keys to add the Variable to.
   1453       Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
   1454     caching_device: Optional device string or function describing where the
   1455       Variable should be cached for reading.  Defaults to the Variable's
   1456       device.  If not `None`, caches on another device.  Typical use is to
   1457       cache on the device where the Ops using the Variable reside, to
   1458       deduplicate copying through `Switch` and other conditional statements.
   1459     partitioner: Optional callable that accepts a fully defined `TensorShape`
   1460       and `dtype` of the Variable to be created, and returns a list of
   1461       partitions for each axis (currently only one axis can be partitioned).
   1462     validate_shape: If False, allows the variable to be initialized with a
   1463         value of unknown shape. If True, the default, the shape of initial_value
   1464         must be known.
   1465     use_resource: If False, creates a regular Variable. If True, creates an
   1466       experimental ResourceVariable instead which has well-defined semantics.
   1467       Defaults to False (will later change to True).
   1468     constraint: An optional projection function to be applied to the variable
   1469       after being updated by an `Optimizer` (e.g. used to implement norm
   1470       constraints or value constraints for layer weights). The function must
   1471       take as input the unprojected Tensor representing the value of the
   1472       variable and return the Tensor for the projected value
   1473       (which must have the same shape). Constraints are not safe to
   1474       use when doing asynchronous distributed training.
   1475 
   1476   Returns:
   1477     A tuple `(shards, partitions)` where `shards` is the list of `Variable`
   1478     shards and `partitions` is the output of the partitioner on the input
   1479     shape.
   1480 
   1481   Raises:
   1482     ValueError: when creating a new variable and shape is not declared,
   1483       or when violating reuse during variable creation. Reuse is set inside
   1484       `variable_scope`.
   1485   """
   1486   # pylint: disable=protected-access
   1487   scope = get_variable_scope()
   1488   if scope.custom_getter is not None:
   1489     raise ValueError(
   1490         "Private access to _get_partitioned_variable is not allowed when "
   1491         "a custom getter is set.  Current custom getter: %s.  "
   1492         "It is likely that you're using create_partitioned_variables.  "
   1493         "If so, consider instead using get_variable with a non-empty "
   1494         "partitioner parameter instead." % scope.custom_getter)
   1495   return scope._get_partitioned_variable(
   1496       _get_default_variable_store(), name, shape=shape, dtype=dtype,
   1497       initializer=initializer, regularizer=regularizer, trainable=trainable,
   1498       collections=collections, caching_device=caching_device,
   1499       partitioner=partitioner, validate_shape=validate_shape,
   1500       use_resource=use_resource, constraint=constraint)
   1501   # pylint: enable=protected-access
   1502 
   1503 
   1504 # Named like a function for compatibility with the previous
   1505 # @tf_contextlib.contextmanager definition.
   1506 class _pure_variable_scope(object):  # pylint: disable=invalid-name
   1507   """A context for the variable_scope, see `variable_scope` for docs."""
   1508 
   1509   def __init__(self,
   1510                name_or_scope,
   1511                reuse=None,
   1512                initializer=None,
   1513                regularizer=None,
   1514                caching_device=None,
   1515                partitioner=None,
   1516                custom_getter=None,
   1517                old_name_scope=None,
   1518                dtype=dtypes.float32,
   1519                use_resource=None,
   1520                constraint=None):
   1521     """Creates a context for the variable_scope, see `variable_scope` for docs.
   1522 
   1523     Note: this does not create a name scope.
   1524 
   1525     Args:
   1526       name_or_scope: `string` or `VariableScope`: the scope to open.
   1527       reuse: `True` or None, or tf.AUTO_REUSE; if `None`, we inherit the parent
   1528         scope's reuse flag.
   1529       initializer: default initializer for variables within this scope.
   1530       regularizer: default regularizer for variables within this scope.
   1531       caching_device: default caching device for variables within this scope.
   1532       partitioner: default partitioner for variables within this scope.
   1533       custom_getter: default custom getter for variables within this scope.
   1534       old_name_scope: the original name scope when re-entering a variable scope.
   1535       dtype: type of the variables within this scope (defaults to `DT_FLOAT`).
   1536       use_resource: If False, variables in this scope will be regular Variables.
   1537         If True, experimental ResourceVariables will be creates instead, with
   1538         well-defined semantics. Defaults to False (will later change to True).
   1539       constraint: An optional projection function to be applied to the variable
   1540         after being updated by an `Optimizer` (e.g. used to implement norm
   1541         constraints or value constraints for layer weights). The function must
   1542         take as input the unprojected Tensor representing the value of the
   1543         variable and return the Tensor for the projected value
   1544         (which must have the same shape). Constraints are not safe to
   1545         use when doing asynchronous distributed training.
   1546     """
   1547     self._name_or_scope = name_or_scope
   1548     self._reuse = reuse
   1549     self._initializer = initializer
   1550     self._regularizer = regularizer
   1551     self._caching_device = caching_device
   1552     self._partitioner = partitioner
   1553     self._custom_getter = custom_getter
   1554     self._old_name_scope = old_name_scope
   1555     self._dtype = dtype
   1556     self._use_resource = use_resource
   1557     self._constraint = constraint
   1558     get_variable_scope()  # Ensure that a default exists, then get a pointer.
   1559     # Get the reference to the collection as we want to modify it in place.
   1560     self._default_varscope = ops.get_collection_ref(_VARSCOPE_KEY)
   1561     self._var_store = _get_default_variable_store()
   1562     if isinstance(self._name_or_scope, VariableScope):
   1563       self._new_name = self._name_or_scope.name
   1564       name_scope = self._name_or_scope._name_scope  # pylint: disable=protected-access
   1565       # Handler for the case when we jump to a shared scope.  We create a new
   1566       #   VariableScope (self._var_scope_object) that contains a copy of the
   1567       #   provided shared scope, possibly with changed reuse and initializer, if
   1568       #   the user requested this.
   1569       variable_scope_object = VariableScope(
   1570           self._name_or_scope.reuse if not self._reuse else self._reuse,
   1571           name=self._new_name,
   1572           initializer=self._name_or_scope.initializer,
   1573           regularizer=self._name_or_scope.regularizer,
   1574           caching_device=self._name_or_scope.caching_device,
   1575           partitioner=self._name_or_scope.partitioner,
   1576           dtype=self._name_or_scope.dtype,
   1577           custom_getter=self._name_or_scope.custom_getter,
   1578           name_scope=name_scope,
   1579           use_resource=self._name_or_scope.use_resource,
   1580           constraint=self._constraint)
   1581       if self._initializer is not None:
   1582         variable_scope_object.set_initializer(self._initializer)
   1583       if self._regularizer is not None:
   1584         variable_scope_object.set_regularizer(self._regularizer)
   1585       if self._caching_device is not None:
   1586         variable_scope_object.set_caching_device(self._caching_device)
   1587       if self._partitioner is not None:
   1588         variable_scope_object.set_partitioner(self._partitioner)
   1589       if self._custom_getter is not None:
   1590         variable_scope_object.set_custom_getter(
   1591             _maybe_wrap_custom_getter(
   1592                 self._custom_getter, self._name_or_scope.custom_getter))
   1593       if self._dtype is not None:
   1594         variable_scope_object.set_dtype(self._dtype)
   1595       if self._use_resource is not None:
   1596         variable_scope_object.set_use_resource(self._use_resource)
   1597       self._cached_variable_scope_object = variable_scope_object
   1598 
   1599   def __enter__(self):
   1600     """Begins the scope block.
   1601 
   1602     Returns:
   1603       A VariableScope.
   1604     Raises:
   1605       ValueError: when trying to reuse within a create scope, or create within
   1606         a reuse scope, or if reuse is not `None` or `True`.
   1607       TypeError: when the types of some arguments are not appropriate.
   1608     """
   1609     self._old = self._default_varscope[0]
   1610     if isinstance(self._name_or_scope, VariableScope):
   1611       self._var_store.open_variable_scope(self._new_name)
   1612       self._old_subscopes = copy.copy(self._var_store.variable_scopes_count)
   1613       variable_scope_object = self._cached_variable_scope_object
   1614     else:
   1615       # Handler for the case when we just prolong current variable scope.
   1616       #   VariableScope with name extended by the provided one, and inherited
   1617       #   reuse and initializer (except if the user provided values to set).
   1618       self._new_name = (
   1619           self._old.name + "/" + self._name_or_scope if self._old.name
   1620           else self._name_or_scope)
   1621       self._reuse = (self._reuse
   1622                      or self._old.reuse)  # Re-using is inherited by sub-scopes.
   1623       if self._old_name_scope is None:
   1624         name_scope = self._name_or_scope
   1625       else:
   1626         name_scope = self._old_name_scope
   1627       variable_scope_object = VariableScope(
   1628           self._reuse,
   1629           name=self._new_name,
   1630           initializer=self._old.initializer,
   1631           regularizer=self._old.regularizer,
   1632           caching_device=self._old.caching_device,
   1633           partitioner=self._old.partitioner,
   1634           dtype=self._old.dtype,
   1635           use_resource=self._old.use_resource,
   1636           custom_getter=self._old.custom_getter,
   1637           name_scope=name_scope,
   1638           constraint=self._constraint)
   1639       if self._initializer is not None:
   1640         variable_scope_object.set_initializer(self._initializer)
   1641       if self._regularizer is not None:
   1642         variable_scope_object.set_regularizer(self._regularizer)
   1643       if self._caching_device is not None:
   1644         variable_scope_object.set_caching_device(self._caching_device)
   1645       if self._partitioner is not None:
   1646         variable_scope_object.set_partitioner(self._partitioner)
   1647       if self._custom_getter is not None:
   1648         variable_scope_object.set_custom_getter(
   1649             _maybe_wrap_custom_getter(self._custom_getter,
   1650                                       self._old.custom_getter))
   1651       if self._dtype is not None:
   1652         variable_scope_object.set_dtype(self._dtype)
   1653       if self._use_resource is not None:
   1654         variable_scope_object.set_use_resource(self._use_resource)
   1655       self._var_store.open_variable_scope(self._new_name)
   1656     self._default_varscope[0] = variable_scope_object
   1657     return variable_scope_object
   1658 
   1659   def __exit__(self, type_arg, value_arg, traceback_arg):
   1660     # If jumping out from a non-prolonged scope, restore counts.
   1661     if isinstance(self._name_or_scope, VariableScope):
   1662       self._var_store.variable_scopes_count = self._old_subscopes
   1663     else:
   1664       self._var_store.close_variable_subscopes(self._new_name)
   1665     self._default_varscope[0] = self._old
   1666 
   1667 
   1668 def _maybe_wrap_custom_getter(custom_getter, old_getter):
   1669   """Wrap a call to a custom_getter to use the old_getter internally."""
   1670   if old_getter is None:
   1671     return custom_getter
   1672 
   1673   # The new custom_getter should call the old one
   1674   def wrapped_custom_getter(getter, *args, **kwargs):
   1675     # Call:
   1676     #  custom_getter(
   1677     #    lambda: old_getter(true_getter, ...), *args, **kwargs)
   1678     # which means custom_getter will call old_getter, which
   1679     # will call the true_getter, perform any intermediate
   1680     # processing, and return the results to the current
   1681     # getter, which will also perform additional processing.
   1682     return custom_getter(
   1683         functools.partial(old_getter, getter),
   1684         *args, **kwargs)
   1685   return wrapped_custom_getter
   1686 
   1687 
   1688 def _get_unique_variable_scope(prefix):
   1689   """Get a name with the given prefix unique in the current variable scope."""
   1690   var_store = _get_default_variable_store()
   1691   current_scope = get_variable_scope()
   1692   name = current_scope.name + "/" + prefix if current_scope.name else prefix
   1693   if var_store.variable_scope_count(name) == 0:
   1694     return prefix
   1695   idx = 1
   1696   while var_store.variable_scope_count(name + ("_%d" % idx)) > 0:
   1697     idx += 1
   1698   return prefix + ("_%d" % idx)
   1699 
   1700 
   1701 # Named like a function for backwards compatibility with the
   1702 # @tf_contextlib.contextmanager version, which was switched to a class to avoid
   1703 # some object creation overhead.
   1704 @tf_export("variable_scope")  # pylint: disable=invalid-name
   1705 class variable_scope(object):
   1706   """A context manager for defining ops that creates variables (layers).
   1707 
   1708   This context manager validates that the (optional) `values` are from the same
   1709   graph, ensures that graph is the default graph, and pushes a name scope and a
   1710   variable scope.
   1711 
   1712   If `name_or_scope` is not None, it is used as is. If `scope` is None, then
   1713   `default_name` is used.  In that case, if the same name has been previously
   1714   used in the same scope, it will be made unique by appending `_N` to it.
   1715 
   1716   Variable scope allows you to create new variables and to share already created
   1717   ones while providing checks to not create or share by accident. For details,
   1718   see the @{$variables$Variable Scope How To}, here we present only a few basic
   1719   examples.
   1720 
   1721   Simple example of how to create a new variable:
   1722 
   1723   ```python
   1724   with tf.variable_scope("foo"):
   1725       with tf.variable_scope("bar"):
   1726           v = tf.get_variable("v", [1])
   1727           assert v.name == "foo/bar/v:0"
   1728   ```
   1729 
   1730   Basic example of sharing a variable AUTO_REUSE:
   1731 
   1732   ```python
   1733   def foo():
   1734     with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
   1735       v = tf.get_variable("v", [1])
   1736     return v
   1737 
   1738   v1 = foo()  # Creates v.
   1739   v2 = foo()  # Gets the same, existing v.
   1740   assert v1 == v2
   1741   ```
   1742 
   1743   Basic example of sharing a variable with reuse=True:
   1744 
   1745   ```python
   1746   with tf.variable_scope("foo"):
   1747       v = tf.get_variable("v", [1])
   1748   with tf.variable_scope("foo", reuse=True):
   1749       v1 = tf.get_variable("v", [1])
   1750   assert v1 == v
   1751   ```
   1752 
   1753   Sharing a variable by capturing a scope and setting reuse:
   1754 
   1755   ```python
   1756   with tf.variable_scope("foo") as scope:
   1757       v = tf.get_variable("v", [1])
   1758       scope.reuse_variables()
   1759       v1 = tf.get_variable("v", [1])
   1760   assert v1 == v
   1761   ```
   1762 
   1763   To prevent accidental sharing of variables, we raise an exception when getting
   1764   an existing variable in a non-reusing scope.
   1765 
   1766   ```python
   1767   with tf.variable_scope("foo"):
   1768       v = tf.get_variable("v", [1])
   1769       v1 = tf.get_variable("v", [1])
   1770       #  Raises ValueError("... v already exists ...").
   1771   ```
   1772 
   1773   Similarly, we raise an exception when trying to get a variable that does not
   1774   exist in reuse mode.
   1775 
   1776   ```python
   1777   with tf.variable_scope("foo", reuse=True):
   1778       v = tf.get_variable("v", [1])
   1779       #  Raises ValueError("... v does not exists ...").
   1780   ```
   1781 
   1782   Note that the `reuse` flag is inherited: if we open a reusing scope, then all
   1783   its sub-scopes become reusing as well.
   1784 
   1785   A note about name scoping: Setting `reuse` does not impact the naming of other
   1786   ops such as mult. See related discussion on
   1787   [github#6189](https://github.com/tensorflow/tensorflow/issues/6189)
   1788 
   1789   Note that up to and including version 1.0, it was allowed (though explicitly
   1790   discouraged) to pass False to the reuse argument, yielding undocumented
   1791   behaviour slightly different from None. Starting at 1.1.0 passing None and
   1792   False as reuse has exactly the same effect.
   1793   """
   1794 
   1795   def __init__(self,
   1796                name_or_scope,
   1797                default_name=None,
   1798                values=None,
   1799                initializer=None,
   1800                regularizer=None,
   1801                caching_device=None,
   1802                partitioner=None,
   1803                custom_getter=None,
   1804                reuse=None,
   1805                dtype=None,
   1806                use_resource=None,
   1807                constraint=None,
   1808                auxiliary_name_scope=True):
   1809     """Initialize the context manager.
   1810 
   1811     Args:
   1812       name_or_scope: `string` or `VariableScope`: the scope to open.
   1813       default_name: The default name to use if the `name_or_scope` argument is
   1814         `None`, this name will be uniquified. If name_or_scope is provided it
   1815         won't be used and therefore it is not required and can be None.
   1816       values: The list of `Tensor` arguments that are passed to the op function.
   1817       initializer: default initializer for variables within this scope.
   1818       regularizer: default regularizer for variables within this scope.
   1819       caching_device: default caching device for variables within this scope.
   1820       partitioner: default partitioner for variables within this scope.
   1821       custom_getter: default custom getter for variables within this scope.
   1822       reuse: `True`, None, or tf.AUTO_REUSE; if `True`, we go into reuse mode
   1823         for this scope as well as all sub-scopes; if tf.AUTO_REUSE, we create
   1824         variables if they do not exist, and return them otherwise; if None, we
   1825         inherit the parent scope's reuse flag. When eager execution is enabled,
   1826         this argument is always forced to be tf.AUTO_REUSE.
   1827       dtype: type of variables created in this scope (defaults to the type
   1828         in the passed scope, or inherited from parent scope).
   1829       use_resource: If False, all variables will be regular Variables. If True,
   1830         experimental ResourceVariables with well-defined semantics will be used
   1831         instead. Defaults to False (will later change to True). When eager
   1832         execution is enabled this argument is always forced to be True.
   1833       constraint: An optional projection function to be applied to the variable
   1834         after being updated by an `Optimizer` (e.g. used to implement norm
   1835         constraints or value constraints for layer weights). The function must
   1836         take as input the unprojected Tensor representing the value of the
   1837         variable and return the Tensor for the projected value
   1838         (which must have the same shape). Constraints are not safe to
   1839         use when doing asynchronous distributed training.
   1840       auxiliary_name_scope: If `True`, we create an auxiliary name scope with
   1841         the scope. If `False`, we don't touch name scope.
   1842 
   1843     Returns:
   1844       A scope that can be captured and reused.
   1845 
   1846     Raises:
   1847       ValueError: when trying to reuse within a create scope, or create within
   1848         a reuse scope.
   1849       TypeError: when the types of some arguments are not appropriate.
   1850     """
   1851     self._name_or_scope = name_or_scope
   1852     self._default_name = default_name
   1853     self._values = values
   1854     self._initializer = initializer
   1855     self._regularizer = regularizer
   1856     self._caching_device = caching_device
   1857     self._partitioner = partitioner
   1858     self._custom_getter = custom_getter
   1859     self._reuse = reuse
   1860     self._dtype = dtype
   1861     self._use_resource = use_resource
   1862     self._constraint = constraint
   1863     if self._default_name is None and self._name_or_scope is None:
   1864       raise TypeError("If default_name is None then name_or_scope is required")
   1865     if self._reuse is False:
   1866       # We don't allow non-inheriting scopes, False = None here.
   1867       self._reuse = None
   1868     if not (self._reuse is True
   1869             or self._reuse is None
   1870             or self._reuse is AUTO_REUSE):
   1871       raise ValueError("The reuse parameter must be True or False or None.")
   1872     if self._values is None:
   1873       self._values = []
   1874     self._in_graph_mode = not context.in_eager_mode()
   1875     if self._in_graph_mode:
   1876       self._graph = ops._get_graph_from_inputs(self._values)  # pylint: disable=protected-access
   1877     self._cached_pure_variable_scope = None
   1878     self._current_name_scope = None
   1879     if not isinstance(auxiliary_name_scope, bool):
   1880       raise TypeError("The auxiliary_name_scope must be `True` or `False`, "
   1881                       "while get {}".format(auxiliary_name_scope))
   1882     self._auxiliary_name_scope = auxiliary_name_scope
   1883 
   1884   def __enter__(self):
   1885     # If the default graph is building a function, then we should not replace it
   1886     # with the cached graph.
   1887     if ops.get_default_graph().building_function:
   1888       self._building_function = True
   1889     else:
   1890       self._building_function = False
   1891     if self._in_graph_mode and not self._building_function:
   1892       self._graph_context_manager = self._graph.as_default()
   1893       self._graph_context_manager.__enter__()
   1894     if self._cached_pure_variable_scope is not None:
   1895       # Fast path for re-entering variable_scopes. We've held on to the pure
   1896       # variable scope from a previous successful __enter__, so we avoid some
   1897       # overhead by re-using that object.
   1898       if self._current_name_scope is not None:
   1899         self._current_name_scope.__enter__()
   1900       return self._cached_pure_variable_scope.__enter__()
   1901 
   1902     try:
   1903       return self._enter_scope_uncached()
   1904     except:
   1905       if self._graph_context_manager is not None:
   1906         self._graph_context_manager.__exit__(*sys.exc_info())
   1907       raise
   1908 
   1909   def _enter_scope_uncached(self):
   1910     """Enters the context manager when there is no cached scope yet.
   1911 
   1912     Returns:
   1913       The entered variable scope.
   1914 
   1915     Raises:
   1916       TypeError: A wrong type is passed as `scope` at __init__().
   1917       ValueError: `reuse` is incorrectly set at __init__().
   1918     """
   1919     if self._auxiliary_name_scope:
   1920       # Create a new name scope later
   1921       current_name_scope = None
   1922     else:
   1923       # Reenter the current name scope
   1924       name_scope = ops.get_name_scope()
   1925       if name_scope:
   1926         # Hack to reenter
   1927         name_scope += "/"
   1928         current_name_scope = ops.name_scope(name_scope)
   1929       else:
   1930         # Root scope
   1931         current_name_scope = ops.name_scope(name_scope)
   1932 
   1933     # IMPORTANT: Only assign to self._cached_pure_variable_scope and
   1934     # self._current_name_scope after successful __enter__() calls.
   1935     if self._name_or_scope is not None:
   1936       if not isinstance(self._name_or_scope,
   1937                         (VariableScope,) + six.string_types):
   1938         raise TypeError("VariableScope: name_or_scope must be a string or "
   1939                         "VariableScope.")
   1940       if isinstance(self._name_or_scope, six.string_types):
   1941         name_scope = self._name_or_scope
   1942       else:
   1943         name_scope = self._name_or_scope.name.split("/")[-1]
   1944       if name_scope or current_name_scope:
   1945         current_name_scope = current_name_scope or ops.name_scope(name_scope)
   1946         try:
   1947           current_name_scope_name = current_name_scope.__enter__()
   1948         except:
   1949           current_name_scope.__exit__(*sys.exc_info())
   1950           raise
   1951         self._current_name_scope = current_name_scope
   1952         if isinstance(self._name_or_scope, six.string_types):
   1953           old_name_scope = current_name_scope_name
   1954         else:
   1955           old_name_scope = self._name_or_scope.original_name_scope
   1956         pure_variable_scope = _pure_variable_scope(
   1957             self._name_or_scope,
   1958             reuse=self._reuse,
   1959             initializer=self._initializer,
   1960             regularizer=self._regularizer,
   1961             caching_device=self._caching_device,
   1962             partitioner=self._partitioner,
   1963             custom_getter=self._custom_getter,
   1964             old_name_scope=old_name_scope,
   1965             dtype=self._dtype,
   1966             use_resource=self._use_resource,
   1967             constraint=self._constraint)
   1968         try:
   1969           entered_pure_variable_scope = pure_variable_scope.__enter__()
   1970         except:
   1971           pure_variable_scope.__exit__(*sys.exc_info())
   1972           raise
   1973         self._cached_pure_variable_scope = pure_variable_scope
   1974         return entered_pure_variable_scope
   1975       else:
   1976         self._current_name_scope = None
   1977         # This can only happen if someone is entering the root variable scope.
   1978         pure_variable_scope = _pure_variable_scope(
   1979             self._name_or_scope,
   1980             reuse=self._reuse,
   1981             initializer=self._initializer,
   1982             regularizer=self._regularizer,
   1983             caching_device=self._caching_device,
   1984             partitioner=self._partitioner,
   1985             custom_getter=self._custom_getter,
   1986             dtype=self._dtype,
   1987             use_resource=self._use_resource,
   1988             constraint=self._constraint)
   1989         try:
   1990           entered_pure_variable_scope = pure_variable_scope.__enter__()
   1991         except:
   1992           pure_variable_scope.__exit__(*sys.exc_info())
   1993           raise
   1994         self._cached_pure_variable_scope = pure_variable_scope
   1995         return entered_pure_variable_scope
   1996 
   1997     else:  # Here name_or_scope is None. Using default name, but made unique.
   1998       if self._reuse:
   1999         raise ValueError("reuse=True cannot be used without a name_or_scope")
   2000       current_name_scope = current_name_scope or ops.name_scope(
   2001           self._default_name)
   2002       try:
   2003         current_name_scope_name = current_name_scope.__enter__()
   2004       except:
   2005         current_name_scope.__exit__(*sys.exc_info())
   2006         raise
   2007       self._current_name_scope = current_name_scope
   2008       unique_default_name = _get_unique_variable_scope(self._default_name)
   2009       pure_variable_scope = _pure_variable_scope(
   2010           unique_default_name,
   2011           initializer=self._initializer,
   2012           regularizer=self._regularizer,
   2013           caching_device=self._caching_device,
   2014           partitioner=self._partitioner,
   2015           custom_getter=self._custom_getter,
   2016           old_name_scope=current_name_scope_name,
   2017           dtype=self._dtype,
   2018           use_resource=self._use_resource,
   2019           constraint=self._constraint)
   2020       try:
   2021         entered_pure_variable_scope = pure_variable_scope.__enter__()
   2022       except:
   2023         pure_variable_scope.__exit__(*sys.exc_info())
   2024         raise
   2025       self._cached_pure_variable_scope = pure_variable_scope
   2026       return entered_pure_variable_scope
   2027 
   2028   def __exit__(self, type_arg, value_arg, traceback_arg):
   2029     self._cached_pure_variable_scope.__exit__(
   2030         type_arg, value_arg, traceback_arg)
   2031     if self._current_name_scope:
   2032       self._current_name_scope.__exit__(type_arg, value_arg, traceback_arg)
   2033     if self._in_graph_mode and not self._building_function:
   2034       self._graph_context_manager.__exit__(type_arg, value_arg, traceback_arg)
   2035 
   2036 
   2037 # pylint: disable=g-doc-return-or-yield
   2038 @tf_export("variable_op_scope")
   2039 @tf_contextlib.contextmanager
   2040 def variable_op_scope(values,
   2041                       name_or_scope,
   2042                       default_name=None,
   2043                       initializer=None,
   2044                       regularizer=None,
   2045                       caching_device=None,
   2046                       partitioner=None,
   2047                       custom_getter=None,
   2048                       reuse=None,
   2049                       dtype=None,
   2050                       use_resource=None,
   2051                       constraint=None):
   2052   """Deprecated: context manager for defining an op that creates variables."""
   2053   logging.warn("tf.variable_op_scope(values, name, default_name) is deprecated,"
   2054                " use tf.variable_scope(name, default_name, values)")
   2055   with variable_scope(name_or_scope,
   2056                       default_name=default_name,
   2057                       values=values,
   2058                       initializer=initializer,
   2059                       regularizer=regularizer,
   2060                       caching_device=caching_device,
   2061                       partitioner=partitioner,
   2062                       custom_getter=custom_getter,
   2063                       reuse=reuse,
   2064                       dtype=dtype,
   2065                       use_resource=use_resource,
   2066                       constraint=constraint) as scope:
   2067     yield scope
   2068 
   2069 
   2070 def _compute_slice_dim_and_shape(full_shape, slicing):
   2071   """Computes which dimension is being sliced and the typical slice shape."""
   2072 
   2073   slice_shape = [0] * len(full_shape)
   2074   slice_dim = None
   2075   for dim, num_slices in enumerate(slicing):
   2076     dim_size = full_shape[dim]
   2077     if num_slices <= 0 or dim_size < num_slices:
   2078       raise ValueError("Cannot create %d slices for size %d. shape: %s, "
   2079                        "slicing: %s" %
   2080                        (num_slices, full_shape[dim], full_shape, slicing))
   2081     if num_slices == 1:
   2082       # Not slicing in this dimension.
   2083       slice_shape[dim] = dim_size
   2084     elif slice_dim is not None:
   2085       # We only support slicing along one of the dimensions.
   2086       raise ValueError("Can only slice a variable along one dimension: "
   2087                        "shape: %s, slicing: %s" % (full_shape, slicing))
   2088     else:
   2089       # Note: We will add any extras onto the last slice, later.
   2090       slice_dim = dim
   2091       slice_shape[dim] = dim_size // num_slices
   2092 
   2093   # Degenerate case: If "slicing" was all ones, pretend we are slicing along
   2094   # the first dimension.
   2095   if slice_dim is None:
   2096     slice_dim = 0
   2097   return slice_dim, slice_shape
   2098 
   2099 
   2100 def default_variable_creator(next_creator=None, **kwargs):
   2101   """Default variable creator."""
   2102   assert next_creator is None
   2103   initial_value = kwargs.get("initial_value", None)
   2104   trainable = kwargs.get("trainable", True)
   2105   collections = kwargs.get("collections", None)
   2106   validate_shape = kwargs.get("validate_shape", True)
   2107   caching_device = kwargs.get("caching_device", None)
   2108   name = kwargs.get("name", None)
   2109   dtype = kwargs.get("dtype", None)
   2110   constraint = kwargs.get("constraint", None)
   2111   use_resource = kwargs.get("use_resource", None)
   2112   if use_resource is None:
   2113     use_resource = get_variable_scope().use_resource
   2114   if use_resource or (use_resource is None and context.in_eager_mode()):
   2115     return resource_variable_ops.ResourceVariable(
   2116         initial_value=initial_value, trainable=trainable,
   2117         collections=collections, validate_shape=validate_shape,
   2118         caching_device=caching_device, name=name, dtype=dtype,
   2119         constraint=constraint)
   2120   elif not use_resource and context.in_eager_mode():
   2121     raise RuntimeError(
   2122         "VariableScope should use resource variable when eager execution is"
   2123         " enabled, but use_resource is False."
   2124     )
   2125   else:
   2126     return variables.Variable(
   2127         initial_value=initial_value, trainable=trainable,
   2128         collections=collections, validate_shape=validate_shape,
   2129         caching_device=caching_device, name=name, dtype=dtype,
   2130         constraint=constraint)
   2131 
   2132 
   2133 def _make_getter(captured_getter, captured_previous):
   2134   """Gets around capturing loop variables in python being broken."""
   2135   return lambda **kwargs: captured_getter(captured_previous, **kwargs)
   2136 
   2137 
   2138 def variable(initial_value=None,
   2139              trainable=True,
   2140              collections=None,
   2141              validate_shape=True,
   2142              caching_device=None,
   2143              name=None,
   2144              dtype=None,
   2145              constraint=None,
   2146              use_resource=None):
   2147   previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
   2148   for getter in ops.get_default_graph()._get_variable_creator_stack():  # pylint: disable=protected-access
   2149     previous_getter = _make_getter(getter, previous_getter)
   2150   return previous_getter(initial_value=initial_value,
   2151                          trainable=trainable,
   2152                          collections=collections,
   2153                          validate_shape=validate_shape,
   2154                          caching_device=caching_device,
   2155                          name=name, dtype=dtype,
   2156                          constraint=constraint,
   2157                          use_resource=use_resource)
   2158 
   2159 
   2160 @tf_contextlib.contextmanager
   2161 def variable_creator_scope(variable_creator):
   2162   """Scope which defines a variable creation function to be used by variable().
   2163 
   2164   variable_creator is expected to be a function with the following signature:
   2165 
   2166   ```
   2167     def variable_creator(next_creator, **kwargs)
   2168   ```
   2169 
   2170   The creator is supposed to eventually call the next_creator to create a
   2171   variable if it does want to create a variable and not call Variable or
   2172   ResourceVariable directly. This helps make creators composable. A creator may
   2173   choose to create multiple variables, return already existing variables, or
   2174   simply register that a variable was created and defer to the next creators in
   2175   line. Creators can also modify the keyword arguments seen by the next
   2176   creators.
   2177 
   2178   Custom getters in the variable scope will eventually resolve down to these
   2179   custom creators when they do create variables.
   2180 
   2181   The valid keyword arguments in kwds are:
   2182       initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
   2183         which is the initial value for the Variable. The initial value must have
   2184         a shape specified unless `validate_shape` is set to False. Can also be a
   2185         callable with no argument that returns the initial value when called. In
   2186         that case, `dtype` must be specified. (Note that initializer functions
   2187         from init_ops.py must first be bound to a shape before being used here.)
   2188       trainable: If `True`, the default, also adds the variable to the graph
   2189         collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
   2190         the default list of variables to use by the `Optimizer` classes.
   2191       collections: List of graph collections keys. The new variable is added to
   2192         these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
   2193       validate_shape: If `False`, allows the variable to be initialized with a
   2194         value of unknown shape. If `True`, the default, the shape of
   2195         `initial_value` must be known.
   2196       caching_device: Optional device string describing where the Variable
   2197         should be cached for reading.  Defaults to the Variable's device.
   2198         If not `None`, caches on another device.  Typical use is to cache
   2199         on the device where the Ops using the Variable reside, to deduplicate
   2200         copying through `Switch` and other conditional statements.
   2201       name: Optional name for the variable. Defaults to `'Variable'` and gets
   2202         uniquified automatically.
   2203       dtype: If set, initial_value will be converted to the given type.
   2204         If `None`, either the datatype will be kept (if `initial_value` is
   2205         a Tensor), or `convert_to_tensor` will decide.
   2206       constraint: A constraint function to be applied to the variable after
   2207         updates by some algorithms.
   2208       use_resource: if True, a ResourceVariable is always created.
   2209 
   2210   This set may grow over time, so it's important the signature of creators is as
   2211   mentioned above.
   2212 
   2213   Args:
   2214     variable_creator: the passed creator
   2215 
   2216   Yields:
   2217     A scope in which the creator is active
   2218   """
   2219   with ops.get_default_graph()._variable_creator_scope(variable_creator):  # pylint: disable=protected-access
   2220     yield
   2221