Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for variable store."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import gc
     22 
     23 import numpy
     24 
     25 from tensorflow.python.eager import context
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.framework import errors
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.framework import test_util
     31 from tensorflow.python.ops import array_ops
     32 from tensorflow.python.ops import control_flow_ops
     33 from tensorflow.python.ops import init_ops
     34 from tensorflow.python.ops import math_ops
     35 from tensorflow.python.ops import resource_variable_ops
     36 from tensorflow.python.ops import state_ops
     37 from tensorflow.python.ops import variable_scope
     38 from tensorflow.python.ops import variables as variables_lib
     39 from tensorflow.python.platform import test
     40 
     41 
     42 class VariableScopeTest(test.TestCase):
     43 
     44   def tearDown(self):
     45     gc.collect()
     46     # This will only contain uncollectable garbage, i.e. reference cycles
     47     # involving objects with __del__ defined.
     48     self.assertEqual(0, len(gc.garbage))
     49 
     50   def testGetVar(self):
     51     vs = variable_scope._get_default_variable_store()
     52     v = vs.get_variable("v", [1])
     53     v1 = vs.get_variable("v", [1])
     54     self.assertEqual(v, v1)
     55 
     56   @test_util.run_in_graph_and_eager_modes()
     57   def testResource(self):
     58     vs = variable_scope._get_default_variable_store()
     59     v1 = vs.get_variable("v", [1], use_resource=True)
     60     self.assertTrue(isinstance(v1, resource_variable_ops.ResourceVariable))
     61 
     62   def testNameExists(self):
     63     vs = variable_scope._get_default_variable_store()
     64     # No check by default, so we can both create and get existing names.
     65     v = vs.get_variable("v", [1])
     66     v1 = vs.get_variable("v", [1])
     67     self.assertEqual(v, v1)
     68 
     69     # When reuse is False, we fail when variables are already there.
     70     vs.get_variable("w", [1], reuse=False)  # That's ok.
     71     with self.assertRaises(ValueError):
     72       vs.get_variable("v", [1], reuse=False)  # That fails.
     73     # When reuse is True, we fail when variables are new.
     74     vs.get_variable("v", [1], reuse=True)  # That's ok.
     75     with self.assertRaises(ValueError):
     76       vs.get_variable("u", [1], reuse=True)  # That fails.
     77 
     78   def testNamelessStore(self):
     79     vs = variable_scope._get_default_variable_store()
     80     vs.get_variable("v1", [2])
     81     vs.get_variable("v2", [2])
     82     expected_names = ["%s:0" % name for name in ["v1", "v2"]]
     83     self.assertEqual(
     84         set(expected_names), set([v.name for v in vs._vars.values()]))
     85 
     86   @test_util.run_in_graph_and_eager_modes()
     87   def testVarScopeInitializer(self):
     88     init = init_ops.constant_initializer(0.3)
     89     with variable_scope.variable_scope("tower0") as tower:
     90       with variable_scope.variable_scope("foo", initializer=init):
     91         v = variable_scope.get_variable("v", [])
     92         self.evaluate(variables_lib.variables_initializer([v]))
     93         self.assertAllClose(self.evaluate(v.value()), 0.3)
     94       with variable_scope.variable_scope(tower, initializer=init):
     95         w = variable_scope.get_variable("w", [])
     96         self.evaluate(variables_lib.variables_initializer([w]))
     97         self.assertAllClose(self.evaluate(w.value()), 0.3)
     98 
     99   @test_util.run_in_graph_and_eager_modes()
    100   def testVarScopeConstraint(self):
    101     constraint = lambda x: 0. * x
    102     with variable_scope.variable_scope("tower1") as tower:
    103       with variable_scope.variable_scope("foo", constraint=constraint):
    104         v = variable_scope.get_variable("v", [])
    105         self.assertEqual(v.constraint, constraint)
    106       with variable_scope.variable_scope(tower, constraint=constraint):
    107         w = variable_scope.get_variable("w", [])
    108         self.assertEqual(w.constraint, constraint)
    109 
    110   @test_util.run_in_graph_and_eager_modes()
    111   def testVarScopeDType(self):
    112     with variable_scope.variable_scope("tower2") as tower:
    113       with variable_scope.variable_scope("foo", dtype=dtypes.float16):
    114         v = variable_scope.get_variable("v", [])
    115         self.assertEqual(v.dtype.base_dtype, dtypes.float16)
    116       with variable_scope.variable_scope(tower, dtype=dtypes.float16):
    117         w = variable_scope.get_variable("w", [])
    118         self.assertEqual(w.dtype.base_dtype, dtypes.float16)
    119 
    120   def testEagerVariableStore(self):
    121     with context.eager_mode():
    122       store = variable_scope.EagerVariableStore()
    123       with store.as_default():
    124         v = variable_scope.get_variable("v", shape=(), trainable=True)
    125         w = variable_scope.get_variable("w", shape=(), trainable=False)
    126 
    127       self.assertTrue(v in store.variables())
    128       self.assertTrue(w in store.variables())
    129       self.assertTrue(v in store.trainable_variables())
    130       self.assertFalse(w in store.trainable_variables())
    131       self.assertFalse(v in store.non_trainable_variables())
    132       self.assertTrue(w in store.non_trainable_variables())
    133 
    134       # Test copying.
    135       new_store = store.copy()
    136       with new_store.as_default():
    137         new_v = variable_scope.get_variable("v")
    138         new_w = variable_scope.get_variable("w")
    139       self.assertEqual(new_v.numpy(), v.numpy())
    140       self.assertEqual(new_w.numpy(), w.numpy())
    141       self.assertTrue(new_v in new_store.variables())
    142       self.assertTrue(new_w in new_store.variables())
    143       self.assertTrue(new_v in new_store.trainable_variables())
    144       self.assertFalse(new_w in new_store.trainable_variables())
    145       self.assertFalse(new_v in new_store.non_trainable_variables())
    146       self.assertTrue(new_w in new_store.non_trainable_variables())
    147 
    148       # Check that variables are separate instances.
    149       for v in store.variables():
    150         v.assign(-1)
    151       for v in new_store.variables():
    152         v.assign(1)
    153       for v in store.variables():
    154         self.assertEqual(v.numpy(), -1)
    155       for v in new_store.variables():
    156         self.assertEqual(v.numpy(), 1)
    157 
    158   @test_util.run_in_graph_and_eager_modes()
    159   def testInitFromNonTensorValue(self):
    160     v = variable_scope.get_variable("v4", initializer=4, dtype=dtypes.int32)
    161     self.evaluate(variables_lib.variables_initializer([v]))
    162     self.assertAllClose(self.evaluate(v.value()), 4)
    163 
    164     w = variable_scope.get_variable(
    165         "w4", initializer=numpy.array([1, 2, 3]), dtype=dtypes.int64)
    166     self.evaluate(variables_lib.variables_initializer([w]))
    167     self.assertAllClose(self.evaluate(w.value()), [1, 2, 3])
    168 
    169     if context.in_graph_mode():
    170       with self.assertRaises(TypeError):
    171         variable_scope.get_variable("x4", initializer={})
    172     else:
    173       with self.assertRaises(ValueError):
    174         variable_scope.get_variable("x4", initializer={})
    175 
    176   @test_util.run_in_graph_and_eager_modes()
    177   def testInitFromNonInitializer(self):
    178     # Test various dtypes with zeros initializer as following:
    179     types = [
    180         dtypes.int8, dtypes.uint8, dtypes.int16, dtypes.uint16, dtypes.int32,
    181         dtypes.int64, dtypes.bool
    182     ]
    183 
    184     # Use different variable_name to distinguish various dtypes
    185     for (i, dtype) in enumerate(types):
    186       x = variable_scope.get_variable(
    187           name="xx%d" % i, shape=(3, 4), dtype=dtype)
    188       y = variable_scope.get_variable(
    189           name="yy%d" % i,
    190           shape=(3, 4),
    191           dtype=dtype,
    192           initializer=init_ops.zeros_initializer(dtype=dtype))
    193 
    194       self.evaluate(variables_lib.global_variables_initializer())
    195       self.assertAllEqual(self.evaluate(x.value()), self.evaluate(y.value()))
    196 
    197   # TODO(alive): support variable partitioning/caching in eager mode.
    198   def testVarScopeCachingDevice(self):
    199     with self.test_session():
    200       caching_device = "/job:moo"
    201       with variable_scope.variable_scope("tower"):
    202         with variable_scope.variable_scope(
    203             "caching", caching_device=caching_device):
    204           v = variable_scope.get_variable("v", [])
    205           self.assertTrue(v.value().device.startswith(caching_device))
    206 
    207           with variable_scope.variable_scope("child"):
    208             v2 = variable_scope.get_variable("v", [])
    209             self.assertTrue(v2.value().device.startswith(caching_device))
    210 
    211           with variable_scope.variable_scope("not_cached", caching_device=""):
    212             v2_not_cached = variable_scope.get_variable("v", [])
    213             self.assertFalse(v2_not_cached.value().device.startswith(
    214                 caching_device))
    215 
    216           with variable_scope.variable_scope(
    217               "not_cached_identity_device",
    218               caching_device=lambda op: op.device):
    219             v2_identity_device = variable_scope.get_variable("v", [])
    220             self.assertFalse(v2_identity_device.value().device.startswith(
    221                 caching_device))
    222 
    223           with variable_scope.variable_scope("we_will_do_it_live") as vs_live:
    224             vs_live.set_caching_device("/job:live")
    225             v_live = variable_scope.get_variable("v", [])
    226             self.assertTrue(v_live.value().device.startswith("/job:live"))
    227 
    228         v_tower = variable_scope.get_variable("v", [])
    229         self.assertFalse(v_tower.value().device.startswith(caching_device))
    230 
    231   @test_util.run_in_graph_and_eager_modes()
    232   def testVarScopeRegularizer(self):
    233     init = init_ops.constant_initializer(0.3)
    234 
    235     def regularizer1(v):
    236       return math_ops.reduce_mean(v) + 0.1
    237 
    238     def regularizer2(v):
    239       return math_ops.reduce_mean(v) + 0.2
    240 
    241     with variable_scope.variable_scope(
    242         "tower3", regularizer=regularizer1) as tower:
    243       with variable_scope.variable_scope("foo", initializer=init):
    244         v = variable_scope.get_variable("v", [])
    245         self.evaluate(variables_lib.variables_initializer([v]))
    246         losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
    247         self.assertEqual(1, len(losses))
    248         self.assertAllClose(self.evaluate(losses[0]), 0.4)
    249       with variable_scope.variable_scope(tower, initializer=init) as vs:
    250         u = variable_scope.get_variable("u", [])
    251         vs.set_regularizer(regularizer2)
    252         w = variable_scope.get_variable("w", [])
    253         # Next 3 variable not regularized to test disabling regularization.
    254         x = variable_scope.get_variable(
    255             "x", [], regularizer=variable_scope.no_regularizer)
    256         with variable_scope.variable_scope(
    257             "baz", regularizer=variable_scope.no_regularizer):
    258           y = variable_scope.get_variable("y", [])
    259         vs.set_regularizer(variable_scope.no_regularizer)
    260         z = variable_scope.get_variable("z", [])
    261         # Check results.
    262         losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
    263         self.assertEqual(3, len(losses))
    264         self.evaluate(variables_lib.variables_initializer([u, w, x, y, z]))
    265         self.assertAllClose(self.evaluate(losses[0]), 0.4)
    266         self.assertAllClose(self.evaluate(losses[1]), 0.4)
    267         self.assertAllClose(self.evaluate(losses[2]), 0.5)
    268       with variable_scope.variable_scope("foo", reuse=True):
    269         # reuse=True is for now only supported when eager execution is disabled.
    270         if context.in_graph_mode():
    271           v = variable_scope.get_variable("v",
    272                                           [])  # "v" is alredy there, reused
    273           losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
    274           self.assertEqual(3, len(losses))  # No new loss added.
    275 
    276   @test_util.run_in_graph_and_eager_modes()
    277   def testInitializeFromValue(self):
    278     init = constant_op.constant(0.1)
    279     w = variable_scope.get_variable("v", initializer=init)
    280     self.evaluate(variables_lib.variables_initializer([w]))
    281     self.assertAllClose(self.evaluate(w.value()), 0.1)
    282 
    283     with self.assertRaisesRegexp(ValueError, "shape"):
    284       # We disallow explicit shape specification when initializer is constant.
    285       variable_scope.get_variable("u", [1], initializer=init)
    286 
    287     with variable_scope.variable_scope("foo", initializer=init):
    288       # Constant initializer can be passed through scopes if needed.
    289       v = variable_scope.get_variable("v")
    290       self.evaluate(variables_lib.variables_initializer([v]))
    291       self.assertAllClose(self.evaluate(v.value()), 0.1)
    292 
    293     # Check that non-float32 initializer creates a non-float32 variable.
    294     init = constant_op.constant(1, dtype=dtypes.int32)
    295     t = variable_scope.get_variable("t", initializer=init)
    296     self.assertEqual(t.dtype.base_dtype, dtypes.int32)
    297 
    298     # Raise error if `initializer` dtype and `dtype` are not identical.
    299     with self.assertRaisesRegexp(ValueError, "don't match"):
    300       variable_scope.get_variable("s", initializer=init, dtype=dtypes.float64)
    301 
    302   def testControlDeps(self):
    303     with self.test_session() as sess:
    304       v0 = variable_scope.get_variable(
    305           "v0", [1], initializer=init_ops.constant_initializer(0))
    306       with ops.control_dependencies([v0.value()]):
    307         v1 = variable_scope.get_variable(
    308             "v1", [1], initializer=init_ops.constant_initializer(1))
    309         add = v1 + v0
    310       # v0 should be uninitialized.
    311       with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
    312         sess.run(v0)
    313       # We should be able to initialize and run v1 without initializing
    314       # v0, even if the variable was created with a control dep on v0.
    315       sess.run(v1.initializer)
    316       self.assertEqual(1, sess.run(v1))
    317       # v0 should still be uninitialized.
    318       with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
    319         sess.run(v0)
    320       with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
    321         sess.run(add)
    322       # If we initialize v0 we should be able to run 'add'.
    323       sess.run(v0.initializer)
    324       sess.run(add)
    325 
    326   def testControlFlow(self):
    327     with self.test_session() as sess:
    328       v0 = variable_scope.get_variable(
    329           "v0", [], initializer=init_ops.constant_initializer(0))
    330       var_dict = {}
    331 
    332       # Call get_variable in each of the cond clauses.
    333       def var_in_then_clause():
    334         v1 = variable_scope.get_variable(
    335             "v1", [1], initializer=init_ops.constant_initializer(1))
    336         var_dict["v1"] = v1
    337         return v1 + v0
    338 
    339       def var_in_else_clause():
    340         v2 = variable_scope.get_variable(
    341             "v2", [1], initializer=init_ops.constant_initializer(2))
    342         var_dict["v2"] = v2
    343         return v2 + v0
    344 
    345       add = control_flow_ops.cond(
    346           math_ops.less(v0, 10), var_in_then_clause, var_in_else_clause)
    347       v1 = var_dict["v1"]
    348       v2 = var_dict["v2"]
    349       # We should be able to initialize and run v1 and v2 without initializing
    350       # v0, even if the variable was created with a control dep on v0.
    351       sess.run(v1.initializer)
    352       self.assertEqual([1], sess.run(v1))
    353       sess.run(v2.initializer)
    354       self.assertEqual([2], sess.run(v2))
    355       # v0 should still be uninitialized.
    356       with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
    357         sess.run(v0)
    358       # We should not be able to run 'add' yet.
    359       with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
    360         sess.run(add)
    361       # If we initialize v0 we should be able to run 'add'.
    362       sess.run(v0.initializer)
    363       sess.run(add)
    364 
    365   @test_util.run_in_graph_and_eager_modes()
    366   def testGetVariableScope(self):
    367     # Test the get_variable_scope() function and setting properties of result.
    368     init = init_ops.constant_initializer(0.3)
    369     with variable_scope.variable_scope("bar"):
    370       new_init1 = variable_scope.get_variable_scope().initializer
    371       self.assertEqual(new_init1, None)
    372       # Check that we can set initializer like this.
    373       variable_scope.get_variable_scope().set_initializer(init)
    374       v = variable_scope.get_variable("v", [])
    375       self.evaluate(variables_lib.variables_initializer([v]))
    376       self.assertAllClose(self.evaluate(v.value()), 0.3)
    377       if context.in_graph_mode():
    378         # Check that we can set reuse.
    379         variable_scope.get_variable_scope().reuse_variables()
    380         with self.assertRaises(ValueError):  # Fail, w does not exist yet.
    381           variable_scope.get_variable("w", [1])
    382     # Check that the set initializer goes away.
    383     new_init = variable_scope.get_variable_scope().initializer
    384     self.assertEqual(new_init, None)
    385 
    386   @test_util.run_in_graph_and_eager_modes()
    387   def testVarScope(self):
    388     with variable_scope.variable_scope("tower4") as tower:
    389       self.assertEqual(tower.name, "tower4")
    390       with ops.name_scope("scope") as sc:
    391         self.assertEqual(sc, "tower4/scope/")
    392 
    393     with variable_scope.variable_scope("tower5"):
    394       with variable_scope.variable_scope("bar") as bar:
    395         self.assertEqual(bar.name, "tower5/bar")
    396         with ops.name_scope("scope") as sc:
    397           self.assertEqual(sc, "tower5/bar/scope/")
    398 
    399     with variable_scope.variable_scope("tower6"):
    400       with variable_scope.variable_scope(tower, reuse=True) as tower_shared:
    401         self.assertEqual(tower_shared.name, "tower4")
    402         with ops.name_scope("scope") as sc:
    403           self.assertEqual(sc, "tower6/tower4/scope/")
    404 
    405   @test_util.run_in_graph_and_eager_modes()
    406   def testVarScopeNameScope(self):
    407     with ops.name_scope("testVarScopeNameScope1"):
    408       with variable_scope.variable_scope("tower") as tower:
    409         with ops.name_scope("scope2") as sc2:
    410           self.assertEqual(sc2, "testVarScopeNameScope1/tower/scope2/")
    411       if context.in_graph_mode():
    412         with variable_scope.variable_scope(
    413             tower):  # Re-entering acts like another "tower".
    414           with ops.name_scope("scope2") as sc2:
    415             self.assertEqual(sc2, "testVarScopeNameScope1/tower_1/scope2/")
    416         with variable_scope.variable_scope(
    417             "tower"):  # Re-entering by string acts the same.
    418           with ops.name_scope("scope2") as sc2:
    419             self.assertEqual(sc2, "testVarScopeNameScope1/tower_2/scope2/")
    420 
    421     with ops.name_scope("testVarScopeNameScope2"):
    422       with variable_scope.variable_scope("tower"):
    423         with ops.name_scope("scope2") as sc2:
    424           self.assertEqual(sc2, "testVarScopeNameScope2/tower/scope2/")
    425       if context.in_graph_mode():
    426         with variable_scope.variable_scope(tower):
    427           with ops.name_scope("scope2") as sc2:
    428             self.assertEqual(sc2, "testVarScopeNameScope2/tower_1/scope2/")
    429 
    430     root_var_scope = variable_scope.get_variable_scope()
    431     with ops.name_scope("testVarScopeNameScope3"):
    432       with variable_scope.variable_scope(root_var_scope):
    433         with ops.name_scope("scope2") as sc2:
    434           self.assertEqual(sc2, "testVarScopeNameScope3/scope2/")
    435 
    436   def testVarScopeOriginalNameScope(self):
    437     with self.test_session():
    438       with ops.name_scope("scope1"):
    439         with variable_scope.variable_scope("tower") as tower:
    440           self.assertEqual(tower.original_name_scope, "scope1/tower/")
    441           with ops.name_scope("scope2") as sc2:
    442             self.assertEqual(sc2, "scope1/tower/scope2/")
    443       with ops.name_scope("scope2"):
    444         with variable_scope.variable_scope(tower) as tower1:
    445           # Re-entering preserves original name scope.
    446           self.assertEqual(tower1.original_name_scope, "scope1/tower/")
    447           with ops.name_scope("foo") as sc2:
    448             self.assertEqual(sc2, "scope2/tower/foo/")
    449         # Test re-entering original name scope.
    450         with ops.name_scope(tower.original_name_scope):
    451           with ops.name_scope("bar") as sc3:
    452             self.assertEqual(sc3, "scope1/tower/bar/")
    453       with ops.name_scope("scope2"):
    454         with variable_scope.variable_scope(tower):
    455           with ops.name_scope(tower.original_name_scope):
    456             with ops.name_scope("bar") as sc3:
    457               self.assertEqual(sc3, "scope1/tower/bar_1/")
    458 
    459   def testVarScopeObjectReuse(self):
    460     with self.test_session():
    461       vs = None
    462       with variable_scope.variable_scope("jump", reuse=True) as scope:
    463         vs = scope
    464 
    465       with variable_scope.variable_scope(vs) as jump:
    466         self.assertTrue(jump.reuse)
    467 
    468       with variable_scope.variable_scope(vs, reuse=True) as jump_reuse:
    469         self.assertTrue(jump_reuse.reuse)
    470 
    471       with variable_scope.variable_scope(vs, reuse=False) as jump_no_reuse:
    472         self.assertTrue(jump_no_reuse.reuse)  # Inherited, cannot be undone.
    473 
    474       with variable_scope.variable_scope("jump", reuse=False) as scope:
    475         vs = scope
    476 
    477       with variable_scope.variable_scope(vs) as jump:
    478         self.assertFalse(jump.reuse)
    479 
    480       with variable_scope.variable_scope(vs, reuse=True) as jump_reuse:
    481         self.assertTrue(jump_reuse.reuse)
    482 
    483       with variable_scope.variable_scope(vs, reuse=False) as jump_no_reuse:
    484         self.assertFalse(jump_no_reuse.reuse)
    485 
    486   def testVarScopeGetOrCreateReuse(self):
    487     with self.test_session():
    488       def test_value(value):
    489         x = constant_op.constant(value)
    490         with variable_scope.variable_scope("testVarScopeGetOrCreateReuse_bar",
    491                                            reuse=variable_scope.AUTO_REUSE):
    492           _ = state_ops.assign(variable_scope.get_variable("var", []), x)
    493         with variable_scope.variable_scope("testVarScopeGetOrCreateReuse_bar",
    494                                            reuse=variable_scope.AUTO_REUSE):
    495           _ = variable_scope.get_variable("var", [])
    496         self.assertEqual(value, x.eval())
    497       test_value(42.)  # Variable is created.
    498       test_value(13.)  # Variable is reused hereafter.
    499       test_value(17.)
    500 
    501   def testVarOpScope(self):
    502     with self.test_session():
    503       with ops.name_scope("testVarOpScope1"):
    504         with variable_scope.variable_scope("tower", "default", []):
    505           self.assertEqual(
    506               variable_scope.get_variable("w", []).name, "tower/w:0")
    507           with ops.name_scope("testVarOpScope2") as sc2:
    508             self.assertEqual(sc2, "testVarOpScope1/tower/testVarOpScope2/")
    509         with variable_scope.variable_scope("tower", "default", []):
    510           with self.assertRaises(ValueError):
    511             variable_scope.get_variable("w", [])
    512           with ops.name_scope("testVarOpScope2") as sc2:
    513             self.assertEqual(sc2, "testVarOpScope1/tower_1/testVarOpScope2/")
    514 
    515       with ops.name_scope("testVarOpScope2"):
    516         with variable_scope.variable_scope(None, "default", []):
    517           self.assertEqual(
    518               variable_scope.get_variable("w", []).name, "default/w:0")
    519           with ops.name_scope("testVarOpScope2") as sc2:
    520             self.assertEqual(sc2, "testVarOpScope2/default/testVarOpScope2/")
    521         with variable_scope.variable_scope(None, "default", []):
    522           self.assertEqual(
    523               variable_scope.get_variable("w", []).name, "default_1/w:0")
    524           with ops.name_scope("testVarOpScope2") as sc2:
    525             self.assertEqual(sc2, "testVarOpScope2/default_1/testVarOpScope2/")
    526 
    527   def testVarOpScopeUniqueNamesInterleavedSubstringScopes(self):
    528     with self.test_session():
    529       with variable_scope.variable_scope(None, "defaultScope1"):
    530         with variable_scope.variable_scope(None, "layer"):
    531           self.assertEqual(
    532               variable_scope.get_variable("w", []).name,
    533               "defaultScope1/layer/w:0")
    534       with variable_scope.variable_scope(None, "defaultScope1"):
    535         with variable_scope.variable_scope(None, "layer"):
    536           self.assertEqual(
    537               variable_scope.get_variable("w", []).name,
    538               "defaultScope1_1/layer/w:0")
    539       with variable_scope.variable_scope(None, "defaultScope"):
    540         with variable_scope.variable_scope(None, "layer"):
    541           self.assertEqual(
    542               variable_scope.get_variable("w", []).name,
    543               "defaultScope/layer/w:0")
    544       with variable_scope.variable_scope(None, "defaultScope1"):
    545         with variable_scope.variable_scope(None, "layer"):
    546           self.assertEqual(
    547               variable_scope.get_variable("w", []).name,
    548               "defaultScope1_2/layer/w:0")
    549 
    550   def testVarOpScopeUniqueNamesWithJump(self):
    551     with self.test_session():
    552       with variable_scope.variable_scope("default") as default:
    553         with variable_scope.variable_scope(None, "layer"):
    554           self.assertEqual(
    555               variable_scope.get_variable("w", []).name,
    556               "default/layer/w:0")
    557         with variable_scope.variable_scope(None, "layer"):
    558           self.assertEqual(
    559               variable_scope.get_variable("w", []).name,
    560               "default/layer_1/w:0")
    561         with variable_scope.variable_scope(default):
    562           pass
    563         # No matter the jump in the middle, unique numbering continues.
    564         with variable_scope.variable_scope(None, "layer"):
    565           self.assertEqual(
    566               variable_scope.get_variable("w", []).name,
    567               "default/layer_2/w:0")
    568 
    569   def testVarOpScopeReuse(self):
    570     with self.test_session():
    571       with variable_scope.variable_scope("outer") as outer:
    572         with variable_scope.variable_scope("tower", "default", []):
    573           self.assertEqual(
    574               variable_scope.get_variable("w", []).name, "outer/tower/w:0")
    575           with ops.name_scope("scope2") as sc2:
    576             self.assertEqual(sc2, "outer/tower/scope2/")
    577         with variable_scope.variable_scope(None, "default", []):
    578           self.assertEqual(
    579               variable_scope.get_variable("w", []).name, "outer/default/w:0")
    580           with ops.name_scope("scope2") as sc2:
    581             self.assertEqual(sc2, "outer/default/scope2/")
    582 
    583       with variable_scope.variable_scope(outer, reuse=True) as outer:
    584         with variable_scope.variable_scope("tower", "default", []):
    585           self.assertEqual(
    586               variable_scope.get_variable("w", []).name, "outer/tower/w:0")
    587           with ops.name_scope("scope2") as sc2:
    588             self.assertEqual(sc2, "outer_1/tower/scope2/")
    589         with variable_scope.variable_scope(None, "default", []):
    590           self.assertEqual(
    591               variable_scope.get_variable("w", []).name, "outer/default/w:0")
    592           with ops.name_scope("scope2") as sc2:
    593             self.assertEqual(sc2, "outer_1/default/scope2/")
    594 
    595   def testVarScopeGetVar(self):
    596     with self.test_session():
    597       with variable_scope.variable_scope("root"):
    598         with variable_scope.variable_scope("towerA") as tower_a:
    599           va = variable_scope.get_variable("v", [1])
    600           self.assertEqual(va.name, "root/towerA/v:0")
    601 
    602         with variable_scope.variable_scope(tower_a, reuse=True):
    603           va2 = variable_scope.get_variable("v", [1])
    604           self.assertEqual(va2, va)
    605 
    606         with variable_scope.variable_scope("towerB"):
    607           vb = variable_scope.get_variable("v", [1])
    608           self.assertEqual(vb.name, "root/towerB/v:0")
    609 
    610         with self.assertRaises(ValueError):
    611           with variable_scope.variable_scope("towerA"):
    612             va2 = variable_scope.get_variable("v", [1])
    613 
    614         with variable_scope.variable_scope("towerA", reuse=True):
    615           va2 = variable_scope.get_variable("v", [1])
    616           self.assertEqual(va2, va)
    617 
    618         with variable_scope.variable_scope("foo"):
    619           with variable_scope.variable_scope("bar"):
    620             v = variable_scope.get_variable("v", [1])
    621             self.assertEqual(v.name, "root/foo/bar/v:0")
    622             with variable_scope.variable_scope(tower_a, reuse=True):
    623               va3 = variable_scope.get_variable("v", [1])
    624               self.assertEqual(va, va3)
    625 
    626         with self.assertRaises(ValueError):
    627           with variable_scope.variable_scope(tower_a, reuse=True):
    628             with variable_scope.variable_scope("baz"):
    629               variable_scope.get_variable("v", [1])
    630 
    631         with self.assertRaises(ValueError) as exc:
    632           with variable_scope.variable_scope(tower_a, reuse=True):
    633             variable_scope.get_variable("v", [2])  # Different shape.
    634         self.assertEqual("shape" in str(exc.exception), True)
    635 
    636         with self.assertRaises(ValueError) as exc:
    637           with variable_scope.variable_scope(tower_a, reuse=True):
    638             variable_scope.get_variable("v", [1], dtype=dtypes.int32)
    639         self.assertEqual("dtype" in str(exc.exception), True)
    640 
    641   def testVarScopeOuterScope(self):
    642     with self.test_session():
    643       with variable_scope.variable_scope("outer") as outer:
    644         pass
    645       with variable_scope.variable_scope(outer):
    646         self.assertEqual(variable_scope.get_variable("w", []).name, "outer/w:0")
    647         with ops.name_scope("scope2") as sc2:
    648           self.assertEqual(sc2, "outer_1/scope2/")
    649         with variable_scope.variable_scope("default"):
    650           self.assertEqual(
    651               variable_scope.get_variable("w", []).name, "outer/default/w:0")
    652           with ops.name_scope("scope2") as sc2:
    653             self.assertEqual(sc2, "outer_1/default/scope2/")
    654 
    655       with variable_scope.variable_scope(outer, reuse=True):
    656         self.assertEqual(variable_scope.get_variable("w", []).name, "outer/w:0")
    657         with ops.name_scope("scope2") as sc2:
    658           self.assertEqual(sc2, "outer_2/scope2/")
    659         with variable_scope.variable_scope("default", reuse=True):
    660           self.assertEqual(
    661               variable_scope.get_variable("w", []).name, "outer/default/w:0")
    662           with ops.name_scope("scope2") as sc2:
    663             self.assertEqual(sc2, "outer_2/default/scope2/")
    664 
    665   def testVarScopeNestedOuterScope(self):
    666     with self.test_session():
    667       with variable_scope.variable_scope("outer") as outer:
    668         with variable_scope.variable_scope(outer):
    669           self.assertEqual(
    670               variable_scope.get_variable("w", []).name, "outer/w:0")
    671           with ops.name_scope("scope2") as sc2:
    672             self.assertEqual(sc2, "outer/outer/scope2/")
    673         with variable_scope.variable_scope("default"):
    674           self.assertEqual(
    675               variable_scope.get_variable("w", []).name, "outer/default/w:0")
    676           with ops.name_scope("scope2") as sc2:
    677             self.assertEqual(sc2, "outer/default/scope2/")
    678 
    679         with variable_scope.variable_scope(outer, reuse=True):
    680           self.assertEqual(
    681               variable_scope.get_variable("w", []).name, "outer/w:0")
    682           with ops.name_scope("scope2") as sc2:
    683             self.assertEqual(sc2, "outer/outer_1/scope2/")
    684         with variable_scope.variable_scope("default", reuse=True):
    685           self.assertEqual(
    686               variable_scope.get_variable("w", []).name, "outer/default/w:0")
    687           with ops.name_scope("scope2") as sc2:
    688             self.assertEqual(sc2, "outer/default_1/scope2/")
    689 
    690   def testVarOpScopeReuseParam(self):
    691     with self.test_session():
    692       with variable_scope.variable_scope("outer") as outer:
    693         with variable_scope.variable_scope("tower", "default", []):
    694           self.assertEqual(
    695               variable_scope.get_variable("w", []).name, "outer/tower/w:0")
    696           with ops.name_scope("scope2") as sc2:
    697             self.assertEqual(sc2, "outer/tower/scope2/")
    698         with variable_scope.variable_scope(None, "default", []):
    699           self.assertEqual(
    700               variable_scope.get_variable("w", []).name, "outer/default/w:0")
    701           with ops.name_scope("scope2") as sc2:
    702             self.assertEqual(sc2, "outer/default/scope2/")
    703 
    704       with variable_scope.variable_scope(outer) as outer:
    705         with variable_scope.variable_scope("tower", "default", reuse=True):
    706           self.assertEqual(
    707               variable_scope.get_variable("w", []).name, "outer/tower/w:0")
    708           with ops.name_scope("scope2") as sc2:
    709             self.assertEqual(sc2, "outer_1/tower/scope2/")
    710         outer.reuse_variables()
    711         with variable_scope.variable_scope(None, "default", []):
    712           self.assertEqual(
    713               variable_scope.get_variable("w", []).name, "outer/default/w:0")
    714           with ops.name_scope("scope2") as sc2:
    715             self.assertEqual(sc2, "outer_1/default/scope2/")
    716 
    717   def testVarOpScopeReuseError(self):
    718     with self.test_session():
    719       with self.assertRaises(ValueError):
    720         with variable_scope.variable_scope(None, "default", reuse=True):
    721           self.assertEqual(
    722               variable_scope.get_variable("w", []).name, "outer/tower/w:0")
    723 
    724   def testVarOpScopeOuterScope(self):
    725     with self.test_session():
    726       with variable_scope.variable_scope("outer") as outer:
    727         pass
    728       with variable_scope.variable_scope(outer, "default", []):
    729         self.assertEqual(variable_scope.get_variable("w", []).name, "outer/w:0")
    730         with ops.name_scope("scope2") as sc2:
    731           self.assertEqual(sc2, "outer_1/scope2/")
    732         with variable_scope.variable_scope(None, "default", []):
    733           self.assertEqual(
    734               variable_scope.get_variable("w", []).name, "outer/default/w:0")
    735           with ops.name_scope("scope2") as sc2:
    736             self.assertEqual(sc2, "outer_1/default/scope2/")
    737 
    738       with variable_scope.variable_scope(outer, "default", reuse=True):
    739         self.assertEqual(variable_scope.get_variable("w", []).name, "outer/w:0")
    740         with ops.name_scope("scope2") as sc2:
    741           self.assertEqual(sc2, "outer_2/scope2/")
    742         outer.reuse_variables()
    743         with variable_scope.variable_scope(None, "default", []):
    744           self.assertEqual(
    745               variable_scope.get_variable("w", []).name, "outer/default/w:0")
    746           with ops.name_scope("scope2") as sc2:
    747             self.assertEqual(sc2, "outer_2/default/scope2/")
    748 
    749   def testVarOpScopeNestedOuterScope(self):
    750     with self.test_session():
    751       with variable_scope.variable_scope("outer") as outer:
    752         with variable_scope.variable_scope(outer, "default", []):
    753           self.assertEqual(
    754               variable_scope.get_variable("w", []).name, "outer/w:0")
    755           with ops.name_scope("scope2") as sc2:
    756             self.assertEqual(sc2, "outer/outer/scope2/")
    757         with variable_scope.variable_scope(None, "default", []):
    758           self.assertEqual(
    759               variable_scope.get_variable("w", []).name, "outer/default/w:0")
    760           with ops.name_scope("scope2") as sc2:
    761             self.assertEqual(sc2, "outer/default/scope2/")
    762 
    763       with variable_scope.variable_scope(outer, "default", reuse=True):
    764         self.assertEqual(variable_scope.get_variable("w", []).name, "outer/w:0")
    765         with ops.name_scope("scope2") as sc2:
    766           self.assertEqual(sc2, "outer_1/scope2/")
    767         with variable_scope.variable_scope(None, "default", []):
    768           self.assertEqual(
    769               variable_scope.get_variable("w", []).name, "outer/default/w:0")
    770           with ops.name_scope("scope2") as sc2:
    771             self.assertEqual(sc2, "outer_1/default/scope2/")
    772 
    773   def testBasicWhenAuxiliaryNameScopeIsFalse(self):
    774     with self.test_session():
    775       with variable_scope.variable_scope(
    776           "scope", auxiliary_name_scope=False) as scope:
    777         self.assertEqual(scope.original_name_scope, "")
    778         self.assertEqual(variable_scope.get_variable("w", []).name, "scope/w:0")
    779         self.assertEqual(constant_op.constant([], name="c").name, "c:0")
    780       with variable_scope.variable_scope(scope, auxiliary_name_scope=False):
    781         self.assertEqual(scope.original_name_scope, "")
    782         self.assertEqual(
    783             variable_scope.get_variable("w1", []).name, "scope/w1:0")
    784         self.assertEqual(constant_op.constant([], name="c1").name, "c1:0")
    785       # Recheck: new name scope is NOT created before
    786       with ops.name_scope("scope"):
    787         self.assertEqual(constant_op.constant([], name="c").name, "scope/c:0")
    788 
    789       with variable_scope.variable_scope("outer"):
    790         with variable_scope.variable_scope(
    791             "inner", auxiliary_name_scope=False) as inner:
    792           self.assertEqual(inner.original_name_scope, "outer/")
    793           self.assertEqual(
    794               variable_scope.get_variable("w", []).name, "outer/inner/w:0")
    795           self.assertEqual(constant_op.constant([], name="c").name, "outer/c:0")
    796         with variable_scope.variable_scope(
    797             inner, auxiliary_name_scope=False) as inner1:
    798           self.assertEqual(inner1.original_name_scope, "outer/")
    799           self.assertEqual(
    800               variable_scope.get_variable("w1", []).name, "outer/inner/w1:0")
    801           self.assertEqual(
    802               constant_op.constant([], name="c1").name, "outer/c1:0")
    803         # Recheck: new name scope is NOT created before
    804         with ops.name_scope("inner"):
    805           self.assertEqual(
    806               constant_op.constant([], name="c").name, "outer/inner/c:0")
    807 
    808   def testCreatedByDefaultNameWhenAuxiliaryNameScopeIsFalse(self):
    809     with self.test_session():
    810       with variable_scope.variable_scope(
    811           None, default_name="default", auxiliary_name_scope=False) as scope:
    812         self.assertEqual(scope.original_name_scope, "")
    813         self.assertEqual(
    814             variable_scope.get_variable("w", []).name, "default/w:0")
    815         self.assertEqual(constant_op.constant([], name="c").name, "c:0")
    816       # Recheck: new name scope is NOT created before
    817       with ops.name_scope("default"):
    818         self.assertEqual(constant_op.constant([], name="c").name, "default/c:0")
    819 
    820       with variable_scope.variable_scope("outer"):
    821         with variable_scope.variable_scope(
    822             None, default_name="default", auxiliary_name_scope=False) as inner:
    823           self.assertEqual(inner.original_name_scope, "outer/")
    824           self.assertEqual(
    825               variable_scope.get_variable("w", []).name, "outer/default/w:0")
    826           self.assertEqual(constant_op.constant([], name="c").name, "outer/c:0")
    827         # Recheck: new name scope is NOT created before
    828         with ops.name_scope("default"):
    829           self.assertEqual(
    830               constant_op.constant([], name="c").name, "outer/default/c:0")
    831 
    832   def testReenterRootScopeWhenAuxiliaryNameScopeIsFalse(self):
    833     with self.test_session():
    834       root_scope = variable_scope.get_variable_scope()
    835       with variable_scope.variable_scope(
    836           root_scope, auxiliary_name_scope=False) as scope:
    837         self.assertEqual(scope.original_name_scope, "")
    838         self.assertEqual(variable_scope.get_variable("w", []).name, "w:0")
    839         self.assertEqual(constant_op.constant([], name="c").name, "c:0")
    840 
    841       with variable_scope.variable_scope("outer"):
    842         with variable_scope.variable_scope(
    843             root_scope, auxiliary_name_scope=False) as inner:
    844           self.assertEqual(inner.original_name_scope, "")
    845           self.assertEqual(variable_scope.get_variable("w1", []).name, "w1:0")
    846           self.assertEqual(
    847               constant_op.constant([], name="c1").name, "outer/c1:0")
    848 
    849   def testAuxiliaryNameScopeIsInvalid(self):
    850     with self.test_session():
    851       with self.assertRaisesRegexp(TypeError, "auxiliary_name_scope"):
    852         with variable_scope.variable_scope(
    853             None, default_name="scope", auxiliary_name_scope="invalid"):
    854           pass
    855 
    856       with self.assertRaisesRegexp(TypeError, "auxiliary_name_scope"):
    857         with variable_scope.variable_scope(
    858             "scope", auxiliary_name_scope="invalid"):
    859           pass
    860 
    861       with variable_scope.variable_scope("scope") as scope:
    862         pass
    863       with self.assertRaisesRegexp(TypeError, "auxiliary_name_scope"):
    864         with variable_scope.variable_scope(
    865             scope, auxiliary_name_scope="invalid"):
    866           pass
    867 
    868   def testReuseScopeWithoutNameScopeCollision(self):
    869     # Github issue: #13429
    870     with self.test_session():
    871       with variable_scope.variable_scope("outer"):
    872         with variable_scope.variable_scope("inner") as inner:
    873           pass
    874 
    875       with variable_scope.variable_scope(
    876           inner, auxiliary_name_scope=False) as scope:
    877         with ops.name_scope(scope.original_name_scope):
    878           self.assertEqual(
    879               variable_scope.get_variable("w", []).name, "outer/inner/w:0")
    880           self.assertEqual(
    881               constant_op.constant([], name="c").name, "outer/inner/c:0")
    882         with ops.name_scope("inner"):
    883           self.assertEqual(constant_op.constant([], name="c").name, "inner/c:0")
    884 
    885       with variable_scope.variable_scope("another"):
    886         with variable_scope.variable_scope(
    887             inner, auxiliary_name_scope=False) as scope1:
    888           with ops.name_scope(scope1.original_name_scope):
    889             self.assertEqual(
    890                 variable_scope.get_variable("w1", []).name, "outer/inner/w1:0")
    891             self.assertEqual(
    892                 constant_op.constant([], name="c1").name, "outer/inner/c1:0")
    893           with ops.name_scope("inner"):
    894             self.assertEqual(
    895                 constant_op.constant([], name="c").name, "another/inner/c:0")
    896 
    897   @test_util.run_in_graph_and_eager_modes()
    898   def testGetLocalVar(self):
    899     # Check that local variable respects naming.
    900     with variable_scope.variable_scope("outer") as outer:
    901       with variable_scope.variable_scope(outer, "default", []):
    902         local_var = variable_scope.get_local_variable(
    903             "w", [], collections=["foo"])
    904         self.assertEqual(local_var.name, "outer/w:0")
    905 
    906     # Since variable is local, it should be in the local variable collection
    907     # but not the trainable collection.
    908     if context.in_graph_mode():
    909       self.assertIn(local_var,
    910                     ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
    911       self.assertIn(local_var, ops.get_collection("foo"))
    912       self.assertNotIn(local_var,
    913                        ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
    914 
    915     # Check that local variable respects `reuse`.
    916     if context.in_graph_mode():
    917       with variable_scope.variable_scope(outer, "default", reuse=True):
    918         self.assertEqual(
    919             variable_scope.get_local_variable("w", []).name, "outer/w:0")
    920 
    921   def testGetVarWithDevice(self):
    922     g = ops.Graph()
    923     varname_type = []
    924 
    925     def device_func(op):
    926       if op.type in ["Variable", "VariableV2", "VarHandleOp"]:
    927         varname_type.append((op.name, op.get_attr("dtype")))
    928       return "/device:GPU:0"
    929 
    930     with g.as_default():
    931       with ops.device(device_func):
    932         _ = variable_scope.get_variable("x", (100, 200))
    933         _ = variable_scope.get_variable(
    934             "y", dtype=dtypes.int64, initializer=numpy.arange(73))
    935     self.assertEqual(varname_type[0], ("x", dtypes.float32))
    936     self.assertEqual(varname_type[1], ("y", dtypes.int64))
    937 
    938   def testGetCollection(self):
    939     with self.test_session():
    940       _ = variable_scope.get_variable("testGetCollection_a", [])
    941       _ = variable_scope.get_variable("testGetCollection_b", [],
    942                                       trainable=False)
    943       with variable_scope.variable_scope("testGetCollection_foo_") as scope1:
    944         _ = variable_scope.get_variable("testGetCollection_a", [])
    945         _ = variable_scope.get_variable("testGetCollection_b", [],
    946                                         trainable=False)
    947         self.assertEqual([
    948             v.name
    949             for v in scope1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
    950         ], ["testGetCollection_foo_/testGetCollection_a:0"])
    951         self.assertEqual([
    952             v.name
    953             for v in scope1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
    954         ], [
    955             "testGetCollection_foo_/testGetCollection_a:0",
    956             "testGetCollection_foo_/testGetCollection_b:0"
    957         ])
    958       with variable_scope.variable_scope("testGetCollection_foo") as scope2:
    959         _ = variable_scope.get_variable("testGetCollection_a", [])
    960         _ = variable_scope.get_variable("testGetCollection_b", [],
    961                                         trainable=False)
    962         self.assertEqual([
    963             v.name
    964             for v in scope2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
    965         ], ["testGetCollection_foo/testGetCollection_a:0"])
    966         self.assertEqual([
    967             v.name
    968             for v in scope2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
    969         ], [
    970             "testGetCollection_foo/testGetCollection_a:0",
    971             "testGetCollection_foo/testGetCollection_b:0"
    972         ])
    973       scope = variable_scope.get_variable_scope()
    974       self.assertEqual([
    975           v.name for v in scope.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
    976       ], [
    977           "testGetCollection_a:0", "testGetCollection_b:0",
    978           "testGetCollection_foo_/testGetCollection_a:0",
    979           "testGetCollection_foo_/testGetCollection_b:0",
    980           "testGetCollection_foo/testGetCollection_a:0",
    981           "testGetCollection_foo/testGetCollection_b:0"
    982       ])
    983       self.assertEqual([
    984           v.name
    985           for v in scope.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
    986       ], [
    987           "testGetCollection_a:0",
    988           "testGetCollection_foo_/testGetCollection_a:0",
    989           "testGetCollection_foo/testGetCollection_a:0"
    990       ])
    991 
    992   def testGetTrainableVariables(self):
    993     with self.test_session():
    994       _ = variable_scope.get_variable("testGetTrainableVariables_a", [])
    995       with variable_scope.variable_scope(
    996           "testGetTrainableVariables_foo") as scope:
    997         _ = variable_scope.get_variable("testGetTrainableVariables_b", [])
    998         _ = variable_scope.get_variable("testGetTrainableVariables_c", [],
    999                                         trainable=False)
   1000         self.assertEqual([v.name
   1001                           for v in scope.trainable_variables()],
   1002                          ["testGetTrainableVariables_foo/"
   1003                           "testGetTrainableVariables_b:0"])
   1004 
   1005   def testGetGlobalVariables(self):
   1006     with self.test_session():
   1007       _ = variable_scope.get_variable("testGetGlobalVariables_a", [])
   1008       with variable_scope.variable_scope("testGetGlobalVariables_foo") as scope:
   1009         _ = variable_scope.get_variable("testGetGlobalVariables_b", [])
   1010         self.assertEqual([v.name
   1011                           for v in scope.global_variables()],
   1012                          ["testGetGlobalVariables_foo/"
   1013                           "testGetGlobalVariables_b:0"])
   1014 
   1015   def testGetLocalVariables(self):
   1016     with self.test_session():
   1017       _ = variable_scope.get_variable(
   1018           "a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
   1019       with variable_scope.variable_scope("foo") as scope:
   1020         _ = variable_scope.get_variable(
   1021             "b", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
   1022         _ = variable_scope.get_variable(
   1023             "c", [])
   1024         self.assertEqual([v.name
   1025                           for v in scope.local_variables()], ["foo/b:0"])
   1026 
   1027   def testGetVariableWithRefDtype(self):
   1028     v = variable_scope.get_variable("v", shape=[3, 4], dtype=dtypes.float32)
   1029     # Ensure it is possible to do get_variable with a _ref dtype passed in.
   1030     _ = variable_scope.get_variable("w", shape=[5, 6], dtype=v.dtype)
   1031 
   1032   def testTwoGraphs(self):
   1033 
   1034     def f():
   1035       g1 = ops.Graph()
   1036       g2 = ops.Graph()
   1037       with g1.as_default():
   1038         with g2.as_default():
   1039           with variable_scope.variable_scope("_"):
   1040             pass
   1041 
   1042     self.assertRaisesRegexp(ValueError, "'_' is not a valid scope name", f)
   1043 
   1044 
   1045 def axis0_into1_partitioner(shape=None, **unused_kwargs):
   1046   part = [1] * len(shape)
   1047   return part
   1048 
   1049 
   1050 def axis0_into2_partitioner(shape=None, **unused_kwargs):
   1051   part = [1] * len(shape)
   1052   part[0] = 2
   1053   return part
   1054 
   1055 
   1056 def axis0_into3_partitioner(shape=None, **unused_kwargs):
   1057   part = [1] * len(shape)
   1058   part[0] = 3
   1059   return part
   1060 
   1061 
   1062 class VariableScopeWithPartitioningTest(test.TestCase):
   1063 
   1064   def testResultNameMatchesRequested(self):
   1065     with variable_scope.variable_scope(
   1066         "scope0", partitioner=axis0_into2_partitioner):
   1067       v = variable_scope.get_variable("name0", shape=(3, 1, 1))
   1068       self.assertEqual(v.name, "scope0/name0")
   1069       v_concat = v.as_tensor()
   1070       self.assertEqual(v_concat.name, "scope0/name0:0")
   1071       variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
   1072       self.assertIn("scope0/name0/part_0:0", [x.name for x in variables])
   1073       self.assertIn("scope0/name0/part_1:0", [x.name for x in variables])
   1074       self.assertNotIn("scope0/name0/part_2:0", [x.name for x in variables])
   1075 
   1076   def testBreaksIfPartitioningChanges(self):
   1077     with variable_scope.variable_scope(
   1078         "scope0", partitioner=axis0_into2_partitioner):
   1079       variable_scope.get_variable("name0", shape=(3, 1, 1))
   1080 
   1081     with variable_scope.variable_scope(
   1082         "scope0", partitioner=axis0_into3_partitioner, reuse=True):
   1083       with self.assertRaisesRegexp(
   1084           ValueError,
   1085           "Trying to reuse partitioned variable .* but specified partitions .* "
   1086           "and found partitions .*"):
   1087         variable_scope.get_variable("name0", shape=(3, 1, 1))
   1088 
   1089     with variable_scope.variable_scope(
   1090         "scope0", partitioner=axis0_into1_partitioner, reuse=True):
   1091       with self.assertRaisesRegexp(
   1092           ValueError,
   1093           "Trying to reuse partitioned variable .* but specified partitions .* "
   1094           "and found partitions .*"):
   1095         variable_scope.get_variable("name0", shape=(3, 1, 1))
   1096 
   1097   def testReturnsExistingConcatenatedValueIfReuse(self):
   1098     with variable_scope.variable_scope(
   1099         "scope0", partitioner=axis0_into2_partitioner):
   1100       v_concat = variable_scope.get_variable("name0", shape=(3, 1, 1))
   1101       variable_scope.get_variable_scope().reuse_variables()
   1102       v_concat_2 = variable_scope.get_variable("name0", shape=(3, 1, 1))
   1103       self.assertEqual(v_concat, v_concat_2)
   1104 
   1105   def testAllowsReuseWithoutPartitioner(self):
   1106     with variable_scope.variable_scope(
   1107         "scope0", partitioner=axis0_into2_partitioner):
   1108       v = variable_scope.get_variable("name0", shape=(3, 1, 1))
   1109     with variable_scope.variable_scope("scope0", reuse=True):
   1110       v_reused = variable_scope.get_variable("name0")
   1111     self.assertEqual(v, v_reused)
   1112 
   1113   def testPropagatePartitionerOnReopening(self):
   1114     with variable_scope.variable_scope(
   1115         "scope0", partitioner=axis0_into2_partitioner) as vs:
   1116       self.assertEqual(axis0_into2_partitioner, vs.partitioner)
   1117       with variable_scope.variable_scope(vs) as vs1:
   1118         self.assertEqual(axis0_into2_partitioner, vs1.partitioner)
   1119 
   1120   def testScalarIgnoresPartitioner(self):
   1121     with variable_scope.variable_scope(
   1122         "scope0", partitioner=axis0_into2_partitioner):
   1123       v = variable_scope.get_variable("name0", shape=())
   1124       self.assertEqual(v.name, "scope0/name0:0")
   1125       variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
   1126       self.assertIn("scope0/name0:0", [x.name for x in variables])
   1127 
   1128   def _testPartitionConcatenatesAlongCorrectAxis(self, use_resource):
   1129 
   1130     def _part_axis_0(**unused_kwargs):
   1131       return (2, 1, 1)
   1132 
   1133     def _part_axis_1(**unused_kwargs):
   1134       return (1, 2, 1)
   1135 
   1136     with variable_scope.variable_scope("root", use_resource=use_resource):
   1137       v0 = variable_scope.get_variable(
   1138           "n0", shape=(2, 2, 2), partitioner=_part_axis_0)
   1139       v1 = variable_scope.get_variable(
   1140           "n1", shape=(2, 2, 2), partitioner=_part_axis_1)
   1141 
   1142     self.assertEqual(v0.get_shape(), (2, 2, 2))
   1143     self.assertEqual(v1.get_shape(), (2, 2, 2))
   1144 
   1145     n0_0 = list(v0)[0]
   1146     n0_1 = list(v0)[1]
   1147     self.assertEqual(n0_0.get_shape(), (1, 2, 2))
   1148     self.assertEqual(n0_1.get_shape(), (1, 2, 2))
   1149 
   1150     n1_0 = list(v1)[0]
   1151     n1_1 = list(v1)[1]
   1152     self.assertEqual(n1_0.get_shape(), (2, 1, 2))
   1153     self.assertEqual(n1_1.get_shape(), (2, 1, 2))
   1154 
   1155   def testPartitionConcatenatesAlongCorrectAxis(self):
   1156     self._testPartitionConcatenatesAlongCorrectAxis(use_resource=False)
   1157 
   1158   def testPartitionConcatenatesAlongCorrectAxisResource(self):
   1159     self._testPartitionConcatenatesAlongCorrectAxis(use_resource=True)
   1160 
   1161 
   1162 class VariableScopeWithCustomGetterTest(test.TestCase):
   1163 
   1164   def testNonCallableGetterFails(self):
   1165     with self.assertRaisesRegexp(ValueError, r"custom_getter .* not callable:"):
   1166       with variable_scope.variable_scope("scope0", custom_getter=3):
   1167         variable_scope.get_variable("name0")
   1168     with self.assertRaisesRegexp(ValueError, r"custom_getter .* not callable:"):
   1169       variable_scope.get_variable("name0", custom_getter=3)
   1170 
   1171   def testNoSideEffectsWithIdentityCustomGetter(self):
   1172     called = [0]
   1173 
   1174     def custom_getter(getter, *args, **kwargs):
   1175       called[0] += 1
   1176       return getter(*args, **kwargs)
   1177 
   1178     with variable_scope.variable_scope(
   1179         "scope", custom_getter=custom_getter) as scope:
   1180       v = variable_scope.get_variable("v", [1])
   1181     with variable_scope.variable_scope(scope, reuse=True):
   1182       v2 = variable_scope.get_variable("v", [1])
   1183     with variable_scope.variable_scope("new_scope") as new_scope:
   1184       v3 = variable_scope.get_variable("v3", [1])
   1185     with variable_scope.variable_scope(
   1186         new_scope, reuse=True, custom_getter=custom_getter):
   1187       v4 = variable_scope.get_variable("v3", [1])
   1188 
   1189     self.assertEqual(v, v2)
   1190     self.assertEqual(v3, v4)
   1191     self.assertEqual(3, called[0])  # skipped one in the first new_scope
   1192 
   1193   def testCustomGetterWithReuse(self):
   1194     # Custom getter can choose to behave differently on reused variables.
   1195     def custom_getter(getter, *args, **kwargs):
   1196       var = getter(*args, **kwargs)
   1197       if kwargs["reuse"]:
   1198         # This can be used, e.g., for changing the caching device if needed.
   1199         return array_ops.identity(var, name="reused")
   1200       else:
   1201         return array_ops.identity(var, name="not_reused")
   1202 
   1203     with variable_scope.variable_scope(
   1204         "scope", custom_getter=custom_getter) as scope:
   1205       v = variable_scope.get_variable("v", [1])
   1206     with variable_scope.variable_scope(scope, reuse=True):
   1207       v2 = variable_scope.get_variable("v", [1])
   1208 
   1209     self.assertEqual(v.name, "not_reused:0")
   1210     self.assertEqual(v2.name, "reused:0")
   1211 
   1212   def testGetterThatCreatesTwoVariablesAndSumsThem(self):
   1213 
   1214     def custom_getter(getter, name, *args, **kwargs):
   1215       g_0 = getter("%s/0" % name, *args, **kwargs)
   1216       g_1 = getter("%s/1" % name, *args, **kwargs)
   1217       with ops.name_scope("custom_getter"):
   1218         return g_0 + g_1
   1219 
   1220     with variable_scope.variable_scope("scope", custom_getter=custom_getter):
   1221       v = variable_scope.get_variable("v", [1, 2, 3])
   1222 
   1223     self.assertEqual([1, 2, 3], v.get_shape())
   1224     true_vars = variables_lib.trainable_variables()
   1225     self.assertEqual(2, len(true_vars))
   1226     self.assertEqual("scope/v/0:0", true_vars[0].name)
   1227     self.assertEqual("scope/v/1:0", true_vars[1].name)
   1228     self.assertEqual("custom_getter/add:0", v.name)
   1229     with self.test_session() as sess:
   1230       variables_lib.global_variables_initializer().run()
   1231       np_vars, np_v = sess.run([true_vars, v])
   1232       self.assertAllClose(np_v, sum(np_vars))
   1233 
   1234   def testNestedCustomGetters(self):
   1235 
   1236     def sum_getter(getter, name, *args, **kwargs):
   1237       g_0 = getter("%s/sum_0" % name, *args, **kwargs)
   1238       g_1 = getter("%s/sum_1" % name, *args, **kwargs)
   1239       with ops.name_scope("sum_getter"):
   1240         return g_0 + g_1
   1241 
   1242     def prod_getter(getter, name, *args, **kwargs):
   1243       g_0 = getter("%s/prod_0" % name, *args, **kwargs)
   1244       g_1 = getter("%s/prod_1" % name, *args, **kwargs)
   1245       with ops.name_scope("prod_getter"):
   1246         return g_0 * g_1
   1247 
   1248     with variable_scope.variable_scope(
   1249         "prod_scope", custom_getter=prod_getter):
   1250       with variable_scope.variable_scope(
   1251           "sum_scope", custom_getter=sum_getter):
   1252         with variable_scope.variable_scope(
   1253             "inner_sum_scope", custom_getter=sum_getter):
   1254           # take sums of sums of products
   1255           v = variable_scope.get_variable("v", [1, 2, 3])
   1256 
   1257     self.assertEqual([1, 2, 3], v.get_shape())
   1258     true_vars = variables_lib.trainable_variables()
   1259     self.assertEqual(8, len(true_vars))
   1260     template = (
   1261         "prod_scope/sum_scope/inner_sum_scope/v/sum_%d/sum_%d/prod_%d:0")
   1262     self.assertEqual(template % (0, 0, 0), true_vars[0].name)
   1263     self.assertEqual(template % (0, 0, 1), true_vars[1].name)
   1264     self.assertEqual(template % (0, 1, 0), true_vars[2].name)
   1265     self.assertEqual(template % (0, 1, 1), true_vars[3].name)
   1266     self.assertEqual(template % (1, 0, 0), true_vars[4].name)
   1267     self.assertEqual(template % (1, 0, 1), true_vars[5].name)
   1268     self.assertEqual(template % (1, 1, 0), true_vars[6].name)
   1269     self.assertEqual(template % (1, 1, 1), true_vars[7].name)
   1270 
   1271     with self.test_session() as sess:
   1272       variables_lib.global_variables_initializer().run()
   1273       np_vars, np_v = sess.run([true_vars, v])
   1274       # take products of sums of products
   1275       self.assertAllClose(
   1276           np_v,
   1277           (((np_vars[0] * np_vars[1]) + (np_vars[2] * np_vars[3]))
   1278            + ((np_vars[4] * np_vars[5]) + (np_vars[6] * np_vars[7]))))
   1279 
   1280   def testVariableCreator(self):
   1281 
   1282     variable_names = []
   1283 
   1284     def creator_a(next_creator, **kwargs):
   1285       variable_names.append(kwargs.get("name", ""))
   1286       return next_creator(**kwargs)
   1287 
   1288     def creator_b(next_creator, **kwargs):
   1289       kwargs["name"] = "forced_name"
   1290       return next_creator(**kwargs)
   1291 
   1292     with variable_scope.variable_creator_scope(creator_a):
   1293       with variable_scope.variable_creator_scope(creator_b):
   1294         variable_scope.variable(1.0, name="one_name")
   1295 
   1296     self.assertAllEqual(variable_names, ["forced_name"])
   1297 
   1298 
   1299 class PartitionInfoTest(test.TestCase):
   1300 
   1301   def testConstructorChecks(self):
   1302     # Invalid arg types.
   1303     with self.assertRaises(TypeError):
   1304       variable_scope._PartitionInfo(full_shape=None, var_offset=[0, 1])
   1305     with self.assertRaises(TypeError):
   1306       variable_scope._PartitionInfo(full_shape=[0, 1], var_offset=None)
   1307     with self.assertRaises(TypeError):
   1308       variable_scope._PartitionInfo(full_shape="foo", var_offset=[0, 1])
   1309     with self.assertRaises(TypeError):
   1310       variable_scope._PartitionInfo(full_shape=[0, 1], var_offset="foo")
   1311 
   1312     # full_shape and var_offset must have same length.
   1313     with self.assertRaises(ValueError):
   1314       variable_scope._PartitionInfo(full_shape=[0, 1], var_offset=[0])
   1315     # Offset must always be less than shape.
   1316     with self.assertRaises(ValueError):
   1317       variable_scope._PartitionInfo(full_shape=[1, 1], var_offset=[0, 1])
   1318 
   1319   def testSingleOffset(self):
   1320     partition_info = variable_scope._PartitionInfo(
   1321         full_shape=[9, 3], var_offset=[4, 0])
   1322     self.assertEqual(4, partition_info.single_offset([1, 3]))
   1323 
   1324     # Tests when the variable isn't partitioned at all.
   1325     partition_info = variable_scope._PartitionInfo(
   1326         full_shape=[9, 3], var_offset=[0, 0])
   1327     self.assertEqual(0, partition_info.single_offset([9, 3]))
   1328 
   1329   def testSingleSliceDim(self):
   1330     partition_info = variable_scope._PartitionInfo(
   1331         full_shape=[9, 3], var_offset=[4, 0])
   1332     # Invalid shape.
   1333     with self.assertRaises(TypeError):
   1334       partition_info.single_slice_dim(None)
   1335 
   1336     # Rank of shape differs from full_shape.
   1337     with self.assertRaises(ValueError):
   1338       partition_info.single_slice_dim([1, 2, 3])
   1339 
   1340     # Shape is too large given var_offset (4+6 > 9).
   1341     with self.assertRaises(ValueError):
   1342       partition_info.single_slice_dim([6, 3])
   1343 
   1344     # Multiple possible slice dim from shape.
   1345     with self.assertRaises(ValueError):
   1346       partition_info.single_slice_dim([1, 1])
   1347 
   1348     partition_info = variable_scope._PartitionInfo(
   1349         full_shape=[9, 3], var_offset=[0, 0])
   1350     self.assertEqual(1, partition_info.single_slice_dim([9, 2]))
   1351     partition_info = variable_scope._PartitionInfo(
   1352         full_shape=[9, 3], var_offset=[4, 0])
   1353     self.assertEqual(0, partition_info.single_slice_dim([2, 3]))
   1354 
   1355 
   1356 if __name__ == "__main__":
   1357   test.main()
   1358