1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """A class to store named variables and a scope operator to manage sharing.""" 17 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import collections as collections_lib 23 import copy 24 import enum # pylint: disable=g-bad-import-order 25 import functools 26 import sys 27 import traceback 28 29 import six 30 from six import iteritems 31 from six.moves import xrange # pylint: disable=redefined-builtin 32 33 from tensorflow.python.eager import context 34 from tensorflow.python.estimator import util as estimator_util 35 from tensorflow.python.framework import dtypes 36 from tensorflow.python.framework import ops 37 from tensorflow.python.framework import tensor_shape 38 from tensorflow.python.ops import array_ops 39 from tensorflow.python.ops import init_ops 40 from tensorflow.python.ops import resource_variable_ops 41 from tensorflow.python.ops import variables 42 from tensorflow.python.platform import tf_logging as logging 43 from tensorflow.python.util import tf_contextlib 44 from tensorflow.python.util.tf_export import tf_export 45 46 __all__ = ["AUTO_REUSE", "VariableScope", "get_variable_scope", 47 "get_variable", "get_local_variable", "variable_scope", 48 "variable_op_scope", "no_regularizer"] 49 50 51 class _PartitionInfo(object): 52 """Holds partition info used by initializer functions. 53 """ 54 55 def __init__(self, full_shape, var_offset): 56 """Constructor. 57 58 Args: 59 full_shape: Tuple or list of `int` indicating the full combined shape 60 of the partitioned variables. 61 var_offset: Tuple or list of `int` specifying offset of this partition 62 with respect to the full variable for each dimension. 63 64 Raises: 65 TypeError: If `full_shape` or `var_offset` is not a sequence. 66 ValueError: If `full_shape` or `var_offset` differ in length. If 67 `var_offset` exceeds `full_shape` in any dimension. 68 """ 69 if not isinstance(full_shape, collections_lib.Sequence) or isinstance( 70 full_shape, six.string_types): 71 raise TypeError( 72 "`full_shape` must be a sequence (like tuple or list) instead of " + 73 type(full_shape).__name__) 74 75 if not isinstance(var_offset, collections_lib.Sequence) or isinstance( 76 var_offset, six.string_types): 77 raise TypeError( 78 "`var_offset` must be a sequence (like tuple or list) instead of " + 79 type(var_offset).__name__) 80 81 if len(var_offset) != len(full_shape): 82 raise ValueError( 83 "Expected equal length, but `var_offset` is of length {} while " 84 "full_shape is of length {}.".format( 85 len(var_offset), len(full_shape))) 86 87 for i in xrange(len(full_shape)): 88 offset = var_offset[i] 89 shape = full_shape[i] 90 if offset < 0 or offset >= shape: 91 raise ValueError( 92 "Expected 0 <= offset < shape but found offset={}, shape={} for " 93 "var_offset={}, full_shape={}".format(offset, shape, var_offset, 94 full_shape)) 95 96 self._full_shape = full_shape 97 self._var_offset = var_offset 98 99 @property 100 def full_shape(self): 101 return self._full_shape 102 103 @property 104 def var_offset(self): 105 return self._var_offset 106 107 def single_offset(self, shape): 108 """Returns the offset when the variable is partitioned in at most one dim. 109 110 Args: 111 shape: Tuple or list of `int` indicating the shape of one specific 112 variable partition. 113 114 Returns: 115 `int` representing the offset in the dimension along which the variable is 116 partitioned. Returns 0 if the variable is not being partitioned. 117 118 Raises: 119 ValueError: Depending on self.single_slice_dim(). 120 """ 121 122 single_slice_dim = self.single_slice_dim(shape) 123 # If this variable is not being partitioned at all, single_slice_dim() could 124 # return None. 125 if single_slice_dim is None: 126 return 0 127 return self.var_offset[single_slice_dim] 128 129 def single_slice_dim(self, shape): 130 """Returns the slice dim when the variable is partitioned only in one dim. 131 132 Args: 133 shape: Tuple or list of `int` indicating the shape of one specific 134 variable partition. 135 136 Returns: 137 `int` representing the dimension that the variable is partitioned in, or 138 `None` if the variable doesn't seem to be partitioned at all. 139 140 Raises: 141 TypeError: If `shape` is not a sequence. 142 ValueError: If `shape` is not the same length as `self.full_shape`. If 143 the variable is partitioned in more than one dimension. 144 """ 145 if not isinstance(shape, collections_lib.Sequence) or isinstance( 146 shape, six.string_types): 147 raise TypeError( 148 "`shape` must be a sequence (like tuple or list) instead of " + 149 type(shape).__name__) 150 151 if len(shape) != len(self.full_shape): 152 raise ValueError( 153 "Expected equal length, but received shape={} of length {} while " 154 "self.full_shape={} is of length {}.".format(shape, len( 155 shape), self.full_shape, len(self.full_shape))) 156 157 for i in xrange(len(shape)): 158 if self.var_offset[i] + shape[i] > self.full_shape[i]: 159 raise ValueError( 160 "With self.var_offset={}, a partition of shape={} would exceed " 161 "self.full_shape={} in dimension {}.".format( 162 self.var_offset, shape, self.full_shape, i)) 163 164 slice_dim = None 165 for i in xrange(len(shape)): 166 if shape[i] == self.full_shape[i]: 167 continue 168 if slice_dim is not None: 169 raise ValueError( 170 "Cannot use single_slice_dim() with shape={} and " 171 "self.full_shape={} since slice dim could be either dimension {} " 172 "or {}.".format(shape, self.full_shape, i, slice_dim)) 173 slice_dim = i 174 175 return slice_dim 176 177 178 class _ReuseMode(enum.Enum): 179 """Mode for variable access within a variable scope.""" 180 181 # Indicates that variables are to be fetched if they already exist or 182 # otherwise created. 183 AUTO_REUSE = 1 184 185 # TODO(alive): For TensorFlow 2.0, Deprecate True/False/None API in favor of 186 # enum values. 187 # REUSE_FALSE = 2 188 # REUSE_TRUE = 3 189 190 AUTO_REUSE = _ReuseMode.AUTO_REUSE 191 tf_export("AUTO_REUSE").export_constant(__name__, "AUTO_REUSE") 192 AUTO_REUSE.__doc__ = """ 193 When passed in as the value for the `reuse` flag, AUTO_REUSE indicates that 194 get_variable() should create the requested variable if it doesn't exist or, if 195 it does exist, simply return it. 196 """ 197 198 199 class _VariableStore(object): 200 """Variable store that carries a number of named Variables. 201 202 New variable names and new variables can be created; all stored 203 variables are initialized with the initializer passed to __init__. 204 205 Attributes: 206 vars: a dictionary with string names (same as passed in GetVar) as keys 207 and the corresponding TensorFlow Variables as values. 208 """ 209 210 def __init__(self): 211 """Create a variable store.""" 212 self._vars = {} # A dictionary of the stored TensorFlow variables. 213 self._partitioned_vars = {} # A dict of the stored PartitionedVariables. 214 self.variable_scopes_count = {} # Count re-used variable scopes. 215 self._store_eager_variables = False 216 217 def open_variable_scope(self, scope_name): 218 if scope_name in self.variable_scopes_count: 219 self.variable_scopes_count[scope_name] += 1 220 else: 221 self.variable_scopes_count[scope_name] = 1 222 223 def close_variable_subscopes(self, scope_name): 224 for k in self.variable_scopes_count: 225 if not scope_name or k.startswith(scope_name + "/"): 226 self.variable_scopes_count[k] = 0 227 228 def variable_scope_count(self, scope_name): 229 return self.variable_scopes_count.get(scope_name, 0) 230 231 def get_variable(self, name, shape=None, dtype=dtypes.float32, 232 initializer=None, regularizer=None, reuse=None, 233 trainable=True, collections=None, caching_device=None, 234 partitioner=None, validate_shape=True, use_resource=None, 235 custom_getter=None, constraint=None): 236 """Gets an existing variable with these parameters or create a new one. 237 238 If a variable with the given name is already stored, we return the stored 239 variable. Otherwise, we create a new one. 240 241 Set `reuse` to `True` when you only want to reuse existing Variables. 242 Set `reuse` to `False` when you only want to create new Variables. 243 Set `reuse` to None (the default) or tf.AUTO_REUSE when you want 244 variables to be created if they don't exist or returned if they do. 245 246 If initializer is `None` (the default), the default initializer passed in 247 the constructor is used. If that one is `None` too, we use a new 248 `glorot_uniform_initializer`. If initializer is a Tensor, we use 249 it as a value and derive the shape from the initializer. 250 251 If a partitioner is provided, a `PartitionedVariable` is returned. 252 Accessing this object as a `Tensor` returns the shards concatenated along 253 the partition axis. 254 255 Some useful partitioners are available. See, e.g., 256 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 257 258 Args: 259 name: The name of the new or existing variable. 260 shape: Shape of the new or existing variable. 261 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 262 initializer: Initializer for the variable. 263 regularizer: A (Tensor -> Tensor or None) function; the result of 264 applying it on a newly created variable will be added to the collection 265 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 266 reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation 267 of variables. When eager execution is enabled this argument is always 268 forced to be False. 269 trainable: If `True` also add the variable to the graph collection 270 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 271 collections: List of graph collections keys to add the `Variable` to. 272 Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 273 caching_device: Optional device string or function describing where the 274 Variable should be cached for reading. Defaults to the Variable's 275 device. If not `None`, caches on another device. Typical use is to 276 cache on the device where the Ops using the `Variable` reside, to 277 deduplicate copying through `Switch` and other conditional statements. 278 partitioner: Optional callable that accepts a fully defined `TensorShape` 279 and dtype of the `Variable` to be created, and returns a list of 280 partitions for each axis (currently only one axis can be partitioned). 281 validate_shape: If False, allows the variable to be initialized with a 282 value of unknown shape. If True, the default, the shape of initial_value 283 must be known. 284 use_resource: If False, creates a regular Variable. If True, creates 285 instead an experimental ResourceVariable which has well-defined 286 semantics. Defaults to False (will later change to True). 287 When eager execution is enabled this argument is always forced to be 288 true. 289 custom_getter: Callable that takes as a first argument the true getter, 290 and allows overwriting the internal get_variable method. 291 The signature of `custom_getter` should match that of this method, 292 but the most future-proof version will allow for changes: 293 `def custom_getter(getter, *args, **kwargs)`. Direct access to 294 all `get_variable` parameters is also allowed: 295 `def custom_getter(getter, name, *args, **kwargs)`. A simple identity 296 custom getter that simply creates variables with modified names is: 297 ```python 298 def custom_getter(getter, name, *args, **kwargs): 299 return getter(name + '_suffix', *args, **kwargs) 300 ``` 301 constraint: An optional projection function to be applied to the variable 302 after being updated by an `Optimizer` (e.g. used to implement norm 303 constraints or value constraints for layer weights). The function must 304 take as input the unprojected Tensor representing the value of the 305 variable and return the Tensor for the projected value 306 (which must have the same shape). Constraints are not safe to 307 use when doing asynchronous distributed training. 308 309 Returns: 310 The created or existing `Variable` (or `PartitionedVariable`, if a 311 partitioner was used). 312 313 Raises: 314 ValueError: when creating a new variable and shape is not declared, 315 when reusing a variable and specifying a conflicting shape, 316 or when violating reuse during variable creation. 317 RuntimeError: when eager execution is enabled and not called from an 318 EagerVariableStore. 319 """ 320 if custom_getter is not None and not callable(custom_getter): 321 raise ValueError( 322 "Passed a custom_getter which is not callable: %s" % custom_getter) 323 324 if context.in_eager_mode(): 325 if not self._store_eager_variables and reuse: 326 raise RuntimeError( 327 "When eager execution is enabled variable reuse is only supported" 328 " when an EagerVariableStore is active. See the documentation on" 329 " EagerVariableStore for example usage.") 330 if self._store_eager_variables: 331 reuse = AUTO_REUSE 332 use_resource = True 333 334 # If a *_ref type is passed in an error would be triggered further down the 335 # stack. We prevent this using base_dtype to get a non-ref version of the 336 # type, before doing anything else. When _ref types are removed in favor of 337 # resources, this line can be removed. 338 try: 339 dtype = dtype.base_dtype 340 except AttributeError: 341 # .base_dtype not existing means that we will try and use the raw dtype 342 # which was passed in - this might be a NumPy type which is valid. 343 pass 344 345 # This is the main logic of get_variable. However, custom_getter 346 # may override this logic. So we save it as a callable and pass 347 # it to custom_getter. 348 # Note: the parameters of _true_getter, and their documentation, match 349 # *exactly* item-for-item with the docstring of this method. 350 def _true_getter(name, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring 351 initializer=None, regularizer=None, reuse=None, 352 trainable=True, collections=None, caching_device=None, 353 partitioner=None, validate_shape=True, use_resource=None, 354 constraint=None): 355 is_scalar = (shape is not None 356 and isinstance(shape, collections_lib.Sequence) 357 and not shape) 358 # Partitioned variable case 359 if partitioner is not None and not is_scalar: 360 if not callable(partitioner): 361 raise ValueError( 362 "Partitioner must be callable, but received: %s" % partitioner) 363 with ops.name_scope(None): 364 return self._get_partitioned_variable(name=name, 365 shape=shape, 366 dtype=dtype, 367 initializer=initializer, 368 regularizer=regularizer, 369 reuse=reuse, 370 trainable=trainable, 371 collections=collections, 372 caching_device=caching_device, 373 partitioner=partitioner, 374 validate_shape=validate_shape, 375 use_resource=use_resource, 376 constraint=constraint) 377 378 # Special case for partitioned variable to allow reuse without having to 379 # specify partitioner. 380 if (reuse is True and partitioner is None 381 and name in self._partitioned_vars): 382 return self._get_partitioned_variable(name=name, 383 shape=shape, 384 dtype=dtype, 385 initializer=initializer, 386 regularizer=regularizer, 387 reuse=reuse, 388 trainable=trainable, 389 collections=collections, 390 caching_device=caching_device, 391 partitioner=None, 392 validate_shape=validate_shape, 393 use_resource=use_resource, 394 constraint=constraint) 395 396 # Single variable case 397 if "%s/part_0" % name in self._vars: 398 raise ValueError( 399 "No partitioner was provided, but a partitioned version of the " 400 "variable was found: %s/part_0. Perhaps a variable of the same " 401 "name was already created with partitioning?" % name) 402 403 return self._get_single_variable( 404 name=name, shape=shape, dtype=dtype, 405 initializer=initializer, regularizer=regularizer, reuse=reuse, 406 trainable=trainable, collections=collections, 407 caching_device=caching_device, validate_shape=validate_shape, 408 use_resource=use_resource, constraint=constraint) 409 410 if custom_getter is not None: 411 # Handle backwards compatibility with getter arguments that were added 412 # to the API after users started writing custom getters. 413 custom_getter_kwargs = { 414 "getter": _true_getter, 415 "name": name, 416 "shape": shape, 417 "dtype": dtype, 418 "initializer": initializer, 419 "regularizer": regularizer, 420 "reuse": reuse, 421 "trainable": trainable, 422 "collections": collections, 423 "caching_device": caching_device, 424 "partitioner": partitioner, 425 "validate_shape": validate_shape, 426 "use_resource": use_resource, 427 } 428 # `fn_args` can handle functions, `functools.partial`, `lambda`. 429 if "constraint" in estimator_util.fn_args(custom_getter): 430 custom_getter_kwargs["constraint"] = constraint 431 return custom_getter(**custom_getter_kwargs) 432 else: 433 return _true_getter( 434 name, shape=shape, dtype=dtype, 435 initializer=initializer, regularizer=regularizer, 436 reuse=reuse, trainable=trainable, collections=collections, 437 caching_device=caching_device, partitioner=partitioner, 438 validate_shape=validate_shape, use_resource=use_resource, 439 constraint=constraint) 440 441 def _get_partitioned_variable( 442 self, name, partitioner, shape=None, dtype=dtypes.float32, 443 initializer=None, regularizer=None, reuse=None, 444 trainable=True, collections=None, caching_device=None, 445 validate_shape=True, use_resource=None, constraint=None): 446 """Gets or creates a sharded variable list with these parameters. 447 448 The `partitioner` must be a callable that accepts a fully defined 449 `TensorShape` and returns a sequence of integers (the `partitions`). 450 These integers describe how to partition the given sharded `Variable` 451 along the given dimension. That is, `partitions[1] = 3` means split 452 the `Variable` into 3 shards along dimension 1. Currently, sharding along 453 only one axis is supported. 454 455 If the list of variables with the given name (prefix) is already stored, 456 we return the stored variables. Otherwise, we create a new one. 457 458 Set `reuse` to `True` when you only want to reuse existing Variables. 459 Set `reuse` to `False` when you only want to create new Variables. 460 Set `reuse` to None (the default) or tf.AUTO_REUSE when you want 461 variables to be created if they don't exist or returned if they do. 462 463 If initializer is `None` (the default), the default initializer passed in 464 the constructor is used. If that one is `None` too, we use a new 465 `glorot_uniform_initializer`. If initializer is a Tensor, we use 466 it as a value and derive the shape from the initializer. 467 468 If the initializer is a callable, then it will be called for each 469 shard. Otherwise the initializer should match the shape of the entire 470 sharded Variable, and it will be sliced accordingly for each shard. 471 472 Some useful partitioners are available. See, e.g., 473 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 474 475 Args: 476 name: the name of the new or existing sharded variable. 477 partitioner: Optional callable that accepts a fully defined `TensorShape` 478 and `dtype` of the Variable to be created, and returns a list of 479 partitions for each axis (currently only one axis can be partitioned). 480 shape: shape of the new or existing sharded variable. 481 dtype: type of the new or existing sharded variable 482 (defaults to `DT_FLOAT`). 483 initializer: initializer for the sharded variable. 484 regularizer: a (Tensor -> Tensor or None) function; the result of 485 applying it on a newly created variable will be added to the collection 486 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 487 reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation 488 of variables. 489 trainable: If `True` also add the variable to the graph collection 490 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 491 collections: List of graph collections keys to add the Variable to. 492 Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 493 caching_device: Optional device string or function describing where the 494 Variable should be cached for reading. Defaults to the Variable's 495 device. If not `None`, caches on another device. Typical use is to 496 cache on the device where the Ops using the Variable reside, to 497 deduplicate copying through `Switch` and other conditional statements. 498 validate_shape: If False, allows the variable to be initialized with a 499 value of unknown shape. If True, the default, the shape of initial_value 500 must be known. 501 use_resource: If False, creates a regular Variable. If True, creates an 502 experimental ResourceVariable which has well-defined semantics. Defaults 503 to False (will later change to True). 504 constraint: An optional projection function to be applied to the variable 505 after being updated by an `Optimizer` (e.g. used to implement norm 506 constraints or value constraints for layer weights). The function must 507 take as input the unprojected Tensor representing the value of the 508 variable and return the Tensor for the projected value 509 (which must have the same shape). Constraints are not safe to 510 use when doing asynchronous distributed training. 511 512 Returns: 513 A `PartitionedVariable` object. 514 515 Raises: 516 ValueError: when creating a new variable and shape is not declared, 517 when reusing a variable and specifying a conflicting shape, 518 when violating reuse during variable creation, or if an existing 519 sharded variable exists for the given name but with different sharding. 520 """ 521 if context.in_eager_mode(): 522 raise NotImplementedError("Partitioned variables are not yet supported " 523 "when eager execution is enabled.") 524 525 initializing_from_value = initializer is not None and isinstance( 526 initializer, ops.Tensor) 527 reuse_without_partition = reuse and not partitioner 528 529 if name in self._vars: 530 raise ValueError( 531 "A partitioner was provided, but an unpartitioned version of the " 532 "variable was found: %s. Perhaps a variable of the same name was " 533 "already created without partitioning?" % name) 534 535 shape = tensor_shape.as_shape(shape) 536 if initializing_from_value: 537 shape = shape.merge_with(initializer.get_shape()) 538 539 if not reuse_without_partition: 540 if not shape.is_fully_defined(): 541 raise ValueError("Shape of a new partitioned variable (%s) must be " 542 "fully defined, but instead was %s." % (name, shape)) 543 544 if shape.ndims < 1: 545 raise ValueError("A partitioned Variable must have rank at least 1, " 546 "shape: %s" % shape) 547 548 partitions = partitioner(shape=shape, dtype=dtype) 549 550 if not isinstance(partitions, collections_lib.Sequence): 551 raise ValueError("Partitioner must return a sequence, but saw: %s" 552 % partitions) 553 554 if len(partitions) != shape.ndims: 555 raise ValueError( 556 "Partitioner returned a partition list that does not match the " 557 "Variable's rank: %s vs. %s" % (partitions, shape)) 558 559 if any([p < 1 for p in partitions]): 560 raise ValueError( 561 "Partitioner returned zero partitions for some axes: %s" % 562 partitions) 563 564 if name in self._partitioned_vars: 565 if reuse is False: 566 raise ValueError( 567 "Partitioned variable with name %s already exists. Did you mean to " 568 "set reuse=True or reuse=tf.AUTO_REUSE in VarScope?" 569 % name) 570 571 existing_var = self._partitioned_vars[name] 572 if not shape.is_compatible_with(existing_var.get_shape()): 573 raise ValueError( 574 "Trying to reuse partitioned variable %s, but specified shape %s " 575 "and found shape %s." 576 % (name, shape, existing_var.get_shape())) 577 if not dtype.is_compatible_with(existing_var.dtype): 578 raise ValueError( 579 "Trying to reuse partitioned variable %s, but specified dtype %s " 580 "and found dtype %s." 581 % (name, dtype.name, existing_var.dtype.name)) 582 583 # pylint: disable=protected-access 584 if (not reuse_without_partition and 585 existing_var._get_partitions() != partitions): 586 raise ValueError( 587 "Trying to reuse partitioned variable %s, but specified partitions " 588 "%s and found partitions %s." % 589 (name, partitions, existing_var._get_partitions())) 590 # pylint: enable=protected-access 591 592 return existing_var 593 594 if reuse is True: 595 raise ValueError("PartitionedVariable %s does not exist, or was not " 596 "created with tf.get_variable(). Did you mean to set " 597 "reuse=False or reuse=tf.AUTO_REUSE in VarScope?" % name) 598 599 slice_dim, slice_shape = _compute_slice_dim_and_shape( 600 shape.as_list(), partitions) 601 602 vs = [] 603 num_slices = partitions[slice_dim] 604 num_slices_with_excess = shape[slice_dim].value % num_slices 605 606 slice_offset = [0] * shape.ndims 607 608 if "%s/part_0" % name in self._vars: 609 if "%s/part_%d" % (name, num_slices - 1) not in self._vars: 610 raise ValueError( 611 "Partitioner returned a different partitioning than what was " 612 "already found. Partitioner returned %d shards, and shard " 613 "%s/part_0 was found, but %s/part_%d was not." 614 % (num_slices, name, name, num_slices - 1)) 615 if "%s/part_%d" % (name, num_slices) in self._vars: 616 raise ValueError( 617 "Partitioner returned a different partitioning than what was " 618 "already found. Partitioner returned %d shards, and shard " 619 "%s/part_0 was found, but so was the extra shard %s/part_%d." 620 % (num_slices, name, name, num_slices)) 621 622 for i in xrange(num_slices): 623 var_shape = slice_shape[:] 624 var_offset = slice_offset[:] 625 partition_info = _PartitionInfo( 626 full_shape=shape.as_list(), var_offset=var_offset) 627 if i < num_slices_with_excess: 628 var_shape[slice_dim] += 1 629 slice_offset[slice_dim] += var_shape[slice_dim] 630 631 var_full_name = "%s/part_%d" % (name, i) 632 with ops.name_scope(var_full_name + "/PartitionedInitializer"): 633 # Create the tensor to initialize the variable with default value. 634 if initializer is None: 635 init, initializing_from_value = self._get_default_initializer( 636 name=name, shape=shape, dtype=dtype) 637 if initializing_from_value: 638 init_shape = None 639 else: 640 init_shape = var_shape 641 elif callable(initializer): 642 init = initializer 643 init_shape = var_shape 644 elif isinstance(initializer, ops.Tensor): 645 init = array_ops.slice(initializer, var_offset, var_shape) 646 # Use the dtype of the given tensor. 647 dtype = init.dtype.base_dtype 648 init_shape = None 649 else: 650 init = ops.convert_to_tensor(initializer, dtype=dtype) 651 init = array_ops.slice(init, var_offset, var_shape) 652 init_shape = None 653 654 with ops.name_scope(None): 655 var = self._get_single_variable( 656 name=var_full_name, 657 shape=init_shape, 658 dtype=dtype, 659 initializer=init, 660 partition_info=partition_info, 661 regularizer=regularizer, 662 reuse=reuse, 663 trainable=trainable, 664 collections=collections, 665 caching_device=caching_device, 666 validate_shape=validate_shape, 667 use_resource=use_resource, 668 constraint=constraint) 669 670 # pylint: disable=protected-access 671 var._set_save_slice_info(variables.Variable.SaveSliceInfo( 672 name, shape.as_list(), var_offset, var_shape)) 673 vs.append(var) 674 # pylint: enable=protected-access 675 676 # pylint: disable=protected-access 677 partitioned_var = variables.PartitionedVariable(name=name, 678 shape=shape, 679 dtype=dtype, 680 variable_list=vs, 681 partitions=partitions) 682 # pylint: enable=protected-access 683 684 self._partitioned_vars[name] = partitioned_var 685 return partitioned_var 686 687 def _get_single_variable(self, 688 name, 689 shape=None, 690 dtype=dtypes.float32, 691 initializer=None, 692 regularizer=None, 693 partition_info=None, 694 reuse=None, 695 trainable=True, 696 collections=None, 697 caching_device=None, 698 validate_shape=True, 699 use_resource=None, 700 constraint=None): 701 """Get or create a single Variable (e.g. a shard or entire variable). 702 703 See the documentation of get_variable above (ignore partitioning components) 704 for details. 705 706 Args: 707 name: see get_variable. 708 shape: see get_variable. 709 dtype: see get_variable. 710 initializer: see get_variable. 711 regularizer: see get_variable. 712 partition_info: _PartitionInfo object. 713 reuse: see get_variable. 714 trainable: see get_variable. 715 collections: see get_variable. 716 caching_device: see get_variable. 717 validate_shape: see get_variable. 718 use_resource: see get_variable. 719 constraint: see get_variable. 720 721 Returns: 722 A Variable. See documentation of get_variable above. 723 724 Raises: 725 ValueError: See documentation of get_variable above. 726 """ 727 # Set to true if initializer is a constant. 728 initializing_from_value = False 729 if initializer is not None and not callable(initializer): 730 initializing_from_value = True 731 if shape is not None and initializing_from_value: 732 raise ValueError("If initializer is a constant, do not specify shape.") 733 734 dtype = dtypes.as_dtype(dtype) 735 shape = tensor_shape.as_shape(shape) 736 737 if name in self._vars: 738 # Here we handle the case when returning an existing variable. 739 if reuse is False: 740 tb = self._vars[name].op.traceback[::-1] 741 # Throw away internal tf entries and only take a few lines. 742 tb = [x for x in tb if "tensorflow/python" not in x[0]][:3] 743 raise ValueError("Variable %s already exists, disallowed." 744 " Did you mean to set reuse=True or " 745 "reuse=tf.AUTO_REUSE in VarScope? " 746 "Originally defined at:\n\n%s" % ( 747 name, "".join(traceback.format_list(tb)))) 748 found_var = self._vars[name] 749 if not shape.is_compatible_with(found_var.get_shape()): 750 raise ValueError("Trying to share variable %s, but specified shape %s" 751 " and found shape %s." % (name, shape, 752 found_var.get_shape())) 753 if not dtype.is_compatible_with(found_var.dtype): 754 dtype_str = dtype.name 755 found_type_str = found_var.dtype.name 756 raise ValueError("Trying to share variable %s, but specified dtype %s" 757 " and found dtype %s." % (name, dtype_str, 758 found_type_str)) 759 return found_var 760 761 # The code below handles only the case of creating a new variable. 762 if reuse is True: 763 raise ValueError("Variable %s does not exist, or was not created with " 764 "tf.get_variable(). Did you mean to set " 765 "reuse=tf.AUTO_REUSE in VarScope?" % name) 766 if not shape.is_fully_defined() and not initializing_from_value: 767 raise ValueError("Shape of a new variable (%s) must be fully defined, " 768 "but instead was %s." % (name, shape)) 769 770 # Create the tensor to initialize the variable with default value. 771 if initializer is None: 772 initializer, initializing_from_value = self._get_default_initializer( 773 name=name, shape=shape, dtype=dtype) 774 # Enter an init scope when creating the initializer. 775 with ops.init_scope(): 776 if initializing_from_value: 777 init_val = initializer 778 variable_dtype = None 779 else: 780 # Instantiate initializer if provided initializer is a type object. 781 if isinstance(initializer, type(init_ops.Initializer)): 782 initializer = initializer(dtype=dtype) 783 init_val = lambda: initializer( # pylint: disable=g-long-lambda 784 shape.as_list(), dtype=dtype, partition_info=partition_info) 785 variable_dtype = dtype.base_dtype 786 787 # Create the variable. 788 if use_resource is None: 789 # Set the default value if unspecified. 790 use_resource = False 791 v = variable( 792 initial_value=init_val, 793 name=name, 794 trainable=trainable, 795 collections=collections, 796 caching_device=caching_device, 797 dtype=variable_dtype, 798 validate_shape=validate_shape, 799 constraint=constraint, 800 use_resource=use_resource) 801 if context.in_graph_mode() or self._store_eager_variables: 802 # In eager mode we do not want to keep default references to Variable 803 # objects as this will prevent their memory from being released. 804 self._vars[name] = v 805 logging.vlog(1, "Created variable %s with shape %s and init %s", v.name, 806 format(shape), initializer) 807 808 # Run the regularizer if requested and save the resulting loss. 809 if regularizer: 810 with ops.colocate_with(v): 811 with ops.name_scope(name + "/Regularizer/"): 812 loss = regularizer(v) 813 if loss is not None: 814 if context.in_graph_mode(): 815 v_name = v.name 816 loss_name = loss.name 817 else: 818 v_name = "v_%s" % type(v) 819 loss_name = "loss_%s" % type(loss) 820 logging.vlog(1, "Applied regularizer to %s and added the result %s " 821 "to REGULARIZATION_LOSSES.", v_name, loss_name) 822 ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, loss) 823 return v 824 825 # Initialize variable when no initializer provided 826 def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32): 827 """Provide a default initializer and a corresponding value. 828 829 Args: 830 name: see get_variable. 831 shape: see get_variable. 832 dtype: see get_variable. 833 834 Returns: 835 initializer and initializing_from_value. See get_variable above. 836 837 Raises: 838 ValueError: When giving unsupported dtype. 839 """ 840 del shape 841 # If dtype is DT_FLOAT, provide a uniform unit scaling initializer 842 if dtype.is_floating: 843 initializer = init_ops.glorot_uniform_initializer() 844 initializing_from_value = False 845 # If dtype is DT_INT/DT_UINT, provide a default value `zero` 846 # If dtype is DT_BOOL, provide a default value `FALSE` 847 elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool: 848 initializer = init_ops.zeros_initializer() 849 initializing_from_value = False 850 # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here? 851 else: 852 raise ValueError("An initializer for variable %s of %s is required" 853 % (name, dtype.base_dtype)) 854 855 return initializer, initializing_from_value 856 857 858 # To stop regularization, use this regularizer 859 @tf_export("no_regularizer") 860 def no_regularizer(_): 861 """Use this function to prevent regularization of variables.""" 862 return None 863 864 865 # TODO(alive): support caching devices and partitioned variables in Eager mode. 866 @tf_export("VariableScope") 867 class VariableScope(object): 868 """Variable scope object to carry defaults to provide to `get_variable`. 869 870 Many of the arguments we need for `get_variable` in a variable store are most 871 easily handled with a context. This object is used for the defaults. 872 873 Attributes: 874 name: name of the current scope, used as prefix in get_variable. 875 initializer: default initializer passed to get_variable. 876 regularizer: default regularizer passed to get_variable. 877 reuse: Boolean, None, or tf.AUTO_REUSE, setting the reuse in 878 get_variable. When eager execution is enabled this argument is always 879 forced to be False. 880 caching_device: string, callable, or None: the caching device passed to 881 get_variable. 882 partitioner: callable or `None`: the partitioner passed to `get_variable`. 883 custom_getter: default custom getter passed to get_variable. 884 name_scope: The name passed to `tf.name_scope`. 885 dtype: default type passed to get_variable (defaults to DT_FLOAT). 886 use_resource: if False, create a normal Variable; if True create an 887 experimental ResourceVariable with well-defined semantics. Defaults 888 to False (will later change to True). When eager execution is enabled 889 this argument is always forced to be True. 890 constraint: An optional projection function to be applied to the variable 891 after being updated by an `Optimizer` (e.g. used to implement norm 892 constraints or value constraints for layer weights). The function must 893 take as input the unprojected Tensor representing the value of the 894 variable and return the Tensor for the projected value 895 (which must have the same shape). Constraints are not safe to 896 use when doing asynchronous distributed training. 897 """ 898 899 def __init__(self, 900 reuse, 901 name="", 902 initializer=None, 903 regularizer=None, 904 caching_device=None, 905 partitioner=None, 906 custom_getter=None, 907 name_scope="", 908 dtype=dtypes.float32, 909 use_resource=None, 910 constraint=None): 911 """Creates a new VariableScope with the given properties.""" 912 self._name = name 913 self._initializer = initializer 914 self._regularizer = regularizer 915 self._reuse = reuse 916 self._caching_device = caching_device 917 self._partitioner = partitioner 918 self._custom_getter = custom_getter 919 self._name_scope = name_scope 920 self._dtype = dtype 921 self._use_resource = use_resource 922 self._constraint = constraint 923 if context.in_eager_mode(): 924 if self._caching_device is not None: 925 raise NotImplementedError("Caching devices is not yet supported " 926 "when eager execution is enabled.") 927 if self._partitioner is not None: 928 raise NotImplementedError("Partitioned variables are not yet supported " 929 "when eager execution is enabled.") 930 self._reuse = AUTO_REUSE 931 self._use_resource = True 932 933 @property 934 def name(self): 935 return self._name 936 937 @property 938 def original_name_scope(self): 939 return self._name_scope 940 941 @property 942 def reuse(self): 943 return self._reuse 944 945 @property 946 def initializer(self): 947 return self._initializer 948 949 @property 950 def dtype(self): 951 return self._dtype 952 953 @property 954 def use_resource(self): 955 return self._use_resource 956 957 @property 958 def regularizer(self): 959 return self._regularizer 960 961 @property 962 def caching_device(self): 963 return self._caching_device 964 965 @property 966 def partitioner(self): 967 return self._partitioner 968 969 @property 970 def custom_getter(self): 971 return self._custom_getter 972 973 @property 974 def constraint(self): 975 return self._constraint 976 977 def reuse_variables(self): 978 """Reuse variables in this scope.""" 979 self._reuse = True 980 981 def set_initializer(self, initializer): 982 """Set initializer for this scope.""" 983 self._initializer = initializer 984 985 def set_dtype(self, dtype): 986 """Set data type for this scope.""" 987 self._dtype = dtype 988 989 def set_use_resource(self, use_resource): 990 """Sets whether to use ResourceVariables for this scope.""" 991 if context.in_eager_mode() and not use_resource: 992 raise ValueError("When eager execution is enabled, " 993 "use_resource cannot be set to false.") 994 self._use_resource = use_resource 995 996 def set_regularizer(self, regularizer): 997 """Set regularizer for this scope.""" 998 self._regularizer = regularizer 999 1000 def set_caching_device(self, caching_device): 1001 """Set caching_device for this scope.""" 1002 if context.in_eager_mode(): 1003 raise NotImplementedError("Caching devices are not yet supported " 1004 "when eager execution is enabled.") 1005 self._caching_device = caching_device 1006 1007 def set_partitioner(self, partitioner): 1008 """Set partitioner for this scope.""" 1009 if partitioner and context.in_eager_mode(): 1010 raise NotImplementedError("Partitioned variables are not yet supported " 1011 "when eager execution is enabled.") 1012 self._partitioner = partitioner 1013 1014 def set_custom_getter(self, custom_getter): 1015 """Set custom getter for this scope.""" 1016 self._custom_getter = custom_getter 1017 1018 def get_collection(self, name): 1019 """Get this scope's variables.""" 1020 scope = self._name + "/" if self._name else "" 1021 return ops.get_collection(name, scope) 1022 1023 def trainable_variables(self): 1024 """Get this scope's trainable variables.""" 1025 return self.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) 1026 1027 def global_variables(self): 1028 """Get this scope's global variables.""" 1029 return self.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 1030 1031 def local_variables(self): 1032 """Get this scope's local variables.""" 1033 return self.get_collection(ops.GraphKeys.LOCAL_VARIABLES) 1034 1035 def get_variable(self, 1036 var_store, 1037 name, 1038 shape=None, 1039 dtype=None, 1040 initializer=None, 1041 regularizer=None, 1042 reuse=None, 1043 trainable=True, 1044 collections=None, 1045 caching_device=None, 1046 partitioner=None, 1047 validate_shape=True, 1048 use_resource=None, 1049 custom_getter=None, 1050 constraint=None): 1051 """Gets an existing variable with this name or create a new one.""" 1052 if regularizer is None: 1053 regularizer = self._regularizer 1054 if caching_device is None: 1055 caching_device = self._caching_device 1056 if partitioner is None: 1057 partitioner = self._partitioner 1058 if custom_getter is None: 1059 custom_getter = self._custom_getter 1060 if context.in_graph_mode(): 1061 if reuse is None: 1062 reuse = self._reuse 1063 if use_resource is None: 1064 use_resource = self._use_resource 1065 else: 1066 reuse = False 1067 use_resource = True 1068 1069 full_name = self.name + "/" + name if self.name else name 1070 # Variable names only depend on variable_scope (full_name here), 1071 # not name_scope, so we reset it below for the time of variable creation. 1072 with ops.name_scope(None): 1073 # Check that `initializer` dtype and `dtype` are consistent before 1074 # replacing them with defaults. 1075 if (dtype is not None and initializer is not None and 1076 not callable(initializer)): 1077 init_dtype = ops.convert_to_tensor(initializer).dtype.base_dtype 1078 if init_dtype != dtype: 1079 raise ValueError("Initializer type '%s' and explicit dtype '%s' " 1080 "don't match." % (init_dtype, dtype)) 1081 if initializer is None: 1082 initializer = self._initializer 1083 if constraint is None: 1084 constraint = self._constraint 1085 if dtype is None: 1086 dtype = self._dtype 1087 return var_store.get_variable( 1088 full_name, shape=shape, dtype=dtype, initializer=initializer, 1089 regularizer=regularizer, reuse=reuse, trainable=trainable, 1090 collections=collections, caching_device=caching_device, 1091 partitioner=partitioner, validate_shape=validate_shape, 1092 use_resource=use_resource, custom_getter=custom_getter, 1093 constraint=constraint) 1094 1095 def _get_partitioned_variable(self, 1096 var_store, 1097 name, 1098 shape=None, 1099 dtype=None, 1100 initializer=None, 1101 regularizer=None, 1102 trainable=True, 1103 collections=None, 1104 caching_device=None, 1105 partitioner=None, 1106 validate_shape=True, 1107 use_resource=None, 1108 constraint=None): 1109 """Gets an existing variable with this name or create a new one.""" 1110 if context.in_eager_mode(): 1111 raise NotImplementedError("Partitioned variables are not yet supported " 1112 "when eager execution is enabled.") 1113 if initializer is None: 1114 initializer = self._initializer 1115 if regularizer is None: 1116 regularizer = self._regularizer 1117 if constraint is None: 1118 constraint = self._constraint 1119 if caching_device is None: 1120 caching_device = self._caching_device 1121 if partitioner is None: 1122 partitioner = self._partitioner 1123 if dtype is None: 1124 dtype = self._dtype 1125 if use_resource is None: 1126 use_resource = self._use_resource 1127 1128 if self._custom_getter is not None: 1129 raise ValueError( 1130 "Private access to _get_partitioned_variable is not allowed when " 1131 "a custom getter is set. Current custom getter: %s. " 1132 "It is likely that you're using create_partitioned_variables. " 1133 "If so, consider instead using get_variable with a non-empty " 1134 "partitioner parameter instead." % self._custom_getter) 1135 1136 if partitioner is None: 1137 raise ValueError("No partitioner was specified") 1138 1139 # This allows the variable scope name to be used as the variable name if 1140 # this function is invoked with an empty name arg, for backward 1141 # compatibility with create_partitioned_variables(). 1142 full_name_list = [] 1143 if self.name: 1144 full_name_list.append(self.name) 1145 if name: 1146 full_name_list.append(name) 1147 full_name = "/".join(full_name_list) 1148 1149 # Variable names only depend on variable_scope (full_name here), 1150 # not name_scope, so we reset it below for the time of variable creation. 1151 with ops.name_scope(None): 1152 # pylint: disable=protected-access 1153 return var_store._get_partitioned_variable( 1154 full_name, shape=shape, dtype=dtype, initializer=initializer, 1155 regularizer=regularizer, reuse=self.reuse, trainable=trainable, 1156 collections=collections, caching_device=caching_device, 1157 partitioner=partitioner, validate_shape=validate_shape, 1158 use_resource=use_resource, constraint=constraint) 1159 # pylint: enable=protected-access 1160 1161 1162 _VARSTORE_KEY = ("__variable_store",) 1163 _VARSCOPE_KEY = ("__varscope",) 1164 1165 1166 @tf_export("get_variable_scope") 1167 def get_variable_scope(): 1168 """Returns the current variable scope.""" 1169 scope = ops.get_collection(_VARSCOPE_KEY) 1170 if scope: # This collection has at most 1 element, the default scope at [0]. 1171 return scope[0] 1172 scope = VariableScope(False) 1173 ops.add_to_collection(_VARSCOPE_KEY, scope) 1174 return scope 1175 1176 1177 def _get_default_variable_store(): 1178 store = ops.get_collection(_VARSTORE_KEY) 1179 if store: 1180 return store[0] 1181 store = _VariableStore() 1182 ops.add_to_collection(_VARSTORE_KEY, store) 1183 return store 1184 1185 1186 @tf_contextlib.contextmanager 1187 def with_variable_store(store): 1188 store_collection = ops.get_collection_ref(_VARSTORE_KEY) 1189 old = list(store_collection) 1190 store_collection[:] = [store] 1191 try: 1192 yield 1193 finally: 1194 store_collection[:] = old 1195 1196 1197 class EagerVariableStore(object): 1198 """Wrapper allowing functional layers to be used with eager execution. 1199 1200 When eager execution is enabled Variables get deleted when they go out of 1201 scope, and are not stored in global collections by default. A lot of code 1202 (mostly the functional layers in tf.layers) assumes that variables are kept in 1203 a global list. 1204 1205 EagerVariableStore can be used in conjunction with this code to make it 1206 eager-friendly. For example, to create a dense layer, use: 1207 1208 ``` 1209 container = tfe.EagerVariableStore() 1210 for input in dataset_iterator: 1211 with container.as_default(): 1212 x = tf.layers.dense(input, name="l1") 1213 print(container.variables) # Should print the variables used in the layer. 1214 ``` 1215 """ 1216 1217 def __init__(self, store=None): 1218 if store is not None: 1219 if not store._store_eager_variables: # pylint: disable=protected-access 1220 raise ValueError("Cannot construct EagerVariableStore from a " 1221 "VariableStore object that does not hold eager " 1222 "variables.") 1223 self._store = store 1224 else: 1225 self._store = _VariableStore() 1226 self._store._store_eager_variables = True # pylint: disable=protected-access 1227 1228 def as_default(self): 1229 return with_variable_store(self._store) 1230 1231 def variables(self): 1232 return sorted(self._store._vars.values(), key=lambda x: x.name) # pylint: disable=protected-access 1233 1234 def trainable_variables(self): 1235 # pylint: disable=protected-access 1236 return sorted([x for x in self._store._vars.values() if x._trainable], 1237 key=lambda x: x.name) 1238 # pylint: enable=protected-access 1239 1240 def non_trainable_variables(self): 1241 # pylint: disable=protected-access 1242 return sorted([x for x in self._store._vars.values() if not x._trainable], 1243 key=lambda x: x.name) 1244 # pylint: enable=protected-access 1245 1246 def copy(self): 1247 """Copy this variable store and all of its contents. 1248 1249 Variables contained in this store will be copied over to the new variable 1250 store, meaning that they can be modified without affecting the variables in 1251 this store. 1252 1253 Returns: 1254 A new EagerVariableStore instance containing copied variables. 1255 """ 1256 # pylint: disable=protected-access 1257 new_store = EagerVariableStore() 1258 for key, var in iteritems(self._store._vars): 1259 # Strip device out of variable name. 1260 try: 1261 index = var.name.index(":") 1262 except ValueError: 1263 stripped_var_name = var.name 1264 else: 1265 stripped_var_name = var.name[:index] 1266 1267 # Create new variable with same value, name, and "trainable" flag. 1268 new_var = resource_variable_ops.ResourceVariable( 1269 var.read_value(), 1270 name=stripped_var_name, 1271 trainable=var._trainable) 1272 new_store._store._vars[key] = new_var 1273 return new_store 1274 # pylint: enable=protected-access 1275 1276 1277 @tf_export("get_variable") 1278 def get_variable(name, 1279 shape=None, 1280 dtype=None, 1281 initializer=None, 1282 regularizer=None, 1283 trainable=True, 1284 collections=None, 1285 caching_device=None, 1286 partitioner=None, 1287 validate_shape=True, 1288 use_resource=None, 1289 custom_getter=None, 1290 constraint=None): 1291 return get_variable_scope().get_variable( 1292 _get_default_variable_store(), name, shape=shape, dtype=dtype, 1293 initializer=initializer, regularizer=regularizer, trainable=trainable, 1294 collections=collections, caching_device=caching_device, 1295 partitioner=partitioner, validate_shape=validate_shape, 1296 use_resource=use_resource, custom_getter=custom_getter, 1297 constraint=constraint) 1298 get_variable_or_local_docstring = ( 1299 """%s 1300 1301 %sThis function prefixes the name with the current variable scope 1302 and performs reuse checks. See the 1303 @{$variables$Variable Scope How To} 1304 for an extensive description of how reusing works. Here is a basic example: 1305 1306 ```python 1307 def foo(): 1308 with tf.variable_scope("foo", reuse=tf.AUTO_REUSE): 1309 v = tf.get_variable("v", [1]) 1310 return v 1311 1312 v1 = foo() # Creates v. 1313 v2 = foo() # Gets the same, existing v. 1314 assert v1 == v2 1315 ``` 1316 1317 If initializer is `None` (the default), the default initializer passed in 1318 the variable scope will be used. If that one is `None` too, a 1319 `glorot_uniform_initializer` will be used. The initializer can also be 1320 a Tensor, in which case the variable is initialized to this value and shape. 1321 1322 Similarly, if the regularizer is `None` (the default), the default regularizer 1323 passed in the variable scope will be used (if that is `None` too, 1324 then by default no regularization is performed). 1325 1326 If a partitioner is provided, a `PartitionedVariable` is returned. 1327 Accessing this object as a `Tensor` returns the shards concatenated along 1328 the partition axis. 1329 1330 Some useful partitioners are available. See, e.g., 1331 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 1332 1333 Args: 1334 name: The name of the new or existing variable. 1335 shape: Shape of the new or existing variable. 1336 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 1337 initializer: Initializer for the variable if one is created. 1338 regularizer: A (Tensor -> Tensor or None) function; the result of 1339 applying it on a newly created variable will be added to the collection 1340 @{tf.GraphKeys.REGULARIZATION_LOSSES} and can be used for regularization. 1341 %scollections: List of graph collections keys to add the Variable to. 1342 Defaults to `[%s]` (see `tf.Variable`). 1343 caching_device: Optional device string or function describing where the 1344 Variable should be cached for reading. Defaults to the Variable's 1345 device. If not `None`, caches on another device. Typical use is to 1346 cache on the device where the Ops using the Variable reside, to 1347 deduplicate copying through `Switch` and other conditional statements. 1348 partitioner: Optional callable that accepts a fully defined `TensorShape` 1349 and `dtype` of the Variable to be created, and returns a list of 1350 partitions for each axis (currently only one axis can be partitioned). 1351 validate_shape: If False, allows the variable to be initialized with a 1352 value of unknown shape. If True, the default, the shape of initial_value 1353 must be known. 1354 use_resource: If False, creates a regular Variable. If true, creates an 1355 experimental ResourceVariable instead with well-defined semantics. 1356 Defaults to False (will later change to True). When eager execution is 1357 enabled this argument is always forced to be True. 1358 custom_getter: Callable that takes as a first argument the true getter, and 1359 allows overwriting the internal get_variable method. 1360 The signature of `custom_getter` should match that of this method, 1361 but the most future-proof version will allow for changes: 1362 `def custom_getter(getter, *args, **kwargs)`. Direct access to 1363 all `get_variable` parameters is also allowed: 1364 `def custom_getter(getter, name, *args, **kwargs)`. A simple identity 1365 custom getter that simply creates variables with modified names is: 1366 ```python 1367 def custom_getter(getter, name, *args, **kwargs): 1368 return getter(name + '_suffix', *args, **kwargs) 1369 ``` 1370 1371 Returns: 1372 The created or existing `Variable` (or `PartitionedVariable`, if a 1373 partitioner was used). 1374 1375 Raises: 1376 ValueError: when creating a new variable and shape is not declared, 1377 when violating reuse during variable creation, or when `initializer` dtype 1378 and `dtype` don't match. Reuse is set inside `variable_scope`. 1379 """) 1380 get_variable.__doc__ = get_variable_or_local_docstring % ( 1381 "Gets an existing variable with these parameters or create a new one.", 1382 "", 1383 "trainable: If `True` also add the variable to the graph collection\n" 1384 " `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).\n ", 1385 "GraphKeys.GLOBAL_VARIABLES") 1386 1387 1388 @functools.wraps(get_variable) 1389 @tf_export("get_local_variable") 1390 def get_local_variable(*args, **kwargs): 1391 kwargs["trainable"] = False 1392 if "collections" in kwargs: 1393 kwargs["collections"] += [ops.GraphKeys.LOCAL_VARIABLES] 1394 else: 1395 kwargs["collections"] = [ops.GraphKeys.LOCAL_VARIABLES] 1396 return get_variable(*args, **kwargs) 1397 get_local_variable.__doc__ = get_variable_or_local_docstring % ( 1398 "Gets an existing *local* variable or creates a new one.", 1399 "Behavior is the same as in `get_variable`, except that variables are\n" 1400 "added to the `LOCAL_VARIABLES` collection and `trainable` is set to\n" 1401 "`False`.\n", 1402 "", 1403 "GraphKeys.LOCAL_VARIABLES") 1404 1405 1406 def _get_partitioned_variable(name, 1407 shape=None, 1408 dtype=None, 1409 initializer=None, 1410 regularizer=None, 1411 trainable=True, 1412 collections=None, 1413 caching_device=None, 1414 partitioner=None, 1415 validate_shape=True, 1416 use_resource=None, 1417 constraint=None): 1418 """Gets or creates a sharded variable list with these parameters. 1419 1420 The `partitioner` must be a callable that accepts a fully defined 1421 `TensorShape` and returns a sequence of integers (the `partitions`). 1422 These integers describe how to partition the given sharded `Variable` 1423 along the given dimension. That is, `partitions[1] = 3` means split 1424 the `Variable` into 3 shards along dimension 1. Currently, sharding along 1425 only one axis is supported. 1426 1427 If the list of variables with the given name (prefix) is already stored, 1428 we return the stored variables. Otherwise, we create a new one. 1429 1430 If initializer is `None` (the default), the default initializer passed in 1431 the constructor is used. If that one is `None` too, we use a new 1432 `glorot_uniform_initializer`. If initializer is a Tensor, we use 1433 it as a value and derive the shape from the initializer. 1434 1435 If the initializer is a callable, then it will be called for each 1436 shard. Otherwise the initializer should match the shape of the entire 1437 sharded Variable, and it will be sliced accordingly for each shard. 1438 1439 Some useful partitioners are available. See, e.g., 1440 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 1441 1442 Args: 1443 name: The name of the new or existing variable. 1444 shape: Shape of the new or existing variable. 1445 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 1446 initializer: Initializer for the variable if one is created. 1447 regularizer: A (Tensor -> Tensor or None) function; the result of 1448 applying it on a newly created variable will be added to the collection 1449 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 1450 trainable: If `True` also add the variable to the graph collection 1451 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 1452 collections: List of graph collections keys to add the Variable to. 1453 Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 1454 caching_device: Optional device string or function describing where the 1455 Variable should be cached for reading. Defaults to the Variable's 1456 device. If not `None`, caches on another device. Typical use is to 1457 cache on the device where the Ops using the Variable reside, to 1458 deduplicate copying through `Switch` and other conditional statements. 1459 partitioner: Optional callable that accepts a fully defined `TensorShape` 1460 and `dtype` of the Variable to be created, and returns a list of 1461 partitions for each axis (currently only one axis can be partitioned). 1462 validate_shape: If False, allows the variable to be initialized with a 1463 value of unknown shape. If True, the default, the shape of initial_value 1464 must be known. 1465 use_resource: If False, creates a regular Variable. If True, creates an 1466 experimental ResourceVariable instead which has well-defined semantics. 1467 Defaults to False (will later change to True). 1468 constraint: An optional projection function to be applied to the variable 1469 after being updated by an `Optimizer` (e.g. used to implement norm 1470 constraints or value constraints for layer weights). The function must 1471 take as input the unprojected Tensor representing the value of the 1472 variable and return the Tensor for the projected value 1473 (which must have the same shape). Constraints are not safe to 1474 use when doing asynchronous distributed training. 1475 1476 Returns: 1477 A tuple `(shards, partitions)` where `shards` is the list of `Variable` 1478 shards and `partitions` is the output of the partitioner on the input 1479 shape. 1480 1481 Raises: 1482 ValueError: when creating a new variable and shape is not declared, 1483 or when violating reuse during variable creation. Reuse is set inside 1484 `variable_scope`. 1485 """ 1486 # pylint: disable=protected-access 1487 scope = get_variable_scope() 1488 if scope.custom_getter is not None: 1489 raise ValueError( 1490 "Private access to _get_partitioned_variable is not allowed when " 1491 "a custom getter is set. Current custom getter: %s. " 1492 "It is likely that you're using create_partitioned_variables. " 1493 "If so, consider instead using get_variable with a non-empty " 1494 "partitioner parameter instead." % scope.custom_getter) 1495 return scope._get_partitioned_variable( 1496 _get_default_variable_store(), name, shape=shape, dtype=dtype, 1497 initializer=initializer, regularizer=regularizer, trainable=trainable, 1498 collections=collections, caching_device=caching_device, 1499 partitioner=partitioner, validate_shape=validate_shape, 1500 use_resource=use_resource, constraint=constraint) 1501 # pylint: enable=protected-access 1502 1503 1504 # Named like a function for compatibility with the previous 1505 # @tf_contextlib.contextmanager definition. 1506 class _pure_variable_scope(object): # pylint: disable=invalid-name 1507 """A context for the variable_scope, see `variable_scope` for docs.""" 1508 1509 def __init__(self, 1510 name_or_scope, 1511 reuse=None, 1512 initializer=None, 1513 regularizer=None, 1514 caching_device=None, 1515 partitioner=None, 1516 custom_getter=None, 1517 old_name_scope=None, 1518 dtype=dtypes.float32, 1519 use_resource=None, 1520 constraint=None): 1521 """Creates a context for the variable_scope, see `variable_scope` for docs. 1522 1523 Note: this does not create a name scope. 1524 1525 Args: 1526 name_or_scope: `string` or `VariableScope`: the scope to open. 1527 reuse: `True` or None, or tf.AUTO_REUSE; if `None`, we inherit the parent 1528 scope's reuse flag. 1529 initializer: default initializer for variables within this scope. 1530 regularizer: default regularizer for variables within this scope. 1531 caching_device: default caching device for variables within this scope. 1532 partitioner: default partitioner for variables within this scope. 1533 custom_getter: default custom getter for variables within this scope. 1534 old_name_scope: the original name scope when re-entering a variable scope. 1535 dtype: type of the variables within this scope (defaults to `DT_FLOAT`). 1536 use_resource: If False, variables in this scope will be regular Variables. 1537 If True, experimental ResourceVariables will be creates instead, with 1538 well-defined semantics. Defaults to False (will later change to True). 1539 constraint: An optional projection function to be applied to the variable 1540 after being updated by an `Optimizer` (e.g. used to implement norm 1541 constraints or value constraints for layer weights). The function must 1542 take as input the unprojected Tensor representing the value of the 1543 variable and return the Tensor for the projected value 1544 (which must have the same shape). Constraints are not safe to 1545 use when doing asynchronous distributed training. 1546 """ 1547 self._name_or_scope = name_or_scope 1548 self._reuse = reuse 1549 self._initializer = initializer 1550 self._regularizer = regularizer 1551 self._caching_device = caching_device 1552 self._partitioner = partitioner 1553 self._custom_getter = custom_getter 1554 self._old_name_scope = old_name_scope 1555 self._dtype = dtype 1556 self._use_resource = use_resource 1557 self._constraint = constraint 1558 get_variable_scope() # Ensure that a default exists, then get a pointer. 1559 # Get the reference to the collection as we want to modify it in place. 1560 self._default_varscope = ops.get_collection_ref(_VARSCOPE_KEY) 1561 self._var_store = _get_default_variable_store() 1562 if isinstance(self._name_or_scope, VariableScope): 1563 self._new_name = self._name_or_scope.name 1564 name_scope = self._name_or_scope._name_scope # pylint: disable=protected-access 1565 # Handler for the case when we jump to a shared scope. We create a new 1566 # VariableScope (self._var_scope_object) that contains a copy of the 1567 # provided shared scope, possibly with changed reuse and initializer, if 1568 # the user requested this. 1569 variable_scope_object = VariableScope( 1570 self._name_or_scope.reuse if not self._reuse else self._reuse, 1571 name=self._new_name, 1572 initializer=self._name_or_scope.initializer, 1573 regularizer=self._name_or_scope.regularizer, 1574 caching_device=self._name_or_scope.caching_device, 1575 partitioner=self._name_or_scope.partitioner, 1576 dtype=self._name_or_scope.dtype, 1577 custom_getter=self._name_or_scope.custom_getter, 1578 name_scope=name_scope, 1579 use_resource=self._name_or_scope.use_resource, 1580 constraint=self._constraint) 1581 if self._initializer is not None: 1582 variable_scope_object.set_initializer(self._initializer) 1583 if self._regularizer is not None: 1584 variable_scope_object.set_regularizer(self._regularizer) 1585 if self._caching_device is not None: 1586 variable_scope_object.set_caching_device(self._caching_device) 1587 if self._partitioner is not None: 1588 variable_scope_object.set_partitioner(self._partitioner) 1589 if self._custom_getter is not None: 1590 variable_scope_object.set_custom_getter( 1591 _maybe_wrap_custom_getter( 1592 self._custom_getter, self._name_or_scope.custom_getter)) 1593 if self._dtype is not None: 1594 variable_scope_object.set_dtype(self._dtype) 1595 if self._use_resource is not None: 1596 variable_scope_object.set_use_resource(self._use_resource) 1597 self._cached_variable_scope_object = variable_scope_object 1598 1599 def __enter__(self): 1600 """Begins the scope block. 1601 1602 Returns: 1603 A VariableScope. 1604 Raises: 1605 ValueError: when trying to reuse within a create scope, or create within 1606 a reuse scope, or if reuse is not `None` or `True`. 1607 TypeError: when the types of some arguments are not appropriate. 1608 """ 1609 self._old = self._default_varscope[0] 1610 if isinstance(self._name_or_scope, VariableScope): 1611 self._var_store.open_variable_scope(self._new_name) 1612 self._old_subscopes = copy.copy(self._var_store.variable_scopes_count) 1613 variable_scope_object = self._cached_variable_scope_object 1614 else: 1615 # Handler for the case when we just prolong current variable scope. 1616 # VariableScope with name extended by the provided one, and inherited 1617 # reuse and initializer (except if the user provided values to set). 1618 self._new_name = ( 1619 self._old.name + "/" + self._name_or_scope if self._old.name 1620 else self._name_or_scope) 1621 self._reuse = (self._reuse 1622 or self._old.reuse) # Re-using is inherited by sub-scopes. 1623 if self._old_name_scope is None: 1624 name_scope = self._name_or_scope 1625 else: 1626 name_scope = self._old_name_scope 1627 variable_scope_object = VariableScope( 1628 self._reuse, 1629 name=self._new_name, 1630 initializer=self._old.initializer, 1631 regularizer=self._old.regularizer, 1632 caching_device=self._old.caching_device, 1633 partitioner=self._old.partitioner, 1634 dtype=self._old.dtype, 1635 use_resource=self._old.use_resource, 1636 custom_getter=self._old.custom_getter, 1637 name_scope=name_scope, 1638 constraint=self._constraint) 1639 if self._initializer is not None: 1640 variable_scope_object.set_initializer(self._initializer) 1641 if self._regularizer is not None: 1642 variable_scope_object.set_regularizer(self._regularizer) 1643 if self._caching_device is not None: 1644 variable_scope_object.set_caching_device(self._caching_device) 1645 if self._partitioner is not None: 1646 variable_scope_object.set_partitioner(self._partitioner) 1647 if self._custom_getter is not None: 1648 variable_scope_object.set_custom_getter( 1649 _maybe_wrap_custom_getter(self._custom_getter, 1650 self._old.custom_getter)) 1651 if self._dtype is not None: 1652 variable_scope_object.set_dtype(self._dtype) 1653 if self._use_resource is not None: 1654 variable_scope_object.set_use_resource(self._use_resource) 1655 self._var_store.open_variable_scope(self._new_name) 1656 self._default_varscope[0] = variable_scope_object 1657 return variable_scope_object 1658 1659 def __exit__(self, type_arg, value_arg, traceback_arg): 1660 # If jumping out from a non-prolonged scope, restore counts. 1661 if isinstance(self._name_or_scope, VariableScope): 1662 self._var_store.variable_scopes_count = self._old_subscopes 1663 else: 1664 self._var_store.close_variable_subscopes(self._new_name) 1665 self._default_varscope[0] = self._old 1666 1667 1668 def _maybe_wrap_custom_getter(custom_getter, old_getter): 1669 """Wrap a call to a custom_getter to use the old_getter internally.""" 1670 if old_getter is None: 1671 return custom_getter 1672 1673 # The new custom_getter should call the old one 1674 def wrapped_custom_getter(getter, *args, **kwargs): 1675 # Call: 1676 # custom_getter( 1677 # lambda: old_getter(true_getter, ...), *args, **kwargs) 1678 # which means custom_getter will call old_getter, which 1679 # will call the true_getter, perform any intermediate 1680 # processing, and return the results to the current 1681 # getter, which will also perform additional processing. 1682 return custom_getter( 1683 functools.partial(old_getter, getter), 1684 *args, **kwargs) 1685 return wrapped_custom_getter 1686 1687 1688 def _get_unique_variable_scope(prefix): 1689 """Get a name with the given prefix unique in the current variable scope.""" 1690 var_store = _get_default_variable_store() 1691 current_scope = get_variable_scope() 1692 name = current_scope.name + "/" + prefix if current_scope.name else prefix 1693 if var_store.variable_scope_count(name) == 0: 1694 return prefix 1695 idx = 1 1696 while var_store.variable_scope_count(name + ("_%d" % idx)) > 0: 1697 idx += 1 1698 return prefix + ("_%d" % idx) 1699 1700 1701 # Named like a function for backwards compatibility with the 1702 # @tf_contextlib.contextmanager version, which was switched to a class to avoid 1703 # some object creation overhead. 1704 @tf_export("variable_scope") # pylint: disable=invalid-name 1705 class variable_scope(object): 1706 """A context manager for defining ops that creates variables (layers). 1707 1708 This context manager validates that the (optional) `values` are from the same 1709 graph, ensures that graph is the default graph, and pushes a name scope and a 1710 variable scope. 1711 1712 If `name_or_scope` is not None, it is used as is. If `scope` is None, then 1713 `default_name` is used. In that case, if the same name has been previously 1714 used in the same scope, it will be made unique by appending `_N` to it. 1715 1716 Variable scope allows you to create new variables and to share already created 1717 ones while providing checks to not create or share by accident. For details, 1718 see the @{$variables$Variable Scope How To}, here we present only a few basic 1719 examples. 1720 1721 Simple example of how to create a new variable: 1722 1723 ```python 1724 with tf.variable_scope("foo"): 1725 with tf.variable_scope("bar"): 1726 v = tf.get_variable("v", [1]) 1727 assert v.name == "foo/bar/v:0" 1728 ``` 1729 1730 Basic example of sharing a variable AUTO_REUSE: 1731 1732 ```python 1733 def foo(): 1734 with tf.variable_scope("foo", reuse=tf.AUTO_REUSE): 1735 v = tf.get_variable("v", [1]) 1736 return v 1737 1738 v1 = foo() # Creates v. 1739 v2 = foo() # Gets the same, existing v. 1740 assert v1 == v2 1741 ``` 1742 1743 Basic example of sharing a variable with reuse=True: 1744 1745 ```python 1746 with tf.variable_scope("foo"): 1747 v = tf.get_variable("v", [1]) 1748 with tf.variable_scope("foo", reuse=True): 1749 v1 = tf.get_variable("v", [1]) 1750 assert v1 == v 1751 ``` 1752 1753 Sharing a variable by capturing a scope and setting reuse: 1754 1755 ```python 1756 with tf.variable_scope("foo") as scope: 1757 v = tf.get_variable("v", [1]) 1758 scope.reuse_variables() 1759 v1 = tf.get_variable("v", [1]) 1760 assert v1 == v 1761 ``` 1762 1763 To prevent accidental sharing of variables, we raise an exception when getting 1764 an existing variable in a non-reusing scope. 1765 1766 ```python 1767 with tf.variable_scope("foo"): 1768 v = tf.get_variable("v", [1]) 1769 v1 = tf.get_variable("v", [1]) 1770 # Raises ValueError("... v already exists ..."). 1771 ``` 1772 1773 Similarly, we raise an exception when trying to get a variable that does not 1774 exist in reuse mode. 1775 1776 ```python 1777 with tf.variable_scope("foo", reuse=True): 1778 v = tf.get_variable("v", [1]) 1779 # Raises ValueError("... v does not exists ..."). 1780 ``` 1781 1782 Note that the `reuse` flag is inherited: if we open a reusing scope, then all 1783 its sub-scopes become reusing as well. 1784 1785 A note about name scoping: Setting `reuse` does not impact the naming of other 1786 ops such as mult. See related discussion on 1787 [github#6189](https://github.com/tensorflow/tensorflow/issues/6189) 1788 1789 Note that up to and including version 1.0, it was allowed (though explicitly 1790 discouraged) to pass False to the reuse argument, yielding undocumented 1791 behaviour slightly different from None. Starting at 1.1.0 passing None and 1792 False as reuse has exactly the same effect. 1793 """ 1794 1795 def __init__(self, 1796 name_or_scope, 1797 default_name=None, 1798 values=None, 1799 initializer=None, 1800 regularizer=None, 1801 caching_device=None, 1802 partitioner=None, 1803 custom_getter=None, 1804 reuse=None, 1805 dtype=None, 1806 use_resource=None, 1807 constraint=None, 1808 auxiliary_name_scope=True): 1809 """Initialize the context manager. 1810 1811 Args: 1812 name_or_scope: `string` or `VariableScope`: the scope to open. 1813 default_name: The default name to use if the `name_or_scope` argument is 1814 `None`, this name will be uniquified. If name_or_scope is provided it 1815 won't be used and therefore it is not required and can be None. 1816 values: The list of `Tensor` arguments that are passed to the op function. 1817 initializer: default initializer for variables within this scope. 1818 regularizer: default regularizer for variables within this scope. 1819 caching_device: default caching device for variables within this scope. 1820 partitioner: default partitioner for variables within this scope. 1821 custom_getter: default custom getter for variables within this scope. 1822 reuse: `True`, None, or tf.AUTO_REUSE; if `True`, we go into reuse mode 1823 for this scope as well as all sub-scopes; if tf.AUTO_REUSE, we create 1824 variables if they do not exist, and return them otherwise; if None, we 1825 inherit the parent scope's reuse flag. When eager execution is enabled, 1826 this argument is always forced to be tf.AUTO_REUSE. 1827 dtype: type of variables created in this scope (defaults to the type 1828 in the passed scope, or inherited from parent scope). 1829 use_resource: If False, all variables will be regular Variables. If True, 1830 experimental ResourceVariables with well-defined semantics will be used 1831 instead. Defaults to False (will later change to True). When eager 1832 execution is enabled this argument is always forced to be True. 1833 constraint: An optional projection function to be applied to the variable 1834 after being updated by an `Optimizer` (e.g. used to implement norm 1835 constraints or value constraints for layer weights). The function must 1836 take as input the unprojected Tensor representing the value of the 1837 variable and return the Tensor for the projected value 1838 (which must have the same shape). Constraints are not safe to 1839 use when doing asynchronous distributed training. 1840 auxiliary_name_scope: If `True`, we create an auxiliary name scope with 1841 the scope. If `False`, we don't touch name scope. 1842 1843 Returns: 1844 A scope that can be captured and reused. 1845 1846 Raises: 1847 ValueError: when trying to reuse within a create scope, or create within 1848 a reuse scope. 1849 TypeError: when the types of some arguments are not appropriate. 1850 """ 1851 self._name_or_scope = name_or_scope 1852 self._default_name = default_name 1853 self._values = values 1854 self._initializer = initializer 1855 self._regularizer = regularizer 1856 self._caching_device = caching_device 1857 self._partitioner = partitioner 1858 self._custom_getter = custom_getter 1859 self._reuse = reuse 1860 self._dtype = dtype 1861 self._use_resource = use_resource 1862 self._constraint = constraint 1863 if self._default_name is None and self._name_or_scope is None: 1864 raise TypeError("If default_name is None then name_or_scope is required") 1865 if self._reuse is False: 1866 # We don't allow non-inheriting scopes, False = None here. 1867 self._reuse = None 1868 if not (self._reuse is True 1869 or self._reuse is None 1870 or self._reuse is AUTO_REUSE): 1871 raise ValueError("The reuse parameter must be True or False or None.") 1872 if self._values is None: 1873 self._values = [] 1874 self._in_graph_mode = not context.in_eager_mode() 1875 if self._in_graph_mode: 1876 self._graph = ops._get_graph_from_inputs(self._values) # pylint: disable=protected-access 1877 self._cached_pure_variable_scope = None 1878 self._current_name_scope = None 1879 if not isinstance(auxiliary_name_scope, bool): 1880 raise TypeError("The auxiliary_name_scope must be `True` or `False`, " 1881 "while get {}".format(auxiliary_name_scope)) 1882 self._auxiliary_name_scope = auxiliary_name_scope 1883 1884 def __enter__(self): 1885 # If the default graph is building a function, then we should not replace it 1886 # with the cached graph. 1887 if ops.get_default_graph().building_function: 1888 self._building_function = True 1889 else: 1890 self._building_function = False 1891 if self._in_graph_mode and not self._building_function: 1892 self._graph_context_manager = self._graph.as_default() 1893 self._graph_context_manager.__enter__() 1894 if self._cached_pure_variable_scope is not None: 1895 # Fast path for re-entering variable_scopes. We've held on to the pure 1896 # variable scope from a previous successful __enter__, so we avoid some 1897 # overhead by re-using that object. 1898 if self._current_name_scope is not None: 1899 self._current_name_scope.__enter__() 1900 return self._cached_pure_variable_scope.__enter__() 1901 1902 try: 1903 return self._enter_scope_uncached() 1904 except: 1905 if self._graph_context_manager is not None: 1906 self._graph_context_manager.__exit__(*sys.exc_info()) 1907 raise 1908 1909 def _enter_scope_uncached(self): 1910 """Enters the context manager when there is no cached scope yet. 1911 1912 Returns: 1913 The entered variable scope. 1914 1915 Raises: 1916 TypeError: A wrong type is passed as `scope` at __init__(). 1917 ValueError: `reuse` is incorrectly set at __init__(). 1918 """ 1919 if self._auxiliary_name_scope: 1920 # Create a new name scope later 1921 current_name_scope = None 1922 else: 1923 # Reenter the current name scope 1924 name_scope = ops.get_name_scope() 1925 if name_scope: 1926 # Hack to reenter 1927 name_scope += "/" 1928 current_name_scope = ops.name_scope(name_scope) 1929 else: 1930 # Root scope 1931 current_name_scope = ops.name_scope(name_scope) 1932 1933 # IMPORTANT: Only assign to self._cached_pure_variable_scope and 1934 # self._current_name_scope after successful __enter__() calls. 1935 if self._name_or_scope is not None: 1936 if not isinstance(self._name_or_scope, 1937 (VariableScope,) + six.string_types): 1938 raise TypeError("VariableScope: name_or_scope must be a string or " 1939 "VariableScope.") 1940 if isinstance(self._name_or_scope, six.string_types): 1941 name_scope = self._name_or_scope 1942 else: 1943 name_scope = self._name_or_scope.name.split("/")[-1] 1944 if name_scope or current_name_scope: 1945 current_name_scope = current_name_scope or ops.name_scope(name_scope) 1946 try: 1947 current_name_scope_name = current_name_scope.__enter__() 1948 except: 1949 current_name_scope.__exit__(*sys.exc_info()) 1950 raise 1951 self._current_name_scope = current_name_scope 1952 if isinstance(self._name_or_scope, six.string_types): 1953 old_name_scope = current_name_scope_name 1954 else: 1955 old_name_scope = self._name_or_scope.original_name_scope 1956 pure_variable_scope = _pure_variable_scope( 1957 self._name_or_scope, 1958 reuse=self._reuse, 1959 initializer=self._initializer, 1960 regularizer=self._regularizer, 1961 caching_device=self._caching_device, 1962 partitioner=self._partitioner, 1963 custom_getter=self._custom_getter, 1964 old_name_scope=old_name_scope, 1965 dtype=self._dtype, 1966 use_resource=self._use_resource, 1967 constraint=self._constraint) 1968 try: 1969 entered_pure_variable_scope = pure_variable_scope.__enter__() 1970 except: 1971 pure_variable_scope.__exit__(*sys.exc_info()) 1972 raise 1973 self._cached_pure_variable_scope = pure_variable_scope 1974 return entered_pure_variable_scope 1975 else: 1976 self._current_name_scope = None 1977 # This can only happen if someone is entering the root variable scope. 1978 pure_variable_scope = _pure_variable_scope( 1979 self._name_or_scope, 1980 reuse=self._reuse, 1981 initializer=self._initializer, 1982 regularizer=self._regularizer, 1983 caching_device=self._caching_device, 1984 partitioner=self._partitioner, 1985 custom_getter=self._custom_getter, 1986 dtype=self._dtype, 1987 use_resource=self._use_resource, 1988 constraint=self._constraint) 1989 try: 1990 entered_pure_variable_scope = pure_variable_scope.__enter__() 1991 except: 1992 pure_variable_scope.__exit__(*sys.exc_info()) 1993 raise 1994 self._cached_pure_variable_scope = pure_variable_scope 1995 return entered_pure_variable_scope 1996 1997 else: # Here name_or_scope is None. Using default name, but made unique. 1998 if self._reuse: 1999 raise ValueError("reuse=True cannot be used without a name_or_scope") 2000 current_name_scope = current_name_scope or ops.name_scope( 2001 self._default_name) 2002 try: 2003 current_name_scope_name = current_name_scope.__enter__() 2004 except: 2005 current_name_scope.__exit__(*sys.exc_info()) 2006 raise 2007 self._current_name_scope = current_name_scope 2008 unique_default_name = _get_unique_variable_scope(self._default_name) 2009 pure_variable_scope = _pure_variable_scope( 2010 unique_default_name, 2011 initializer=self._initializer, 2012 regularizer=self._regularizer, 2013 caching_device=self._caching_device, 2014 partitioner=self._partitioner, 2015 custom_getter=self._custom_getter, 2016 old_name_scope=current_name_scope_name, 2017 dtype=self._dtype, 2018 use_resource=self._use_resource, 2019 constraint=self._constraint) 2020 try: 2021 entered_pure_variable_scope = pure_variable_scope.__enter__() 2022 except: 2023 pure_variable_scope.__exit__(*sys.exc_info()) 2024 raise 2025 self._cached_pure_variable_scope = pure_variable_scope 2026 return entered_pure_variable_scope 2027 2028 def __exit__(self, type_arg, value_arg, traceback_arg): 2029 self._cached_pure_variable_scope.__exit__( 2030 type_arg, value_arg, traceback_arg) 2031 if self._current_name_scope: 2032 self._current_name_scope.__exit__(type_arg, value_arg, traceback_arg) 2033 if self._in_graph_mode and not self._building_function: 2034 self._graph_context_manager.__exit__(type_arg, value_arg, traceback_arg) 2035 2036 2037 # pylint: disable=g-doc-return-or-yield 2038 @tf_export("variable_op_scope") 2039 @tf_contextlib.contextmanager 2040 def variable_op_scope(values, 2041 name_or_scope, 2042 default_name=None, 2043 initializer=None, 2044 regularizer=None, 2045 caching_device=None, 2046 partitioner=None, 2047 custom_getter=None, 2048 reuse=None, 2049 dtype=None, 2050 use_resource=None, 2051 constraint=None): 2052 """Deprecated: context manager for defining an op that creates variables.""" 2053 logging.warn("tf.variable_op_scope(values, name, default_name) is deprecated," 2054 " use tf.variable_scope(name, default_name, values)") 2055 with variable_scope(name_or_scope, 2056 default_name=default_name, 2057 values=values, 2058 initializer=initializer, 2059 regularizer=regularizer, 2060 caching_device=caching_device, 2061 partitioner=partitioner, 2062 custom_getter=custom_getter, 2063 reuse=reuse, 2064 dtype=dtype, 2065 use_resource=use_resource, 2066 constraint=constraint) as scope: 2067 yield scope 2068 2069 2070 def _compute_slice_dim_and_shape(full_shape, slicing): 2071 """Computes which dimension is being sliced and the typical slice shape.""" 2072 2073 slice_shape = [0] * len(full_shape) 2074 slice_dim = None 2075 for dim, num_slices in enumerate(slicing): 2076 dim_size = full_shape[dim] 2077 if num_slices <= 0 or dim_size < num_slices: 2078 raise ValueError("Cannot create %d slices for size %d. shape: %s, " 2079 "slicing: %s" % 2080 (num_slices, full_shape[dim], full_shape, slicing)) 2081 if num_slices == 1: 2082 # Not slicing in this dimension. 2083 slice_shape[dim] = dim_size 2084 elif slice_dim is not None: 2085 # We only support slicing along one of the dimensions. 2086 raise ValueError("Can only slice a variable along one dimension: " 2087 "shape: %s, slicing: %s" % (full_shape, slicing)) 2088 else: 2089 # Note: We will add any extras onto the last slice, later. 2090 slice_dim = dim 2091 slice_shape[dim] = dim_size // num_slices 2092 2093 # Degenerate case: If "slicing" was all ones, pretend we are slicing along 2094 # the first dimension. 2095 if slice_dim is None: 2096 slice_dim = 0 2097 return slice_dim, slice_shape 2098 2099 2100 def default_variable_creator(next_creator=None, **kwargs): 2101 """Default variable creator.""" 2102 assert next_creator is None 2103 initial_value = kwargs.get("initial_value", None) 2104 trainable = kwargs.get("trainable", True) 2105 collections = kwargs.get("collections", None) 2106 validate_shape = kwargs.get("validate_shape", True) 2107 caching_device = kwargs.get("caching_device", None) 2108 name = kwargs.get("name", None) 2109 dtype = kwargs.get("dtype", None) 2110 constraint = kwargs.get("constraint", None) 2111 use_resource = kwargs.get("use_resource", None) 2112 if use_resource is None: 2113 use_resource = get_variable_scope().use_resource 2114 if use_resource or (use_resource is None and context.in_eager_mode()): 2115 return resource_variable_ops.ResourceVariable( 2116 initial_value=initial_value, trainable=trainable, 2117 collections=collections, validate_shape=validate_shape, 2118 caching_device=caching_device, name=name, dtype=dtype, 2119 constraint=constraint) 2120 elif not use_resource and context.in_eager_mode(): 2121 raise RuntimeError( 2122 "VariableScope should use resource variable when eager execution is" 2123 " enabled, but use_resource is False." 2124 ) 2125 else: 2126 return variables.Variable( 2127 initial_value=initial_value, trainable=trainable, 2128 collections=collections, validate_shape=validate_shape, 2129 caching_device=caching_device, name=name, dtype=dtype, 2130 constraint=constraint) 2131 2132 2133 def _make_getter(captured_getter, captured_previous): 2134 """Gets around capturing loop variables in python being broken.""" 2135 return lambda **kwargs: captured_getter(captured_previous, **kwargs) 2136 2137 2138 def variable(initial_value=None, 2139 trainable=True, 2140 collections=None, 2141 validate_shape=True, 2142 caching_device=None, 2143 name=None, 2144 dtype=None, 2145 constraint=None, 2146 use_resource=None): 2147 previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs) 2148 for getter in ops.get_default_graph()._get_variable_creator_stack(): # pylint: disable=protected-access 2149 previous_getter = _make_getter(getter, previous_getter) 2150 return previous_getter(initial_value=initial_value, 2151 trainable=trainable, 2152 collections=collections, 2153 validate_shape=validate_shape, 2154 caching_device=caching_device, 2155 name=name, dtype=dtype, 2156 constraint=constraint, 2157 use_resource=use_resource) 2158 2159 2160 @tf_contextlib.contextmanager 2161 def variable_creator_scope(variable_creator): 2162 """Scope which defines a variable creation function to be used by variable(). 2163 2164 variable_creator is expected to be a function with the following signature: 2165 2166 ``` 2167 def variable_creator(next_creator, **kwargs) 2168 ``` 2169 2170 The creator is supposed to eventually call the next_creator to create a 2171 variable if it does want to create a variable and not call Variable or 2172 ResourceVariable directly. This helps make creators composable. A creator may 2173 choose to create multiple variables, return already existing variables, or 2174 simply register that a variable was created and defer to the next creators in 2175 line. Creators can also modify the keyword arguments seen by the next 2176 creators. 2177 2178 Custom getters in the variable scope will eventually resolve down to these 2179 custom creators when they do create variables. 2180 2181 The valid keyword arguments in kwds are: 2182 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 2183 which is the initial value for the Variable. The initial value must have 2184 a shape specified unless `validate_shape` is set to False. Can also be a 2185 callable with no argument that returns the initial value when called. In 2186 that case, `dtype` must be specified. (Note that initializer functions 2187 from init_ops.py must first be bound to a shape before being used here.) 2188 trainable: If `True`, the default, also adds the variable to the graph 2189 collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as 2190 the default list of variables to use by the `Optimizer` classes. 2191 collections: List of graph collections keys. The new variable is added to 2192 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 2193 validate_shape: If `False`, allows the variable to be initialized with a 2194 value of unknown shape. If `True`, the default, the shape of 2195 `initial_value` must be known. 2196 caching_device: Optional device string describing where the Variable 2197 should be cached for reading. Defaults to the Variable's device. 2198 If not `None`, caches on another device. Typical use is to cache 2199 on the device where the Ops using the Variable reside, to deduplicate 2200 copying through `Switch` and other conditional statements. 2201 name: Optional name for the variable. Defaults to `'Variable'` and gets 2202 uniquified automatically. 2203 dtype: If set, initial_value will be converted to the given type. 2204 If `None`, either the datatype will be kept (if `initial_value` is 2205 a Tensor), or `convert_to_tensor` will decide. 2206 constraint: A constraint function to be applied to the variable after 2207 updates by some algorithms. 2208 use_resource: if True, a ResourceVariable is always created. 2209 2210 This set may grow over time, so it's important the signature of creators is as 2211 mentioned above. 2212 2213 Args: 2214 variable_creator: the passed creator 2215 2216 Yields: 2217 A scope in which the creator is active 2218 """ 2219 with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access 2220 yield 2221