1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for variable store.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import gc 22 23 import numpy 24 25 from tensorflow.python.eager import context 26 from tensorflow.python.framework import constant_op 27 from tensorflow.python.framework import dtypes 28 from tensorflow.python.framework import errors 29 from tensorflow.python.framework import ops 30 from tensorflow.python.framework import test_util 31 from tensorflow.python.ops import array_ops 32 from tensorflow.python.ops import control_flow_ops 33 from tensorflow.python.ops import init_ops 34 from tensorflow.python.ops import math_ops 35 from tensorflow.python.ops import resource_variable_ops 36 from tensorflow.python.ops import state_ops 37 from tensorflow.python.ops import variable_scope 38 from tensorflow.python.ops import variables as variables_lib 39 from tensorflow.python.platform import test 40 41 42 class VariableScopeTest(test.TestCase): 43 44 def tearDown(self): 45 gc.collect() 46 # This will only contain uncollectable garbage, i.e. reference cycles 47 # involving objects with __del__ defined. 48 self.assertEqual(0, len(gc.garbage)) 49 50 def testGetVar(self): 51 vs = variable_scope._get_default_variable_store() 52 v = vs.get_variable("v", [1]) 53 v1 = vs.get_variable("v", [1]) 54 self.assertEqual(v, v1) 55 56 @test_util.run_in_graph_and_eager_modes() 57 def testResource(self): 58 vs = variable_scope._get_default_variable_store() 59 v1 = vs.get_variable("v", [1], use_resource=True) 60 self.assertTrue(isinstance(v1, resource_variable_ops.ResourceVariable)) 61 62 def testNameExists(self): 63 vs = variable_scope._get_default_variable_store() 64 # No check by default, so we can both create and get existing names. 65 v = vs.get_variable("v", [1]) 66 v1 = vs.get_variable("v", [1]) 67 self.assertEqual(v, v1) 68 69 # When reuse is False, we fail when variables are already there. 70 vs.get_variable("w", [1], reuse=False) # That's ok. 71 with self.assertRaises(ValueError): 72 vs.get_variable("v", [1], reuse=False) # That fails. 73 # When reuse is True, we fail when variables are new. 74 vs.get_variable("v", [1], reuse=True) # That's ok. 75 with self.assertRaises(ValueError): 76 vs.get_variable("u", [1], reuse=True) # That fails. 77 78 def testNamelessStore(self): 79 vs = variable_scope._get_default_variable_store() 80 vs.get_variable("v1", [2]) 81 vs.get_variable("v2", [2]) 82 expected_names = ["%s:0" % name for name in ["v1", "v2"]] 83 self.assertEqual( 84 set(expected_names), set([v.name for v in vs._vars.values()])) 85 86 @test_util.run_in_graph_and_eager_modes() 87 def testVarScopeInitializer(self): 88 init = init_ops.constant_initializer(0.3) 89 with variable_scope.variable_scope("tower0") as tower: 90 with variable_scope.variable_scope("foo", initializer=init): 91 v = variable_scope.get_variable("v", []) 92 self.evaluate(variables_lib.variables_initializer([v])) 93 self.assertAllClose(self.evaluate(v.value()), 0.3) 94 with variable_scope.variable_scope(tower, initializer=init): 95 w = variable_scope.get_variable("w", []) 96 self.evaluate(variables_lib.variables_initializer([w])) 97 self.assertAllClose(self.evaluate(w.value()), 0.3) 98 99 @test_util.run_in_graph_and_eager_modes() 100 def testVarScopeConstraint(self): 101 constraint = lambda x: 0. * x 102 with variable_scope.variable_scope("tower1") as tower: 103 with variable_scope.variable_scope("foo", constraint=constraint): 104 v = variable_scope.get_variable("v", []) 105 self.assertEqual(v.constraint, constraint) 106 with variable_scope.variable_scope(tower, constraint=constraint): 107 w = variable_scope.get_variable("w", []) 108 self.assertEqual(w.constraint, constraint) 109 110 @test_util.run_in_graph_and_eager_modes() 111 def testVarScopeDType(self): 112 with variable_scope.variable_scope("tower2") as tower: 113 with variable_scope.variable_scope("foo", dtype=dtypes.float16): 114 v = variable_scope.get_variable("v", []) 115 self.assertEqual(v.dtype.base_dtype, dtypes.float16) 116 with variable_scope.variable_scope(tower, dtype=dtypes.float16): 117 w = variable_scope.get_variable("w", []) 118 self.assertEqual(w.dtype.base_dtype, dtypes.float16) 119 120 def testEagerVariableStore(self): 121 with context.eager_mode(): 122 store = variable_scope.EagerVariableStore() 123 with store.as_default(): 124 v = variable_scope.get_variable("v", shape=(), trainable=True) 125 w = variable_scope.get_variable("w", shape=(), trainable=False) 126 127 self.assertTrue(v in store.variables()) 128 self.assertTrue(w in store.variables()) 129 self.assertTrue(v in store.trainable_variables()) 130 self.assertFalse(w in store.trainable_variables()) 131 self.assertFalse(v in store.non_trainable_variables()) 132 self.assertTrue(w in store.non_trainable_variables()) 133 134 # Test copying. 135 new_store = store.copy() 136 with new_store.as_default(): 137 new_v = variable_scope.get_variable("v") 138 new_w = variable_scope.get_variable("w") 139 self.assertEqual(new_v.numpy(), v.numpy()) 140 self.assertEqual(new_w.numpy(), w.numpy()) 141 self.assertTrue(new_v in new_store.variables()) 142 self.assertTrue(new_w in new_store.variables()) 143 self.assertTrue(new_v in new_store.trainable_variables()) 144 self.assertFalse(new_w in new_store.trainable_variables()) 145 self.assertFalse(new_v in new_store.non_trainable_variables()) 146 self.assertTrue(new_w in new_store.non_trainable_variables()) 147 148 # Check that variables are separate instances. 149 for v in store.variables(): 150 v.assign(-1) 151 for v in new_store.variables(): 152 v.assign(1) 153 for v in store.variables(): 154 self.assertEqual(v.numpy(), -1) 155 for v in new_store.variables(): 156 self.assertEqual(v.numpy(), 1) 157 158 @test_util.run_in_graph_and_eager_modes() 159 def testInitFromNonTensorValue(self): 160 v = variable_scope.get_variable("v4", initializer=4, dtype=dtypes.int32) 161 self.evaluate(variables_lib.variables_initializer([v])) 162 self.assertAllClose(self.evaluate(v.value()), 4) 163 164 w = variable_scope.get_variable( 165 "w4", initializer=numpy.array([1, 2, 3]), dtype=dtypes.int64) 166 self.evaluate(variables_lib.variables_initializer([w])) 167 self.assertAllClose(self.evaluate(w.value()), [1, 2, 3]) 168 169 if context.in_graph_mode(): 170 with self.assertRaises(TypeError): 171 variable_scope.get_variable("x4", initializer={}) 172 else: 173 with self.assertRaises(ValueError): 174 variable_scope.get_variable("x4", initializer={}) 175 176 @test_util.run_in_graph_and_eager_modes() 177 def testInitFromNonInitializer(self): 178 # Test various dtypes with zeros initializer as following: 179 types = [ 180 dtypes.int8, dtypes.uint8, dtypes.int16, dtypes.uint16, dtypes.int32, 181 dtypes.int64, dtypes.bool 182 ] 183 184 # Use different variable_name to distinguish various dtypes 185 for (i, dtype) in enumerate(types): 186 x = variable_scope.get_variable( 187 name="xx%d" % i, shape=(3, 4), dtype=dtype) 188 y = variable_scope.get_variable( 189 name="yy%d" % i, 190 shape=(3, 4), 191 dtype=dtype, 192 initializer=init_ops.zeros_initializer(dtype=dtype)) 193 194 self.evaluate(variables_lib.global_variables_initializer()) 195 self.assertAllEqual(self.evaluate(x.value()), self.evaluate(y.value())) 196 197 # TODO(alive): support variable partitioning/caching in eager mode. 198 def testVarScopeCachingDevice(self): 199 with self.test_session(): 200 caching_device = "/job:moo" 201 with variable_scope.variable_scope("tower"): 202 with variable_scope.variable_scope( 203 "caching", caching_device=caching_device): 204 v = variable_scope.get_variable("v", []) 205 self.assertTrue(v.value().device.startswith(caching_device)) 206 207 with variable_scope.variable_scope("child"): 208 v2 = variable_scope.get_variable("v", []) 209 self.assertTrue(v2.value().device.startswith(caching_device)) 210 211 with variable_scope.variable_scope("not_cached", caching_device=""): 212 v2_not_cached = variable_scope.get_variable("v", []) 213 self.assertFalse(v2_not_cached.value().device.startswith( 214 caching_device)) 215 216 with variable_scope.variable_scope( 217 "not_cached_identity_device", 218 caching_device=lambda op: op.device): 219 v2_identity_device = variable_scope.get_variable("v", []) 220 self.assertFalse(v2_identity_device.value().device.startswith( 221 caching_device)) 222 223 with variable_scope.variable_scope("we_will_do_it_live") as vs_live: 224 vs_live.set_caching_device("/job:live") 225 v_live = variable_scope.get_variable("v", []) 226 self.assertTrue(v_live.value().device.startswith("/job:live")) 227 228 v_tower = variable_scope.get_variable("v", []) 229 self.assertFalse(v_tower.value().device.startswith(caching_device)) 230 231 @test_util.run_in_graph_and_eager_modes() 232 def testVarScopeRegularizer(self): 233 init = init_ops.constant_initializer(0.3) 234 235 def regularizer1(v): 236 return math_ops.reduce_mean(v) + 0.1 237 238 def regularizer2(v): 239 return math_ops.reduce_mean(v) + 0.2 240 241 with variable_scope.variable_scope( 242 "tower3", regularizer=regularizer1) as tower: 243 with variable_scope.variable_scope("foo", initializer=init): 244 v = variable_scope.get_variable("v", []) 245 self.evaluate(variables_lib.variables_initializer([v])) 246 losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) 247 self.assertEqual(1, len(losses)) 248 self.assertAllClose(self.evaluate(losses[0]), 0.4) 249 with variable_scope.variable_scope(tower, initializer=init) as vs: 250 u = variable_scope.get_variable("u", []) 251 vs.set_regularizer(regularizer2) 252 w = variable_scope.get_variable("w", []) 253 # Next 3 variable not regularized to test disabling regularization. 254 x = variable_scope.get_variable( 255 "x", [], regularizer=variable_scope.no_regularizer) 256 with variable_scope.variable_scope( 257 "baz", regularizer=variable_scope.no_regularizer): 258 y = variable_scope.get_variable("y", []) 259 vs.set_regularizer(variable_scope.no_regularizer) 260 z = variable_scope.get_variable("z", []) 261 # Check results. 262 losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) 263 self.assertEqual(3, len(losses)) 264 self.evaluate(variables_lib.variables_initializer([u, w, x, y, z])) 265 self.assertAllClose(self.evaluate(losses[0]), 0.4) 266 self.assertAllClose(self.evaluate(losses[1]), 0.4) 267 self.assertAllClose(self.evaluate(losses[2]), 0.5) 268 with variable_scope.variable_scope("foo", reuse=True): 269 # reuse=True is for now only supported when eager execution is disabled. 270 if context.in_graph_mode(): 271 v = variable_scope.get_variable("v", 272 []) # "v" is alredy there, reused 273 losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) 274 self.assertEqual(3, len(losses)) # No new loss added. 275 276 @test_util.run_in_graph_and_eager_modes() 277 def testInitializeFromValue(self): 278 init = constant_op.constant(0.1) 279 w = variable_scope.get_variable("v", initializer=init) 280 self.evaluate(variables_lib.variables_initializer([w])) 281 self.assertAllClose(self.evaluate(w.value()), 0.1) 282 283 with self.assertRaisesRegexp(ValueError, "shape"): 284 # We disallow explicit shape specification when initializer is constant. 285 variable_scope.get_variable("u", [1], initializer=init) 286 287 with variable_scope.variable_scope("foo", initializer=init): 288 # Constant initializer can be passed through scopes if needed. 289 v = variable_scope.get_variable("v") 290 self.evaluate(variables_lib.variables_initializer([v])) 291 self.assertAllClose(self.evaluate(v.value()), 0.1) 292 293 # Check that non-float32 initializer creates a non-float32 variable. 294 init = constant_op.constant(1, dtype=dtypes.int32) 295 t = variable_scope.get_variable("t", initializer=init) 296 self.assertEqual(t.dtype.base_dtype, dtypes.int32) 297 298 # Raise error if `initializer` dtype and `dtype` are not identical. 299 with self.assertRaisesRegexp(ValueError, "don't match"): 300 variable_scope.get_variable("s", initializer=init, dtype=dtypes.float64) 301 302 def testControlDeps(self): 303 with self.test_session() as sess: 304 v0 = variable_scope.get_variable( 305 "v0", [1], initializer=init_ops.constant_initializer(0)) 306 with ops.control_dependencies([v0.value()]): 307 v1 = variable_scope.get_variable( 308 "v1", [1], initializer=init_ops.constant_initializer(1)) 309 add = v1 + v0 310 # v0 should be uninitialized. 311 with self.assertRaisesRegexp(errors.OpError, "uninitialized"): 312 sess.run(v0) 313 # We should be able to initialize and run v1 without initializing 314 # v0, even if the variable was created with a control dep on v0. 315 sess.run(v1.initializer) 316 self.assertEqual(1, sess.run(v1)) 317 # v0 should still be uninitialized. 318 with self.assertRaisesRegexp(errors.OpError, "uninitialized"): 319 sess.run(v0) 320 with self.assertRaisesRegexp(errors.OpError, "uninitialized"): 321 sess.run(add) 322 # If we initialize v0 we should be able to run 'add'. 323 sess.run(v0.initializer) 324 sess.run(add) 325 326 def testControlFlow(self): 327 with self.test_session() as sess: 328 v0 = variable_scope.get_variable( 329 "v0", [], initializer=init_ops.constant_initializer(0)) 330 var_dict = {} 331 332 # Call get_variable in each of the cond clauses. 333 def var_in_then_clause(): 334 v1 = variable_scope.get_variable( 335 "v1", [1], initializer=init_ops.constant_initializer(1)) 336 var_dict["v1"] = v1 337 return v1 + v0 338 339 def var_in_else_clause(): 340 v2 = variable_scope.get_variable( 341 "v2", [1], initializer=init_ops.constant_initializer(2)) 342 var_dict["v2"] = v2 343 return v2 + v0 344 345 add = control_flow_ops.cond( 346 math_ops.less(v0, 10), var_in_then_clause, var_in_else_clause) 347 v1 = var_dict["v1"] 348 v2 = var_dict["v2"] 349 # We should be able to initialize and run v1 and v2 without initializing 350 # v0, even if the variable was created with a control dep on v0. 351 sess.run(v1.initializer) 352 self.assertEqual([1], sess.run(v1)) 353 sess.run(v2.initializer) 354 self.assertEqual([2], sess.run(v2)) 355 # v0 should still be uninitialized. 356 with self.assertRaisesRegexp(errors.OpError, "uninitialized"): 357 sess.run(v0) 358 # We should not be able to run 'add' yet. 359 with self.assertRaisesRegexp(errors.OpError, "uninitialized"): 360 sess.run(add) 361 # If we initialize v0 we should be able to run 'add'. 362 sess.run(v0.initializer) 363 sess.run(add) 364 365 @test_util.run_in_graph_and_eager_modes() 366 def testGetVariableScope(self): 367 # Test the get_variable_scope() function and setting properties of result. 368 init = init_ops.constant_initializer(0.3) 369 with variable_scope.variable_scope("bar"): 370 new_init1 = variable_scope.get_variable_scope().initializer 371 self.assertEqual(new_init1, None) 372 # Check that we can set initializer like this. 373 variable_scope.get_variable_scope().set_initializer(init) 374 v = variable_scope.get_variable("v", []) 375 self.evaluate(variables_lib.variables_initializer([v])) 376 self.assertAllClose(self.evaluate(v.value()), 0.3) 377 if context.in_graph_mode(): 378 # Check that we can set reuse. 379 variable_scope.get_variable_scope().reuse_variables() 380 with self.assertRaises(ValueError): # Fail, w does not exist yet. 381 variable_scope.get_variable("w", [1]) 382 # Check that the set initializer goes away. 383 new_init = variable_scope.get_variable_scope().initializer 384 self.assertEqual(new_init, None) 385 386 @test_util.run_in_graph_and_eager_modes() 387 def testVarScope(self): 388 with variable_scope.variable_scope("tower4") as tower: 389 self.assertEqual(tower.name, "tower4") 390 with ops.name_scope("scope") as sc: 391 self.assertEqual(sc, "tower4/scope/") 392 393 with variable_scope.variable_scope("tower5"): 394 with variable_scope.variable_scope("bar") as bar: 395 self.assertEqual(bar.name, "tower5/bar") 396 with ops.name_scope("scope") as sc: 397 self.assertEqual(sc, "tower5/bar/scope/") 398 399 with variable_scope.variable_scope("tower6"): 400 with variable_scope.variable_scope(tower, reuse=True) as tower_shared: 401 self.assertEqual(tower_shared.name, "tower4") 402 with ops.name_scope("scope") as sc: 403 self.assertEqual(sc, "tower6/tower4/scope/") 404 405 @test_util.run_in_graph_and_eager_modes() 406 def testVarScopeNameScope(self): 407 with ops.name_scope("testVarScopeNameScope1"): 408 with variable_scope.variable_scope("tower") as tower: 409 with ops.name_scope("scope2") as sc2: 410 self.assertEqual(sc2, "testVarScopeNameScope1/tower/scope2/") 411 if context.in_graph_mode(): 412 with variable_scope.variable_scope( 413 tower): # Re-entering acts like another "tower". 414 with ops.name_scope("scope2") as sc2: 415 self.assertEqual(sc2, "testVarScopeNameScope1/tower_1/scope2/") 416 with variable_scope.variable_scope( 417 "tower"): # Re-entering by string acts the same. 418 with ops.name_scope("scope2") as sc2: 419 self.assertEqual(sc2, "testVarScopeNameScope1/tower_2/scope2/") 420 421 with ops.name_scope("testVarScopeNameScope2"): 422 with variable_scope.variable_scope("tower"): 423 with ops.name_scope("scope2") as sc2: 424 self.assertEqual(sc2, "testVarScopeNameScope2/tower/scope2/") 425 if context.in_graph_mode(): 426 with variable_scope.variable_scope(tower): 427 with ops.name_scope("scope2") as sc2: 428 self.assertEqual(sc2, "testVarScopeNameScope2/tower_1/scope2/") 429 430 root_var_scope = variable_scope.get_variable_scope() 431 with ops.name_scope("testVarScopeNameScope3"): 432 with variable_scope.variable_scope(root_var_scope): 433 with ops.name_scope("scope2") as sc2: 434 self.assertEqual(sc2, "testVarScopeNameScope3/scope2/") 435 436 def testVarScopeOriginalNameScope(self): 437 with self.test_session(): 438 with ops.name_scope("scope1"): 439 with variable_scope.variable_scope("tower") as tower: 440 self.assertEqual(tower.original_name_scope, "scope1/tower/") 441 with ops.name_scope("scope2") as sc2: 442 self.assertEqual(sc2, "scope1/tower/scope2/") 443 with ops.name_scope("scope2"): 444 with variable_scope.variable_scope(tower) as tower1: 445 # Re-entering preserves original name scope. 446 self.assertEqual(tower1.original_name_scope, "scope1/tower/") 447 with ops.name_scope("foo") as sc2: 448 self.assertEqual(sc2, "scope2/tower/foo/") 449 # Test re-entering original name scope. 450 with ops.name_scope(tower.original_name_scope): 451 with ops.name_scope("bar") as sc3: 452 self.assertEqual(sc3, "scope1/tower/bar/") 453 with ops.name_scope("scope2"): 454 with variable_scope.variable_scope(tower): 455 with ops.name_scope(tower.original_name_scope): 456 with ops.name_scope("bar") as sc3: 457 self.assertEqual(sc3, "scope1/tower/bar_1/") 458 459 def testVarScopeObjectReuse(self): 460 with self.test_session(): 461 vs = None 462 with variable_scope.variable_scope("jump", reuse=True) as scope: 463 vs = scope 464 465 with variable_scope.variable_scope(vs) as jump: 466 self.assertTrue(jump.reuse) 467 468 with variable_scope.variable_scope(vs, reuse=True) as jump_reuse: 469 self.assertTrue(jump_reuse.reuse) 470 471 with variable_scope.variable_scope(vs, reuse=False) as jump_no_reuse: 472 self.assertTrue(jump_no_reuse.reuse) # Inherited, cannot be undone. 473 474 with variable_scope.variable_scope("jump", reuse=False) as scope: 475 vs = scope 476 477 with variable_scope.variable_scope(vs) as jump: 478 self.assertFalse(jump.reuse) 479 480 with variable_scope.variable_scope(vs, reuse=True) as jump_reuse: 481 self.assertTrue(jump_reuse.reuse) 482 483 with variable_scope.variable_scope(vs, reuse=False) as jump_no_reuse: 484 self.assertFalse(jump_no_reuse.reuse) 485 486 def testVarScopeGetOrCreateReuse(self): 487 with self.test_session(): 488 def test_value(value): 489 x = constant_op.constant(value) 490 with variable_scope.variable_scope("testVarScopeGetOrCreateReuse_bar", 491 reuse=variable_scope.AUTO_REUSE): 492 _ = state_ops.assign(variable_scope.get_variable("var", []), x) 493 with variable_scope.variable_scope("testVarScopeGetOrCreateReuse_bar", 494 reuse=variable_scope.AUTO_REUSE): 495 _ = variable_scope.get_variable("var", []) 496 self.assertEqual(value, x.eval()) 497 test_value(42.) # Variable is created. 498 test_value(13.) # Variable is reused hereafter. 499 test_value(17.) 500 501 def testVarOpScope(self): 502 with self.test_session(): 503 with ops.name_scope("testVarOpScope1"): 504 with variable_scope.variable_scope("tower", "default", []): 505 self.assertEqual( 506 variable_scope.get_variable("w", []).name, "tower/w:0") 507 with ops.name_scope("testVarOpScope2") as sc2: 508 self.assertEqual(sc2, "testVarOpScope1/tower/testVarOpScope2/") 509 with variable_scope.variable_scope("tower", "default", []): 510 with self.assertRaises(ValueError): 511 variable_scope.get_variable("w", []) 512 with ops.name_scope("testVarOpScope2") as sc2: 513 self.assertEqual(sc2, "testVarOpScope1/tower_1/testVarOpScope2/") 514 515 with ops.name_scope("testVarOpScope2"): 516 with variable_scope.variable_scope(None, "default", []): 517 self.assertEqual( 518 variable_scope.get_variable("w", []).name, "default/w:0") 519 with ops.name_scope("testVarOpScope2") as sc2: 520 self.assertEqual(sc2, "testVarOpScope2/default/testVarOpScope2/") 521 with variable_scope.variable_scope(None, "default", []): 522 self.assertEqual( 523 variable_scope.get_variable("w", []).name, "default_1/w:0") 524 with ops.name_scope("testVarOpScope2") as sc2: 525 self.assertEqual(sc2, "testVarOpScope2/default_1/testVarOpScope2/") 526 527 def testVarOpScopeUniqueNamesInterleavedSubstringScopes(self): 528 with self.test_session(): 529 with variable_scope.variable_scope(None, "defaultScope1"): 530 with variable_scope.variable_scope(None, "layer"): 531 self.assertEqual( 532 variable_scope.get_variable("w", []).name, 533 "defaultScope1/layer/w:0") 534 with variable_scope.variable_scope(None, "defaultScope1"): 535 with variable_scope.variable_scope(None, "layer"): 536 self.assertEqual( 537 variable_scope.get_variable("w", []).name, 538 "defaultScope1_1/layer/w:0") 539 with variable_scope.variable_scope(None, "defaultScope"): 540 with variable_scope.variable_scope(None, "layer"): 541 self.assertEqual( 542 variable_scope.get_variable("w", []).name, 543 "defaultScope/layer/w:0") 544 with variable_scope.variable_scope(None, "defaultScope1"): 545 with variable_scope.variable_scope(None, "layer"): 546 self.assertEqual( 547 variable_scope.get_variable("w", []).name, 548 "defaultScope1_2/layer/w:0") 549 550 def testVarOpScopeUniqueNamesWithJump(self): 551 with self.test_session(): 552 with variable_scope.variable_scope("default") as default: 553 with variable_scope.variable_scope(None, "layer"): 554 self.assertEqual( 555 variable_scope.get_variable("w", []).name, 556 "default/layer/w:0") 557 with variable_scope.variable_scope(None, "layer"): 558 self.assertEqual( 559 variable_scope.get_variable("w", []).name, 560 "default/layer_1/w:0") 561 with variable_scope.variable_scope(default): 562 pass 563 # No matter the jump in the middle, unique numbering continues. 564 with variable_scope.variable_scope(None, "layer"): 565 self.assertEqual( 566 variable_scope.get_variable("w", []).name, 567 "default/layer_2/w:0") 568 569 def testVarOpScopeReuse(self): 570 with self.test_session(): 571 with variable_scope.variable_scope("outer") as outer: 572 with variable_scope.variable_scope("tower", "default", []): 573 self.assertEqual( 574 variable_scope.get_variable("w", []).name, "outer/tower/w:0") 575 with ops.name_scope("scope2") as sc2: 576 self.assertEqual(sc2, "outer/tower/scope2/") 577 with variable_scope.variable_scope(None, "default", []): 578 self.assertEqual( 579 variable_scope.get_variable("w", []).name, "outer/default/w:0") 580 with ops.name_scope("scope2") as sc2: 581 self.assertEqual(sc2, "outer/default/scope2/") 582 583 with variable_scope.variable_scope(outer, reuse=True) as outer: 584 with variable_scope.variable_scope("tower", "default", []): 585 self.assertEqual( 586 variable_scope.get_variable("w", []).name, "outer/tower/w:0") 587 with ops.name_scope("scope2") as sc2: 588 self.assertEqual(sc2, "outer_1/tower/scope2/") 589 with variable_scope.variable_scope(None, "default", []): 590 self.assertEqual( 591 variable_scope.get_variable("w", []).name, "outer/default/w:0") 592 with ops.name_scope("scope2") as sc2: 593 self.assertEqual(sc2, "outer_1/default/scope2/") 594 595 def testVarScopeGetVar(self): 596 with self.test_session(): 597 with variable_scope.variable_scope("root"): 598 with variable_scope.variable_scope("towerA") as tower_a: 599 va = variable_scope.get_variable("v", [1]) 600 self.assertEqual(va.name, "root/towerA/v:0") 601 602 with variable_scope.variable_scope(tower_a, reuse=True): 603 va2 = variable_scope.get_variable("v", [1]) 604 self.assertEqual(va2, va) 605 606 with variable_scope.variable_scope("towerB"): 607 vb = variable_scope.get_variable("v", [1]) 608 self.assertEqual(vb.name, "root/towerB/v:0") 609 610 with self.assertRaises(ValueError): 611 with variable_scope.variable_scope("towerA"): 612 va2 = variable_scope.get_variable("v", [1]) 613 614 with variable_scope.variable_scope("towerA", reuse=True): 615 va2 = variable_scope.get_variable("v", [1]) 616 self.assertEqual(va2, va) 617 618 with variable_scope.variable_scope("foo"): 619 with variable_scope.variable_scope("bar"): 620 v = variable_scope.get_variable("v", [1]) 621 self.assertEqual(v.name, "root/foo/bar/v:0") 622 with variable_scope.variable_scope(tower_a, reuse=True): 623 va3 = variable_scope.get_variable("v", [1]) 624 self.assertEqual(va, va3) 625 626 with self.assertRaises(ValueError): 627 with variable_scope.variable_scope(tower_a, reuse=True): 628 with variable_scope.variable_scope("baz"): 629 variable_scope.get_variable("v", [1]) 630 631 with self.assertRaises(ValueError) as exc: 632 with variable_scope.variable_scope(tower_a, reuse=True): 633 variable_scope.get_variable("v", [2]) # Different shape. 634 self.assertEqual("shape" in str(exc.exception), True) 635 636 with self.assertRaises(ValueError) as exc: 637 with variable_scope.variable_scope(tower_a, reuse=True): 638 variable_scope.get_variable("v", [1], dtype=dtypes.int32) 639 self.assertEqual("dtype" in str(exc.exception), True) 640 641 def testVarScopeOuterScope(self): 642 with self.test_session(): 643 with variable_scope.variable_scope("outer") as outer: 644 pass 645 with variable_scope.variable_scope(outer): 646 self.assertEqual(variable_scope.get_variable("w", []).name, "outer/w:0") 647 with ops.name_scope("scope2") as sc2: 648 self.assertEqual(sc2, "outer_1/scope2/") 649 with variable_scope.variable_scope("default"): 650 self.assertEqual( 651 variable_scope.get_variable("w", []).name, "outer/default/w:0") 652 with ops.name_scope("scope2") as sc2: 653 self.assertEqual(sc2, "outer_1/default/scope2/") 654 655 with variable_scope.variable_scope(outer, reuse=True): 656 self.assertEqual(variable_scope.get_variable("w", []).name, "outer/w:0") 657 with ops.name_scope("scope2") as sc2: 658 self.assertEqual(sc2, "outer_2/scope2/") 659 with variable_scope.variable_scope("default", reuse=True): 660 self.assertEqual( 661 variable_scope.get_variable("w", []).name, "outer/default/w:0") 662 with ops.name_scope("scope2") as sc2: 663 self.assertEqual(sc2, "outer_2/default/scope2/") 664 665 def testVarScopeNestedOuterScope(self): 666 with self.test_session(): 667 with variable_scope.variable_scope("outer") as outer: 668 with variable_scope.variable_scope(outer): 669 self.assertEqual( 670 variable_scope.get_variable("w", []).name, "outer/w:0") 671 with ops.name_scope("scope2") as sc2: 672 self.assertEqual(sc2, "outer/outer/scope2/") 673 with variable_scope.variable_scope("default"): 674 self.assertEqual( 675 variable_scope.get_variable("w", []).name, "outer/default/w:0") 676 with ops.name_scope("scope2") as sc2: 677 self.assertEqual(sc2, "outer/default/scope2/") 678 679 with variable_scope.variable_scope(outer, reuse=True): 680 self.assertEqual( 681 variable_scope.get_variable("w", []).name, "outer/w:0") 682 with ops.name_scope("scope2") as sc2: 683 self.assertEqual(sc2, "outer/outer_1/scope2/") 684 with variable_scope.variable_scope("default", reuse=True): 685 self.assertEqual( 686 variable_scope.get_variable("w", []).name, "outer/default/w:0") 687 with ops.name_scope("scope2") as sc2: 688 self.assertEqual(sc2, "outer/default_1/scope2/") 689 690 def testVarOpScopeReuseParam(self): 691 with self.test_session(): 692 with variable_scope.variable_scope("outer") as outer: 693 with variable_scope.variable_scope("tower", "default", []): 694 self.assertEqual( 695 variable_scope.get_variable("w", []).name, "outer/tower/w:0") 696 with ops.name_scope("scope2") as sc2: 697 self.assertEqual(sc2, "outer/tower/scope2/") 698 with variable_scope.variable_scope(None, "default", []): 699 self.assertEqual( 700 variable_scope.get_variable("w", []).name, "outer/default/w:0") 701 with ops.name_scope("scope2") as sc2: 702 self.assertEqual(sc2, "outer/default/scope2/") 703 704 with variable_scope.variable_scope(outer) as outer: 705 with variable_scope.variable_scope("tower", "default", reuse=True): 706 self.assertEqual( 707 variable_scope.get_variable("w", []).name, "outer/tower/w:0") 708 with ops.name_scope("scope2") as sc2: 709 self.assertEqual(sc2, "outer_1/tower/scope2/") 710 outer.reuse_variables() 711 with variable_scope.variable_scope(None, "default", []): 712 self.assertEqual( 713 variable_scope.get_variable("w", []).name, "outer/default/w:0") 714 with ops.name_scope("scope2") as sc2: 715 self.assertEqual(sc2, "outer_1/default/scope2/") 716 717 def testVarOpScopeReuseError(self): 718 with self.test_session(): 719 with self.assertRaises(ValueError): 720 with variable_scope.variable_scope(None, "default", reuse=True): 721 self.assertEqual( 722 variable_scope.get_variable("w", []).name, "outer/tower/w:0") 723 724 def testVarOpScopeOuterScope(self): 725 with self.test_session(): 726 with variable_scope.variable_scope("outer") as outer: 727 pass 728 with variable_scope.variable_scope(outer, "default", []): 729 self.assertEqual(variable_scope.get_variable("w", []).name, "outer/w:0") 730 with ops.name_scope("scope2") as sc2: 731 self.assertEqual(sc2, "outer_1/scope2/") 732 with variable_scope.variable_scope(None, "default", []): 733 self.assertEqual( 734 variable_scope.get_variable("w", []).name, "outer/default/w:0") 735 with ops.name_scope("scope2") as sc2: 736 self.assertEqual(sc2, "outer_1/default/scope2/") 737 738 with variable_scope.variable_scope(outer, "default", reuse=True): 739 self.assertEqual(variable_scope.get_variable("w", []).name, "outer/w:0") 740 with ops.name_scope("scope2") as sc2: 741 self.assertEqual(sc2, "outer_2/scope2/") 742 outer.reuse_variables() 743 with variable_scope.variable_scope(None, "default", []): 744 self.assertEqual( 745 variable_scope.get_variable("w", []).name, "outer/default/w:0") 746 with ops.name_scope("scope2") as sc2: 747 self.assertEqual(sc2, "outer_2/default/scope2/") 748 749 def testVarOpScopeNestedOuterScope(self): 750 with self.test_session(): 751 with variable_scope.variable_scope("outer") as outer: 752 with variable_scope.variable_scope(outer, "default", []): 753 self.assertEqual( 754 variable_scope.get_variable("w", []).name, "outer/w:0") 755 with ops.name_scope("scope2") as sc2: 756 self.assertEqual(sc2, "outer/outer/scope2/") 757 with variable_scope.variable_scope(None, "default", []): 758 self.assertEqual( 759 variable_scope.get_variable("w", []).name, "outer/default/w:0") 760 with ops.name_scope("scope2") as sc2: 761 self.assertEqual(sc2, "outer/default/scope2/") 762 763 with variable_scope.variable_scope(outer, "default", reuse=True): 764 self.assertEqual(variable_scope.get_variable("w", []).name, "outer/w:0") 765 with ops.name_scope("scope2") as sc2: 766 self.assertEqual(sc2, "outer_1/scope2/") 767 with variable_scope.variable_scope(None, "default", []): 768 self.assertEqual( 769 variable_scope.get_variable("w", []).name, "outer/default/w:0") 770 with ops.name_scope("scope2") as sc2: 771 self.assertEqual(sc2, "outer_1/default/scope2/") 772 773 def testBasicWhenAuxiliaryNameScopeIsFalse(self): 774 with self.test_session(): 775 with variable_scope.variable_scope( 776 "scope", auxiliary_name_scope=False) as scope: 777 self.assertEqual(scope.original_name_scope, "") 778 self.assertEqual(variable_scope.get_variable("w", []).name, "scope/w:0") 779 self.assertEqual(constant_op.constant([], name="c").name, "c:0") 780 with variable_scope.variable_scope(scope, auxiliary_name_scope=False): 781 self.assertEqual(scope.original_name_scope, "") 782 self.assertEqual( 783 variable_scope.get_variable("w1", []).name, "scope/w1:0") 784 self.assertEqual(constant_op.constant([], name="c1").name, "c1:0") 785 # Recheck: new name scope is NOT created before 786 with ops.name_scope("scope"): 787 self.assertEqual(constant_op.constant([], name="c").name, "scope/c:0") 788 789 with variable_scope.variable_scope("outer"): 790 with variable_scope.variable_scope( 791 "inner", auxiliary_name_scope=False) as inner: 792 self.assertEqual(inner.original_name_scope, "outer/") 793 self.assertEqual( 794 variable_scope.get_variable("w", []).name, "outer/inner/w:0") 795 self.assertEqual(constant_op.constant([], name="c").name, "outer/c:0") 796 with variable_scope.variable_scope( 797 inner, auxiliary_name_scope=False) as inner1: 798 self.assertEqual(inner1.original_name_scope, "outer/") 799 self.assertEqual( 800 variable_scope.get_variable("w1", []).name, "outer/inner/w1:0") 801 self.assertEqual( 802 constant_op.constant([], name="c1").name, "outer/c1:0") 803 # Recheck: new name scope is NOT created before 804 with ops.name_scope("inner"): 805 self.assertEqual( 806 constant_op.constant([], name="c").name, "outer/inner/c:0") 807 808 def testCreatedByDefaultNameWhenAuxiliaryNameScopeIsFalse(self): 809 with self.test_session(): 810 with variable_scope.variable_scope( 811 None, default_name="default", auxiliary_name_scope=False) as scope: 812 self.assertEqual(scope.original_name_scope, "") 813 self.assertEqual( 814 variable_scope.get_variable("w", []).name, "default/w:0") 815 self.assertEqual(constant_op.constant([], name="c").name, "c:0") 816 # Recheck: new name scope is NOT created before 817 with ops.name_scope("default"): 818 self.assertEqual(constant_op.constant([], name="c").name, "default/c:0") 819 820 with variable_scope.variable_scope("outer"): 821 with variable_scope.variable_scope( 822 None, default_name="default", auxiliary_name_scope=False) as inner: 823 self.assertEqual(inner.original_name_scope, "outer/") 824 self.assertEqual( 825 variable_scope.get_variable("w", []).name, "outer/default/w:0") 826 self.assertEqual(constant_op.constant([], name="c").name, "outer/c:0") 827 # Recheck: new name scope is NOT created before 828 with ops.name_scope("default"): 829 self.assertEqual( 830 constant_op.constant([], name="c").name, "outer/default/c:0") 831 832 def testReenterRootScopeWhenAuxiliaryNameScopeIsFalse(self): 833 with self.test_session(): 834 root_scope = variable_scope.get_variable_scope() 835 with variable_scope.variable_scope( 836 root_scope, auxiliary_name_scope=False) as scope: 837 self.assertEqual(scope.original_name_scope, "") 838 self.assertEqual(variable_scope.get_variable("w", []).name, "w:0") 839 self.assertEqual(constant_op.constant([], name="c").name, "c:0") 840 841 with variable_scope.variable_scope("outer"): 842 with variable_scope.variable_scope( 843 root_scope, auxiliary_name_scope=False) as inner: 844 self.assertEqual(inner.original_name_scope, "") 845 self.assertEqual(variable_scope.get_variable("w1", []).name, "w1:0") 846 self.assertEqual( 847 constant_op.constant([], name="c1").name, "outer/c1:0") 848 849 def testAuxiliaryNameScopeIsInvalid(self): 850 with self.test_session(): 851 with self.assertRaisesRegexp(TypeError, "auxiliary_name_scope"): 852 with variable_scope.variable_scope( 853 None, default_name="scope", auxiliary_name_scope="invalid"): 854 pass 855 856 with self.assertRaisesRegexp(TypeError, "auxiliary_name_scope"): 857 with variable_scope.variable_scope( 858 "scope", auxiliary_name_scope="invalid"): 859 pass 860 861 with variable_scope.variable_scope("scope") as scope: 862 pass 863 with self.assertRaisesRegexp(TypeError, "auxiliary_name_scope"): 864 with variable_scope.variable_scope( 865 scope, auxiliary_name_scope="invalid"): 866 pass 867 868 def testReuseScopeWithoutNameScopeCollision(self): 869 # Github issue: #13429 870 with self.test_session(): 871 with variable_scope.variable_scope("outer"): 872 with variable_scope.variable_scope("inner") as inner: 873 pass 874 875 with variable_scope.variable_scope( 876 inner, auxiliary_name_scope=False) as scope: 877 with ops.name_scope(scope.original_name_scope): 878 self.assertEqual( 879 variable_scope.get_variable("w", []).name, "outer/inner/w:0") 880 self.assertEqual( 881 constant_op.constant([], name="c").name, "outer/inner/c:0") 882 with ops.name_scope("inner"): 883 self.assertEqual(constant_op.constant([], name="c").name, "inner/c:0") 884 885 with variable_scope.variable_scope("another"): 886 with variable_scope.variable_scope( 887 inner, auxiliary_name_scope=False) as scope1: 888 with ops.name_scope(scope1.original_name_scope): 889 self.assertEqual( 890 variable_scope.get_variable("w1", []).name, "outer/inner/w1:0") 891 self.assertEqual( 892 constant_op.constant([], name="c1").name, "outer/inner/c1:0") 893 with ops.name_scope("inner"): 894 self.assertEqual( 895 constant_op.constant([], name="c").name, "another/inner/c:0") 896 897 @test_util.run_in_graph_and_eager_modes() 898 def testGetLocalVar(self): 899 # Check that local variable respects naming. 900 with variable_scope.variable_scope("outer") as outer: 901 with variable_scope.variable_scope(outer, "default", []): 902 local_var = variable_scope.get_local_variable( 903 "w", [], collections=["foo"]) 904 self.assertEqual(local_var.name, "outer/w:0") 905 906 # Since variable is local, it should be in the local variable collection 907 # but not the trainable collection. 908 if context.in_graph_mode(): 909 self.assertIn(local_var, 910 ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) 911 self.assertIn(local_var, ops.get_collection("foo")) 912 self.assertNotIn(local_var, 913 ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) 914 915 # Check that local variable respects `reuse`. 916 if context.in_graph_mode(): 917 with variable_scope.variable_scope(outer, "default", reuse=True): 918 self.assertEqual( 919 variable_scope.get_local_variable("w", []).name, "outer/w:0") 920 921 def testGetVarWithDevice(self): 922 g = ops.Graph() 923 varname_type = [] 924 925 def device_func(op): 926 if op.type in ["Variable", "VariableV2", "VarHandleOp"]: 927 varname_type.append((op.name, op.get_attr("dtype"))) 928 return "/device:GPU:0" 929 930 with g.as_default(): 931 with ops.device(device_func): 932 _ = variable_scope.get_variable("x", (100, 200)) 933 _ = variable_scope.get_variable( 934 "y", dtype=dtypes.int64, initializer=numpy.arange(73)) 935 self.assertEqual(varname_type[0], ("x", dtypes.float32)) 936 self.assertEqual(varname_type[1], ("y", dtypes.int64)) 937 938 def testGetCollection(self): 939 with self.test_session(): 940 _ = variable_scope.get_variable("testGetCollection_a", []) 941 _ = variable_scope.get_variable("testGetCollection_b", [], 942 trainable=False) 943 with variable_scope.variable_scope("testGetCollection_foo_") as scope1: 944 _ = variable_scope.get_variable("testGetCollection_a", []) 945 _ = variable_scope.get_variable("testGetCollection_b", [], 946 trainable=False) 947 self.assertEqual([ 948 v.name 949 for v in scope1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) 950 ], ["testGetCollection_foo_/testGetCollection_a:0"]) 951 self.assertEqual([ 952 v.name 953 for v in scope1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 954 ], [ 955 "testGetCollection_foo_/testGetCollection_a:0", 956 "testGetCollection_foo_/testGetCollection_b:0" 957 ]) 958 with variable_scope.variable_scope("testGetCollection_foo") as scope2: 959 _ = variable_scope.get_variable("testGetCollection_a", []) 960 _ = variable_scope.get_variable("testGetCollection_b", [], 961 trainable=False) 962 self.assertEqual([ 963 v.name 964 for v in scope2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) 965 ], ["testGetCollection_foo/testGetCollection_a:0"]) 966 self.assertEqual([ 967 v.name 968 for v in scope2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 969 ], [ 970 "testGetCollection_foo/testGetCollection_a:0", 971 "testGetCollection_foo/testGetCollection_b:0" 972 ]) 973 scope = variable_scope.get_variable_scope() 974 self.assertEqual([ 975 v.name for v in scope.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 976 ], [ 977 "testGetCollection_a:0", "testGetCollection_b:0", 978 "testGetCollection_foo_/testGetCollection_a:0", 979 "testGetCollection_foo_/testGetCollection_b:0", 980 "testGetCollection_foo/testGetCollection_a:0", 981 "testGetCollection_foo/testGetCollection_b:0" 982 ]) 983 self.assertEqual([ 984 v.name 985 for v in scope.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) 986 ], [ 987 "testGetCollection_a:0", 988 "testGetCollection_foo_/testGetCollection_a:0", 989 "testGetCollection_foo/testGetCollection_a:0" 990 ]) 991 992 def testGetTrainableVariables(self): 993 with self.test_session(): 994 _ = variable_scope.get_variable("testGetTrainableVariables_a", []) 995 with variable_scope.variable_scope( 996 "testGetTrainableVariables_foo") as scope: 997 _ = variable_scope.get_variable("testGetTrainableVariables_b", []) 998 _ = variable_scope.get_variable("testGetTrainableVariables_c", [], 999 trainable=False) 1000 self.assertEqual([v.name 1001 for v in scope.trainable_variables()], 1002 ["testGetTrainableVariables_foo/" 1003 "testGetTrainableVariables_b:0"]) 1004 1005 def testGetGlobalVariables(self): 1006 with self.test_session(): 1007 _ = variable_scope.get_variable("testGetGlobalVariables_a", []) 1008 with variable_scope.variable_scope("testGetGlobalVariables_foo") as scope: 1009 _ = variable_scope.get_variable("testGetGlobalVariables_b", []) 1010 self.assertEqual([v.name 1011 for v in scope.global_variables()], 1012 ["testGetGlobalVariables_foo/" 1013 "testGetGlobalVariables_b:0"]) 1014 1015 def testGetLocalVariables(self): 1016 with self.test_session(): 1017 _ = variable_scope.get_variable( 1018 "a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES]) 1019 with variable_scope.variable_scope("foo") as scope: 1020 _ = variable_scope.get_variable( 1021 "b", [], collections=[ops.GraphKeys.LOCAL_VARIABLES]) 1022 _ = variable_scope.get_variable( 1023 "c", []) 1024 self.assertEqual([v.name 1025 for v in scope.local_variables()], ["foo/b:0"]) 1026 1027 def testGetVariableWithRefDtype(self): 1028 v = variable_scope.get_variable("v", shape=[3, 4], dtype=dtypes.float32) 1029 # Ensure it is possible to do get_variable with a _ref dtype passed in. 1030 _ = variable_scope.get_variable("w", shape=[5, 6], dtype=v.dtype) 1031 1032 def testTwoGraphs(self): 1033 1034 def f(): 1035 g1 = ops.Graph() 1036 g2 = ops.Graph() 1037 with g1.as_default(): 1038 with g2.as_default(): 1039 with variable_scope.variable_scope("_"): 1040 pass 1041 1042 self.assertRaisesRegexp(ValueError, "'_' is not a valid scope name", f) 1043 1044 1045 def axis0_into1_partitioner(shape=None, **unused_kwargs): 1046 part = [1] * len(shape) 1047 return part 1048 1049 1050 def axis0_into2_partitioner(shape=None, **unused_kwargs): 1051 part = [1] * len(shape) 1052 part[0] = 2 1053 return part 1054 1055 1056 def axis0_into3_partitioner(shape=None, **unused_kwargs): 1057 part = [1] * len(shape) 1058 part[0] = 3 1059 return part 1060 1061 1062 class VariableScopeWithPartitioningTest(test.TestCase): 1063 1064 def testResultNameMatchesRequested(self): 1065 with variable_scope.variable_scope( 1066 "scope0", partitioner=axis0_into2_partitioner): 1067 v = variable_scope.get_variable("name0", shape=(3, 1, 1)) 1068 self.assertEqual(v.name, "scope0/name0") 1069 v_concat = v.as_tensor() 1070 self.assertEqual(v_concat.name, "scope0/name0:0") 1071 variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 1072 self.assertIn("scope0/name0/part_0:0", [x.name for x in variables]) 1073 self.assertIn("scope0/name0/part_1:0", [x.name for x in variables]) 1074 self.assertNotIn("scope0/name0/part_2:0", [x.name for x in variables]) 1075 1076 def testBreaksIfPartitioningChanges(self): 1077 with variable_scope.variable_scope( 1078 "scope0", partitioner=axis0_into2_partitioner): 1079 variable_scope.get_variable("name0", shape=(3, 1, 1)) 1080 1081 with variable_scope.variable_scope( 1082 "scope0", partitioner=axis0_into3_partitioner, reuse=True): 1083 with self.assertRaisesRegexp( 1084 ValueError, 1085 "Trying to reuse partitioned variable .* but specified partitions .* " 1086 "and found partitions .*"): 1087 variable_scope.get_variable("name0", shape=(3, 1, 1)) 1088 1089 with variable_scope.variable_scope( 1090 "scope0", partitioner=axis0_into1_partitioner, reuse=True): 1091 with self.assertRaisesRegexp( 1092 ValueError, 1093 "Trying to reuse partitioned variable .* but specified partitions .* " 1094 "and found partitions .*"): 1095 variable_scope.get_variable("name0", shape=(3, 1, 1)) 1096 1097 def testReturnsExistingConcatenatedValueIfReuse(self): 1098 with variable_scope.variable_scope( 1099 "scope0", partitioner=axis0_into2_partitioner): 1100 v_concat = variable_scope.get_variable("name0", shape=(3, 1, 1)) 1101 variable_scope.get_variable_scope().reuse_variables() 1102 v_concat_2 = variable_scope.get_variable("name0", shape=(3, 1, 1)) 1103 self.assertEqual(v_concat, v_concat_2) 1104 1105 def testAllowsReuseWithoutPartitioner(self): 1106 with variable_scope.variable_scope( 1107 "scope0", partitioner=axis0_into2_partitioner): 1108 v = variable_scope.get_variable("name0", shape=(3, 1, 1)) 1109 with variable_scope.variable_scope("scope0", reuse=True): 1110 v_reused = variable_scope.get_variable("name0") 1111 self.assertEqual(v, v_reused) 1112 1113 def testPropagatePartitionerOnReopening(self): 1114 with variable_scope.variable_scope( 1115 "scope0", partitioner=axis0_into2_partitioner) as vs: 1116 self.assertEqual(axis0_into2_partitioner, vs.partitioner) 1117 with variable_scope.variable_scope(vs) as vs1: 1118 self.assertEqual(axis0_into2_partitioner, vs1.partitioner) 1119 1120 def testScalarIgnoresPartitioner(self): 1121 with variable_scope.variable_scope( 1122 "scope0", partitioner=axis0_into2_partitioner): 1123 v = variable_scope.get_variable("name0", shape=()) 1124 self.assertEqual(v.name, "scope0/name0:0") 1125 variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 1126 self.assertIn("scope0/name0:0", [x.name for x in variables]) 1127 1128 def _testPartitionConcatenatesAlongCorrectAxis(self, use_resource): 1129 1130 def _part_axis_0(**unused_kwargs): 1131 return (2, 1, 1) 1132 1133 def _part_axis_1(**unused_kwargs): 1134 return (1, 2, 1) 1135 1136 with variable_scope.variable_scope("root", use_resource=use_resource): 1137 v0 = variable_scope.get_variable( 1138 "n0", shape=(2, 2, 2), partitioner=_part_axis_0) 1139 v1 = variable_scope.get_variable( 1140 "n1", shape=(2, 2, 2), partitioner=_part_axis_1) 1141 1142 self.assertEqual(v0.get_shape(), (2, 2, 2)) 1143 self.assertEqual(v1.get_shape(), (2, 2, 2)) 1144 1145 n0_0 = list(v0)[0] 1146 n0_1 = list(v0)[1] 1147 self.assertEqual(n0_0.get_shape(), (1, 2, 2)) 1148 self.assertEqual(n0_1.get_shape(), (1, 2, 2)) 1149 1150 n1_0 = list(v1)[0] 1151 n1_1 = list(v1)[1] 1152 self.assertEqual(n1_0.get_shape(), (2, 1, 2)) 1153 self.assertEqual(n1_1.get_shape(), (2, 1, 2)) 1154 1155 def testPartitionConcatenatesAlongCorrectAxis(self): 1156 self._testPartitionConcatenatesAlongCorrectAxis(use_resource=False) 1157 1158 def testPartitionConcatenatesAlongCorrectAxisResource(self): 1159 self._testPartitionConcatenatesAlongCorrectAxis(use_resource=True) 1160 1161 1162 class VariableScopeWithCustomGetterTest(test.TestCase): 1163 1164 def testNonCallableGetterFails(self): 1165 with self.assertRaisesRegexp(ValueError, r"custom_getter .* not callable:"): 1166 with variable_scope.variable_scope("scope0", custom_getter=3): 1167 variable_scope.get_variable("name0") 1168 with self.assertRaisesRegexp(ValueError, r"custom_getter .* not callable:"): 1169 variable_scope.get_variable("name0", custom_getter=3) 1170 1171 def testNoSideEffectsWithIdentityCustomGetter(self): 1172 called = [0] 1173 1174 def custom_getter(getter, *args, **kwargs): 1175 called[0] += 1 1176 return getter(*args, **kwargs) 1177 1178 with variable_scope.variable_scope( 1179 "scope", custom_getter=custom_getter) as scope: 1180 v = variable_scope.get_variable("v", [1]) 1181 with variable_scope.variable_scope(scope, reuse=True): 1182 v2 = variable_scope.get_variable("v", [1]) 1183 with variable_scope.variable_scope("new_scope") as new_scope: 1184 v3 = variable_scope.get_variable("v3", [1]) 1185 with variable_scope.variable_scope( 1186 new_scope, reuse=True, custom_getter=custom_getter): 1187 v4 = variable_scope.get_variable("v3", [1]) 1188 1189 self.assertEqual(v, v2) 1190 self.assertEqual(v3, v4) 1191 self.assertEqual(3, called[0]) # skipped one in the first new_scope 1192 1193 def testCustomGetterWithReuse(self): 1194 # Custom getter can choose to behave differently on reused variables. 1195 def custom_getter(getter, *args, **kwargs): 1196 var = getter(*args, **kwargs) 1197 if kwargs["reuse"]: 1198 # This can be used, e.g., for changing the caching device if needed. 1199 return array_ops.identity(var, name="reused") 1200 else: 1201 return array_ops.identity(var, name="not_reused") 1202 1203 with variable_scope.variable_scope( 1204 "scope", custom_getter=custom_getter) as scope: 1205 v = variable_scope.get_variable("v", [1]) 1206 with variable_scope.variable_scope(scope, reuse=True): 1207 v2 = variable_scope.get_variable("v", [1]) 1208 1209 self.assertEqual(v.name, "not_reused:0") 1210 self.assertEqual(v2.name, "reused:0") 1211 1212 def testGetterThatCreatesTwoVariablesAndSumsThem(self): 1213 1214 def custom_getter(getter, name, *args, **kwargs): 1215 g_0 = getter("%s/0" % name, *args, **kwargs) 1216 g_1 = getter("%s/1" % name, *args, **kwargs) 1217 with ops.name_scope("custom_getter"): 1218 return g_0 + g_1 1219 1220 with variable_scope.variable_scope("scope", custom_getter=custom_getter): 1221 v = variable_scope.get_variable("v", [1, 2, 3]) 1222 1223 self.assertEqual([1, 2, 3], v.get_shape()) 1224 true_vars = variables_lib.trainable_variables() 1225 self.assertEqual(2, len(true_vars)) 1226 self.assertEqual("scope/v/0:0", true_vars[0].name) 1227 self.assertEqual("scope/v/1:0", true_vars[1].name) 1228 self.assertEqual("custom_getter/add:0", v.name) 1229 with self.test_session() as sess: 1230 variables_lib.global_variables_initializer().run() 1231 np_vars, np_v = sess.run([true_vars, v]) 1232 self.assertAllClose(np_v, sum(np_vars)) 1233 1234 def testNestedCustomGetters(self): 1235 1236 def sum_getter(getter, name, *args, **kwargs): 1237 g_0 = getter("%s/sum_0" % name, *args, **kwargs) 1238 g_1 = getter("%s/sum_1" % name, *args, **kwargs) 1239 with ops.name_scope("sum_getter"): 1240 return g_0 + g_1 1241 1242 def prod_getter(getter, name, *args, **kwargs): 1243 g_0 = getter("%s/prod_0" % name, *args, **kwargs) 1244 g_1 = getter("%s/prod_1" % name, *args, **kwargs) 1245 with ops.name_scope("prod_getter"): 1246 return g_0 * g_1 1247 1248 with variable_scope.variable_scope( 1249 "prod_scope", custom_getter=prod_getter): 1250 with variable_scope.variable_scope( 1251 "sum_scope", custom_getter=sum_getter): 1252 with variable_scope.variable_scope( 1253 "inner_sum_scope", custom_getter=sum_getter): 1254 # take sums of sums of products 1255 v = variable_scope.get_variable("v", [1, 2, 3]) 1256 1257 self.assertEqual([1, 2, 3], v.get_shape()) 1258 true_vars = variables_lib.trainable_variables() 1259 self.assertEqual(8, len(true_vars)) 1260 template = ( 1261 "prod_scope/sum_scope/inner_sum_scope/v/sum_%d/sum_%d/prod_%d:0") 1262 self.assertEqual(template % (0, 0, 0), true_vars[0].name) 1263 self.assertEqual(template % (0, 0, 1), true_vars[1].name) 1264 self.assertEqual(template % (0, 1, 0), true_vars[2].name) 1265 self.assertEqual(template % (0, 1, 1), true_vars[3].name) 1266 self.assertEqual(template % (1, 0, 0), true_vars[4].name) 1267 self.assertEqual(template % (1, 0, 1), true_vars[5].name) 1268 self.assertEqual(template % (1, 1, 0), true_vars[6].name) 1269 self.assertEqual(template % (1, 1, 1), true_vars[7].name) 1270 1271 with self.test_session() as sess: 1272 variables_lib.global_variables_initializer().run() 1273 np_vars, np_v = sess.run([true_vars, v]) 1274 # take products of sums of products 1275 self.assertAllClose( 1276 np_v, 1277 (((np_vars[0] * np_vars[1]) + (np_vars[2] * np_vars[3])) 1278 + ((np_vars[4] * np_vars[5]) + (np_vars[6] * np_vars[7])))) 1279 1280 def testVariableCreator(self): 1281 1282 variable_names = [] 1283 1284 def creator_a(next_creator, **kwargs): 1285 variable_names.append(kwargs.get("name", "")) 1286 return next_creator(**kwargs) 1287 1288 def creator_b(next_creator, **kwargs): 1289 kwargs["name"] = "forced_name" 1290 return next_creator(**kwargs) 1291 1292 with variable_scope.variable_creator_scope(creator_a): 1293 with variable_scope.variable_creator_scope(creator_b): 1294 variable_scope.variable(1.0, name="one_name") 1295 1296 self.assertAllEqual(variable_names, ["forced_name"]) 1297 1298 1299 class PartitionInfoTest(test.TestCase): 1300 1301 def testConstructorChecks(self): 1302 # Invalid arg types. 1303 with self.assertRaises(TypeError): 1304 variable_scope._PartitionInfo(full_shape=None, var_offset=[0, 1]) 1305 with self.assertRaises(TypeError): 1306 variable_scope._PartitionInfo(full_shape=[0, 1], var_offset=None) 1307 with self.assertRaises(TypeError): 1308 variable_scope._PartitionInfo(full_shape="foo", var_offset=[0, 1]) 1309 with self.assertRaises(TypeError): 1310 variable_scope._PartitionInfo(full_shape=[0, 1], var_offset="foo") 1311 1312 # full_shape and var_offset must have same length. 1313 with self.assertRaises(ValueError): 1314 variable_scope._PartitionInfo(full_shape=[0, 1], var_offset=[0]) 1315 # Offset must always be less than shape. 1316 with self.assertRaises(ValueError): 1317 variable_scope._PartitionInfo(full_shape=[1, 1], var_offset=[0, 1]) 1318 1319 def testSingleOffset(self): 1320 partition_info = variable_scope._PartitionInfo( 1321 full_shape=[9, 3], var_offset=[4, 0]) 1322 self.assertEqual(4, partition_info.single_offset([1, 3])) 1323 1324 # Tests when the variable isn't partitioned at all. 1325 partition_info = variable_scope._PartitionInfo( 1326 full_shape=[9, 3], var_offset=[0, 0]) 1327 self.assertEqual(0, partition_info.single_offset([9, 3])) 1328 1329 def testSingleSliceDim(self): 1330 partition_info = variable_scope._PartitionInfo( 1331 full_shape=[9, 3], var_offset=[4, 0]) 1332 # Invalid shape. 1333 with self.assertRaises(TypeError): 1334 partition_info.single_slice_dim(None) 1335 1336 # Rank of shape differs from full_shape. 1337 with self.assertRaises(ValueError): 1338 partition_info.single_slice_dim([1, 2, 3]) 1339 1340 # Shape is too large given var_offset (4+6 > 9). 1341 with self.assertRaises(ValueError): 1342 partition_info.single_slice_dim([6, 3]) 1343 1344 # Multiple possible slice dim from shape. 1345 with self.assertRaises(ValueError): 1346 partition_info.single_slice_dim([1, 1]) 1347 1348 partition_info = variable_scope._PartitionInfo( 1349 full_shape=[9, 3], var_offset=[0, 0]) 1350 self.assertEqual(1, partition_info.single_slice_dim([9, 2])) 1351 partition_info = variable_scope._PartitionInfo( 1352 full_shape=[9, 3], var_offset=[4, 0]) 1353 self.assertEqual(0, partition_info.single_slice_dim([2, 3])) 1354 1355 1356 if __name__ == "__main__": 1357 test.main() 1358