1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for make_template.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import functools 21 import traceback 22 23 from tensorflow.python.client import session 24 from tensorflow.python.eager import context 25 from tensorflow.python.framework import ops 26 from tensorflow.python.framework import random_seed 27 from tensorflow.python.framework import test_util 28 from tensorflow.python.ops import array_ops 29 from tensorflow.python.ops import init_ops 30 from tensorflow.python.ops import math_ops 31 from tensorflow.python.ops import template 32 from tensorflow.python.ops import variable_scope 33 from tensorflow.python.ops import variables 34 import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 35 from tensorflow.python.platform import test 36 from tensorflow.python.training import gradient_descent 37 38 39 def variable_scoped_function(trainable=True): 40 return variable_scope.get_variable( 41 "dummy", shape=[1], trainable=trainable, 42 initializer=init_ops.zeros_initializer()) 43 44 45 def internally_variable_scoped_function(scope_name): 46 with variable_scope.variable_scope(scope_name): 47 return variable_scope.get_variable( 48 "dummy", shape=[1], initializer=init_ops.zeros_initializer()) 49 50 51 def function_with_create(trainable): 52 """Creates a variable as a side effect using tf.Variable.""" 53 variables.Variable(0, trainable=trainable) 54 return variable_scope.get_variable( 55 "dummy", shape=[1], initializer=init_ops.zeros_initializer()) 56 57 58 def function_with_side_create(trainable, name="side"): 59 """Creates a variable as a side effect using tf.get_variable.""" 60 variable_scope.get_variable(name, shape=[1], trainable=trainable) 61 return variable_scope.get_variable( 62 "dummy", shape=[1], initializer=init_ops.zeros_initializer()) 63 64 65 def variable_scoped_function_with_local_variable(): 66 variable_scope.get_local_variable( 67 "local", shape=[1], initializer=init_ops.zeros_initializer()) 68 return variable_scope.get_variable( 69 "dummy", shape=[1], initializer=init_ops.zeros_initializer()) 70 71 72 class TemplateTest(test.TestCase): 73 74 def test_end_to_end(self): 75 """This test shows a very simple line model with test_loss. 76 77 The template is used to share parameters between a training and test model. 78 """ 79 # y = 2x + 1 80 training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7]) 81 test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17]) 82 83 random_seed.set_random_seed(1234) 84 85 def test_line(x): 86 m = variable_scope.get_variable( 87 "w", shape=[], initializer=init_ops.truncated_normal_initializer()) 88 b = variable_scope.get_variable( 89 "b", shape=[], initializer=init_ops.truncated_normal_initializer()) 90 return x * m + b 91 92 line_template = template.make_template("line", test_line) 93 94 train_prediction = line_template(training_input) 95 test_prediction = line_template(test_input) 96 97 train_loss = math_ops.reduce_mean( 98 math_ops.square(train_prediction - training_output)) 99 test_loss = math_ops.reduce_mean( 100 math_ops.square(test_prediction - test_output)) 101 102 optimizer = gradient_descent.GradientDescentOptimizer(0.1) 103 train_op = optimizer.minimize(train_loss) 104 105 with session.Session() as sess: 106 sess.run(variables.global_variables_initializer()) 107 initial_test_loss = sess.run(test_loss) 108 sess.run(train_op) 109 final_test_loss = sess.run(test_loss) 110 111 # Parameters are tied, so the loss should have gone down when we trained it. 112 self.assertLess(final_test_loss, initial_test_loss) 113 114 def test_end_to_end_eager(self): 115 """This test shows a very simple line model with test_loss in eager mode. 116 117 The template is used to share parameters between a training and test model. 118 """ 119 with context.eager_mode(): 120 # y = 2x + 1 121 training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7]) 122 test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17]) 123 124 random_seed.set_random_seed(1234) 125 126 def test_line(x): 127 m = variable_scope.get_variable( 128 "w", shape=[], initializer=init_ops.truncated_normal_initializer()) 129 b = variable_scope.get_variable( 130 "b", shape=[], initializer=init_ops.truncated_normal_initializer()) 131 return x * m + b 132 133 line_template = template.make_template("line", test_line) 134 135 def train_loss(): 136 train_prediction = line_template(training_input) 137 return math_ops.reduce_mean( 138 math_ops.square(train_prediction - training_output)) 139 140 def test_loss(): 141 test_prediction = line_template(test_input) 142 return math_ops.reduce_mean( 143 math_ops.square(test_prediction - test_output)) 144 145 optimizer = gradient_descent.GradientDescentOptimizer(0.1) 146 initial_test_loss = test_loss() 147 optimizer.minimize(train_loss) 148 final_test_loss = test_loss() 149 150 # Parameters are tied, so the loss should have gone down after training. 151 self.assertLess(final_test_loss.numpy(), initial_test_loss.numpy()) 152 153 @test_util.run_in_graph_and_eager_modes() 154 def test_skip_stack_frames(self): 155 first = traceback.format_stack() 156 second = traceback.format_stack() 157 result = template._skip_common_stack_elements(first, second) 158 self.assertEqual(1, len(result)) 159 self.assertNotEqual(len(first), len(result)) 160 161 @test_util.run_in_graph_and_eager_modes() 162 def test_template_with_name(self): 163 tmpl1 = template.make_template("s1", variable_scoped_function) 164 tmpl2 = template.make_template("s1", variable_scoped_function) 165 166 v1 = tmpl1() 167 v2 = tmpl1() 168 v3 = tmpl2() 169 self.assertEqual(v1, v2) 170 self.assertNotEqual(v1, v3) 171 self.assertEqual("s1/dummy:0", v1.name) 172 self.assertEqual("s1_1/dummy:0", v3.name) 173 174 def test_same_unique_name_raise_error(self): 175 tmpl1 = template.make_template( 176 "_", variable_scoped_function, unique_name_="s1") 177 tmpl1() 178 tmpl2 = template.make_template( 179 "_", variable_scoped_function, unique_name_="s1") 180 with self.assertRaisesRegexp( 181 ValueError, "Variable s1/dummy already exists, disallowed.*"): 182 tmpl2() 183 184 def test_unique_name_raise_error_in_eager(self): 185 with context.eager_mode(): 186 with self.assertRaisesRegexp( 187 ValueError, 188 "unique_name_ cannot be used when eager exeuction is enabled."): 189 template.make_template( 190 "_", variable_scoped_function, unique_name_="s1") 191 192 def test_unique_name_and_reuse(self): 193 tmpl1 = template.make_template( 194 "_", variable_scoped_function, unique_name_="s1") 195 v1 = tmpl1() 196 v2 = tmpl1() 197 198 variable_scope.get_variable_scope().reuse_variables() 199 tmpl2 = template.make_template( 200 "_", variable_scoped_function, unique_name_="s1") 201 v3 = tmpl2() 202 203 self.assertEqual(v1, v2) 204 self.assertEqual(v1, v3) 205 self.assertEqual("s1/dummy:0", v1.name) 206 207 @test_util.run_in_graph_and_eager_modes() 208 def test_template_in_scope(self): 209 tmpl1 = template.make_template("s1", variable_scoped_function) 210 tmpl2 = template.make_template("s1", variable_scoped_function) 211 212 with variable_scope.variable_scope("scope"): 213 v1 = tmpl1() 214 v3 = tmpl2() 215 216 # The template contract requires the following to ignore scope2. 217 with variable_scope.variable_scope("scope2"): 218 v2 = tmpl1() 219 self.assertEqual(v1, v2) 220 self.assertNotEqual(v1, v3) 221 self.assertEqual("scope/s1/dummy:0", v1.name) 222 self.assertEqual("scope/s1_1/dummy:0", v3.name) 223 224 @test_util.run_in_graph_and_eager_modes() 225 def test_template_with_internal_reuse(self): 226 tmpl1 = template.make_template("s1", internally_variable_scoped_function) 227 tmpl2 = template.make_template("s1", internally_variable_scoped_function) 228 229 v1 = tmpl1("test") 230 v2 = tmpl1("test") 231 v3 = tmpl2("test") 232 self.assertEqual(v1, v2) 233 self.assertNotEqual(v1, v3) 234 self.assertEqual("s1/test/dummy:0", v1.name) 235 self.assertEqual("s1_1/test/dummy:0", v3.name) 236 237 with self.assertRaises(ValueError): 238 tmpl1("not_test") 239 240 @test_util.run_in_graph_and_eager_modes() 241 def test_template_without_name(self): 242 with self.assertRaisesRegexp( 243 ValueError, "name cannot be None."): 244 template.make_template(None, variable_scoped_function) 245 246 @test_util.run_in_graph_and_eager_modes() 247 def test_make_template(self): 248 # Test both that we can call it with positional and keywords. 249 tmpl1 = template.make_template( 250 "s1", internally_variable_scoped_function, scope_name="test") 251 tmpl2 = template.make_template( 252 "s1", internally_variable_scoped_function, scope_name="test") 253 254 v1 = tmpl1() 255 v2 = tmpl1() 256 v3 = tmpl2() 257 self.assertEqual(v1, v2) 258 self.assertNotEqual(v1, v3) 259 self.assertEqual("s1/test/dummy:0", v1.name) 260 self.assertEqual("s1_1/test/dummy:0", v3.name) 261 262 def test_enforces_no_extra_trainable_variables(self): 263 tmpl = template.make_template("s", function_with_create, trainable=True) 264 265 tmpl() 266 with self.assertRaises(ValueError): 267 tmpl() 268 269 @test_util.run_in_graph_and_eager_modes() 270 def test_enforces_no_extra_trainable_variables_eager(self): 271 tmpl = template.make_template("s", 272 function_with_side_create, 273 trainable=True) 274 275 tmpl(name="1") 276 with self.assertRaises(ValueError): 277 tmpl(name="2") 278 279 def test_permits_extra_non_trainable_variables(self): 280 tmpl = template.make_template("s", function_with_create, trainable=False) 281 self.assertEqual(tmpl(), tmpl()) 282 283 def test_permits_extra_non_trainable_variables_eager(self): 284 with context.eager_mode(): 285 tmpl = template.make_template("s", 286 function_with_side_create, 287 trainable=False) 288 self.assertEqual(tmpl(name="1"), tmpl(name="2")) 289 290 @test_util.run_in_graph_and_eager_modes() 291 def test_internal_variable_reuse(self): 292 293 def nested(): 294 with variable_scope.variable_scope("nested") as vs: 295 v1 = variable_scope.get_variable( 296 "x", initializer=init_ops.zeros_initializer(), shape=[]) 297 with variable_scope.variable_scope(vs, reuse=True): 298 v2 = variable_scope.get_variable("x") 299 self.assertEqual(v1, v2) 300 return v1 301 302 tmpl1 = template.make_template("s1", nested) 303 tmpl2 = template.make_template("s1", nested) 304 305 v1 = tmpl1() 306 v2 = tmpl1() 307 v3 = tmpl2() 308 self.assertEqual(v1, v2) 309 self.assertNotEqual(v1, v3) 310 self.assertEqual("s1/nested/x:0", v1.name) 311 self.assertEqual("s1_1/nested/x:0", v3.name) 312 313 @test_util.run_in_graph_and_eager_modes() 314 def test_nested_templates(self): 315 316 def nested_template(): 317 nested1 = template.make_template("nested", variable_scoped_function) 318 nested2 = template.make_template("nested", variable_scoped_function) 319 v1 = nested1() 320 v2 = nested2() 321 322 # nested1 and nested2 should not share variables 323 self.assertNotEqual(v1, v2) 324 325 # Variables created by nested1 should be isolated from variables 326 # created by nested2. 327 self.assertEqual(nested1.variables, [v1]) 328 self.assertEqual(nested2.variables, [v2]) 329 self.assertEqual(nested1.trainable_variables, [v1]) 330 self.assertEqual(nested2.trainable_variables, [v2]) 331 self.assertEqual(len(nested1.non_trainable_variables), 0) 332 self.assertEqual(len(nested2.non_trainable_variables), 0) 333 return v1, v2 334 335 tmpl1 = template.make_template("s1", nested_template) 336 tmpl2 = template.make_template("s1", nested_template) 337 338 v1, v2 = tmpl1() 339 v3, v4 = tmpl1() 340 v5, v6 = tmpl2() 341 342 # The second invocation of tmpl1 should reuse the variables 343 # created in the first invocation. 344 self.assertEqual([v1, v2], [v3, v4]) 345 self.assertEqual(tmpl1.variables, [v1, v2]) 346 self.assertEqual(tmpl1.trainable_variables, [v1, v2]) 347 self.assertEqual(len(tmpl1.non_trainable_variables), 0) 348 349 # tmpl1 and tmpl2 should not share variables. 350 self.assertNotEqual([v1, v2], [v5, v6]) 351 self.assertSequenceEqual(tmpl2.variables, [v5, v6]) 352 self.assertSequenceEqual(tmpl2.trainable_variables, [v5, v6]) 353 self.assertEqual(len(tmpl2.non_trainable_variables), 0) 354 self.assertEqual("s1/nested/dummy:0", v1.name) 355 self.assertEqual("s1/nested_1/dummy:0", v2.name) 356 self.assertEqual("s1_1/nested/dummy:0", v5.name) 357 self.assertEqual("s1_1/nested_1/dummy:0", v6.name) 358 359 @test_util.run_in_graph_and_eager_modes() 360 def test_nested_templates_with_defun(self): 361 362 def variable_scoped_function_no_return_value(trainable=True): 363 # defun cannot compile functions that return non-Tensor objects 364 _ = variable_scope.get_variable( 365 "dummy", 366 shape=[1], 367 trainable=trainable, 368 initializer=init_ops.zeros_initializer()) 369 370 def nested_template(): 371 nested1 = template.make_template_internal( 372 "nested", 373 variable_scoped_function_no_return_value, 374 create_graph_function_=True) 375 nested2 = template.make_template_internal( 376 "nested", 377 variable_scoped_function_no_return_value, 378 create_graph_function_=True) 379 nested1() 380 nested2() 381 v1 = nested1.variables 382 v2 = nested2.variables 383 384 # nested1 and nested2 should not share variables 385 self.assertNotEqual(v1, v2) 386 387 # Variables created by nested1 should be isolated from variables 388 # created by nested2. 389 self.assertEqual(nested1.variables, v1) 390 self.assertEqual(nested2.variables, v2) 391 self.assertEqual(nested1.trainable_variables, v1) 392 self.assertEqual(nested2.trainable_variables, v2) 393 self.assertEqual(len(nested1.non_trainable_variables), 0) 394 self.assertEqual(len(nested2.non_trainable_variables), 0) 395 396 tmpl1 = template.make_template("s1", nested_template) 397 tmpl2 = template.make_template("s1", nested_template) 398 399 tmpl1() 400 v1 = tmpl1.variables 401 tmpl1() 402 v2 = tmpl1.variables 403 tmpl2() 404 v3 = tmpl2.variables 405 406 # The second invocation of tmpl1 should reuse the variables 407 # created in the first invocation. 408 self.assertSequenceEqual(v1, v2) 409 410 # tmpl1 and tmpl2 should not share variables. 411 self.assertNotEqual(v1, v3) 412 self.assertEqual("s1/nested/dummy:0", v1[0].name) 413 self.assertEqual("s1/nested_1/dummy:0", v1[1].name) 414 self.assertEqual("s1_1/nested/dummy:0", v3[0].name) 415 self.assertEqual("s1_1/nested_1/dummy:0", v3[1].name) 416 417 def test_graph_function_no_name(self): 418 with context.eager_mode(): 419 420 def f(_, y): 421 return y + 1 422 423 partial = functools.partial(f, 1.0) 424 tmpl = template.make_template_internal( 425 "a", partial, create_graph_function_=True) 426 self.assertAllEqual(tmpl(ops.convert_to_tensor(1.0)), 2.0) 427 428 @test_util.run_in_graph_and_eager_modes() 429 def test_immediate_scope_creation(self): 430 # Create templates in scope a then call in scope b. make_template should 431 # capture the scope the first time it is called, and make_immediate_template 432 # should capture the scope at construction time. 433 with variable_scope.variable_scope("ctor_scope"): 434 # Create scope here: 435 tmpl_immed = template.make_template("a", variable_scoped_function, 436 True) 437 # default: create scope at __call__ 438 tmpl_defer = template.make_template( 439 "b", variable_scoped_function, False) 440 with variable_scope.variable_scope("call_scope"): 441 inner_imm_var = tmpl_immed() 442 inner_defer_var = tmpl_defer() 443 outer_imm_var = tmpl_immed() 444 outer_defer_var = tmpl_defer() 445 446 self.assertNotEqual(inner_imm_var, inner_defer_var) 447 self.assertEqual(outer_imm_var, inner_imm_var) 448 self.assertEqual(outer_defer_var, inner_defer_var) 449 450 self.assertEqual("ctor_scope/a/dummy:0", inner_imm_var.name) 451 self.assertEqual("call_scope/b/dummy:0", inner_defer_var.name) 452 453 @test_util.run_in_graph_and_eager_modes() 454 def test_scope_access(self): 455 # Ensure that we can access the scope inside the template, because the name 456 # of that scope may be different from the name we pass to make_template, due 457 # to having been made unique by variable_scope. 458 with variable_scope.variable_scope("foo"): 459 # Create two templates with the same name, ensure scopes are made unique. 460 ta = template.make_template("bar", variable_scoped_function, True) 461 tb = template.make_template("bar", variable_scoped_function, True) 462 463 # Ensure we can get the scopes before either template is actually called. 464 self.assertEqual(ta.variable_scope.name, "foo/bar") 465 self.assertEqual(tb.variable_scope.name, "foo/bar_1") 466 467 with variable_scope.variable_scope("foo_2"): 468 # Create a template which defers scope creation. 469 tc = template.make_template("blah", variable_scoped_function, False) 470 471 # Before we call the template, the scope property will be set to None. 472 self.assertEqual(tc.variable_scope, None) 473 tc() 474 475 # Template is called at the top level, so there is no preceding "foo_2". 476 self.assertEqual(tc.variable_scope.name, "blah") 477 478 @test_util.run_in_graph_and_eager_modes() 479 def test_custom_getter(self): 480 # Custom getter that maintains call count and forwards to true getter 481 custom_getter_count = [0] 482 483 def custom_getter(getter, name, *args, **kwargs): 484 custom_getter_count[0] += 1 485 return getter(name, *args, **kwargs) 486 487 # Test that custom getter is called both when variables are created and 488 # subsequently accessed 489 tmpl1 = template.make_template( 490 "s1", variable_scoped_function, custom_getter_=custom_getter) 491 self.assertEqual(custom_getter_count[0], 0) 492 tmpl1() 493 self.assertEqual(custom_getter_count[0], 1) 494 tmpl1() 495 self.assertEqual(custom_getter_count[0], 2) 496 497 # Test that custom getter is called when the variable scope is created 498 # during construction 499 custom_getter_count[0] = 0 500 tmpl2 = template.make_template( 501 "s2", 502 variable_scoped_function, 503 custom_getter_=custom_getter, 504 create_scope_now_=True) 505 self.assertEqual(custom_getter_count[0], 0) 506 tmpl2() 507 self.assertEqual(custom_getter_count[0], 1) 508 tmpl2() 509 self.assertEqual(custom_getter_count[0], 2) 510 511 @test_util.run_in_graph_and_eager_modes() 512 def test_fails_gracefully(self): 513 for create_scope_now in [True, False]: 514 def module_function_with_one_arg(inputs): 515 w = variable_scope.get_variable( 516 "w", shape=[1], initializer=init_ops.zeros_initializer()) 517 return inputs * w 518 519 templatized_function = template.make_template( 520 "f1", module_function_with_one_arg, 521 create_scope_now_=create_scope_now) 522 data = array_ops.zeros([1]) 523 try: 524 # Try to connect with a kwarg which is unsupported. 525 templatized_function(data, is_training=True) 526 except TypeError: 527 pass 528 529 # The failed __call__ hasn't modified the inner state. 530 self.assertFalse(templatized_function._variables_created) 531 templatized_function(data) 532 self.assertTrue(templatized_function._variables_created) 533 534 @test_util.run_in_graph_and_eager_modes() 535 def test_name_scopes_for_variable_scopes(self): 536 # Test that name scopes are not unnecessarily uniquified (but are 537 # still uniquified when necessary). 538 def linear_module(x, output_size): 539 w = variable_scope.get_variable( 540 "w", shape=[x.get_shape()[1], output_size], 541 initializer=init_ops.zeros_initializer()) 542 b = variable_scope.get_variable( 543 "b", shape=[output_size], 544 initializer=init_ops.zeros_initializer()) 545 return (math_ops.matmul(x, w) + b), w 546 547 def make_linear_module(output_size, name): 548 return template.make_template( 549 name, 550 linear_module, 551 output_size=output_size, 552 create_scope_now_=True) 553 554 inputs = array_ops.ones((3, 4)) 555 556 linear1 = make_linear_module(output_size=2, name="foo") 557 outputs_a, w1 = linear1(inputs) 558 outputs_b, _ = linear1(inputs) 559 self.assertEquals("foo", linear1.variable_scope.name) 560 self.assertEquals("foo/w:0", w1.name) 561 if context.in_graph_mode(): 562 self.assertEquals("foo/add:0", outputs_a.name, 563 "First application of template should get " 564 "same name scope as variables.") 565 self.assertEquals("foo_1/add:0", outputs_b.name, 566 "Second application of template should get " 567 "a freshly uniquified name scope.") 568 569 linear2 = make_linear_module(output_size=2, name="foo") 570 outputs_c, w2 = linear2(inputs) 571 outputs_d, _ = linear2(inputs) 572 self.assertEquals("foo_1", linear2.variable_scope.name, 573 "New template gets a freshly uniquified variable scope " 574 "because 'foo' is already taken.") 575 self.assertEquals("foo_1/w:0", w2.name) 576 if context.in_graph_mode(): 577 self.assertEquals("foo_1_1/add:0", outputs_c.name, 578 "First application of template would get " 579 "same name scope as variables, but 'foo_1' is already " 580 "a name scope.") 581 self.assertEquals("foo_1_2/add:0", outputs_d.name, 582 "Second application of template should also get " 583 "a freshly uniquified name scope.") 584 585 @test_util.run_in_graph_and_eager_modes() 586 def test_global_variables(self): 587 # Make sure global_variables are created. 588 with variable_scope.variable_scope("foo"): 589 # Create two templates with the same name, ensure scopes are made unique. 590 ta = template.make_template("bar", variable_scoped_function, True) 591 if context.in_eager_mode(): 592 tb = template.make_template("s", function_with_side_create, 593 trainable=False) 594 else: 595 tb = template.make_template("s", function_with_create, trainable=False) 596 597 # Initially there are not variables created. 598 self.assertEqual([], list(ta.global_variables)) 599 self.assertEqual([], list(tb.global_variables)) 600 # After calling there are variables created. 601 ta() 602 tb() 603 # Ensure we can get the scopes before either template is actually called. 604 self.assertEqual(1, len(ta.global_variables)) 605 self.assertEqual(2, len(tb.global_variables)) 606 607 @test_util.run_in_graph_and_eager_modes() 608 def test_trainable_variables(self): 609 # Make sure trainable_variables are created. 610 with variable_scope.variable_scope("foo2"): 611 # Create two templates with the same name, ensure scopes are made unique. 612 ta = template.make_template("bar", variable_scoped_function, True) 613 tb = template.make_template("bar", variable_scoped_function, True) 614 615 # Initially there are not variables created. 616 self.assertEqual([], list(ta.trainable_variables)) 617 self.assertEqual([], list(tb.trainable_variables)) 618 # After calling there are variables created. 619 ta() 620 tb() 621 # Ensure we can get the scopes before either template is actually called. 622 self.assertEqual(1, len(ta.trainable_variables)) 623 self.assertEqual(1, len(tb.trainable_variables)) 624 # None non-trainable variable was created. 625 self.assertEqual([], list(ta.non_trainable_variables)) 626 self.assertEqual([], list(tb.non_trainable_variables)) 627 # Ensure variables returns all the variables. 628 self.assertEqual(1, len(ta.variables)) 629 self.assertEqual(1, len(tb.variables)) 630 631 @test_util.run_in_graph_and_eager_modes() 632 def test_non_trainable_variables(self): 633 # Make sure non_trainable_variables are created. 634 with variable_scope.variable_scope("foo2"): 635 ta = template.make_template("a", variable_scoped_function, 636 trainable=True) 637 tb = template.make_template("b", variable_scoped_function, 638 trainable=False) 639 # Initially there are not variables created. 640 self.assertEqual([], list(ta.variables)) 641 self.assertEqual([], list(tb.variables)) 642 # After calling there are variables created. 643 ta() 644 tb() 645 # Check the trainable and non_trainable variables. 646 self.assertEqual(1, len(ta.trainable_variables)) 647 self.assertEqual([], list(ta.non_trainable_variables)) 648 649 self.assertEqual([], list(tb.trainable_variables)) 650 self.assertEqual(1, len(tb.non_trainable_variables)) 651 # Ensure variables returns all the variables. 652 self.assertEqual(1, len(ta.variables)) 653 self.assertEqual(1, len(tb.variables)) 654 655 # TODO(apassos) handle local variables in Eager 656 def test_local_variables(self): 657 # Make sure trainable_variables are created. 658 with variable_scope.variable_scope("foo3"): 659 # Create two templates with the same name, ensure scopes are made unique. 660 ta = template.make_template("bar", variable_scoped_function, True) 661 tb = template.make_template("bar", 662 variable_scoped_function_with_local_variable) 663 664 # Initially there are not variables created. 665 self.assertEqual([], list(ta.local_variables)) 666 self.assertEqual([], list(tb.local_variables)) 667 # After calling there are variables created. 668 ta() 669 tb() 670 # Ensure we can get the scopes before either template is actually called. 671 self.assertEqual(0, len(ta.local_variables)) 672 self.assertEqual(1, len(tb.local_variables)) 673 674 @test_util.run_in_graph_and_eager_modes() 675 def test_make_template_with_defun(self): 676 677 def variable_scoped_function_no_return_value(scope_name): 678 # defun cannot compile functions that return non-Tensor objects 679 with variable_scope.variable_scope(scope_name): 680 _ = variable_scope.get_variable( 681 "dummy", shape=[1], initializer=init_ops.zeros_initializer()) 682 683 tmpl = template.make_template_internal( 684 "s1", 685 variable_scoped_function_no_return_value, 686 create_graph_function_=True, 687 scope_name="test") 688 689 # The first invocation of tmpl1 creates variables, the second should 690 # be executed as a graph function. 691 tmpl() 692 v1 = tmpl.variables 693 tmpl() 694 v2 = tmpl.variables 695 696 self.assertSequenceEqual(v1, v2) 697 self.assertEqual("s1/test/dummy:0", v1[0].name) 698 699 700 if __name__ == "__main__": 701 test.main() 702