Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for make_template."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import functools
     21 import traceback
     22 
     23 from tensorflow.python.client import session
     24 from tensorflow.python.eager import context
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.framework import random_seed
     27 from tensorflow.python.framework import test_util
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import init_ops
     30 from tensorflow.python.ops import math_ops
     31 from tensorflow.python.ops import template
     32 from tensorflow.python.ops import variable_scope
     33 from tensorflow.python.ops import variables
     34 import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
     35 from tensorflow.python.platform import test
     36 from tensorflow.python.training import gradient_descent
     37 
     38 
     39 def variable_scoped_function(trainable=True):
     40   return variable_scope.get_variable(
     41       "dummy", shape=[1], trainable=trainable,
     42       initializer=init_ops.zeros_initializer())
     43 
     44 
     45 def internally_variable_scoped_function(scope_name):
     46   with variable_scope.variable_scope(scope_name):
     47     return variable_scope.get_variable(
     48         "dummy", shape=[1], initializer=init_ops.zeros_initializer())
     49 
     50 
     51 def function_with_create(trainable):
     52   """Creates a variable as a side effect using tf.Variable."""
     53   variables.Variable(0, trainable=trainable)
     54   return variable_scope.get_variable(
     55       "dummy", shape=[1], initializer=init_ops.zeros_initializer())
     56 
     57 
     58 def function_with_side_create(trainable, name="side"):
     59   """Creates a variable as a side effect using tf.get_variable."""
     60   variable_scope.get_variable(name, shape=[1], trainable=trainable)
     61   return variable_scope.get_variable(
     62       "dummy", shape=[1], initializer=init_ops.zeros_initializer())
     63 
     64 
     65 def variable_scoped_function_with_local_variable():
     66   variable_scope.get_local_variable(
     67       "local", shape=[1], initializer=init_ops.zeros_initializer())
     68   return variable_scope.get_variable(
     69       "dummy", shape=[1], initializer=init_ops.zeros_initializer())
     70 
     71 
     72 class TemplateTest(test.TestCase):
     73 
     74   def test_end_to_end(self):
     75     """This test shows a very simple line model with test_loss.
     76 
     77     The template is used to share parameters between a training and test model.
     78     """
     79     # y = 2x + 1
     80     training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7])
     81     test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17])
     82 
     83     random_seed.set_random_seed(1234)
     84 
     85     def test_line(x):
     86       m = variable_scope.get_variable(
     87           "w", shape=[], initializer=init_ops.truncated_normal_initializer())
     88       b = variable_scope.get_variable(
     89           "b", shape=[], initializer=init_ops.truncated_normal_initializer())
     90       return x * m + b
     91 
     92     line_template = template.make_template("line", test_line)
     93 
     94     train_prediction = line_template(training_input)
     95     test_prediction = line_template(test_input)
     96 
     97     train_loss = math_ops.reduce_mean(
     98         math_ops.square(train_prediction - training_output))
     99     test_loss = math_ops.reduce_mean(
    100         math_ops.square(test_prediction - test_output))
    101 
    102     optimizer = gradient_descent.GradientDescentOptimizer(0.1)
    103     train_op = optimizer.minimize(train_loss)
    104 
    105     with session.Session() as sess:
    106       sess.run(variables.global_variables_initializer())
    107       initial_test_loss = sess.run(test_loss)
    108       sess.run(train_op)
    109       final_test_loss = sess.run(test_loss)
    110 
    111     # Parameters are tied, so the loss should have gone down when we trained it.
    112     self.assertLess(final_test_loss, initial_test_loss)
    113 
    114   def test_end_to_end_eager(self):
    115     """This test shows a very simple line model with test_loss in eager mode.
    116 
    117     The template is used to share parameters between a training and test model.
    118     """
    119     with context.eager_mode():
    120       # y = 2x + 1
    121       training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7])
    122       test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17])
    123 
    124       random_seed.set_random_seed(1234)
    125 
    126       def test_line(x):
    127         m = variable_scope.get_variable(
    128             "w", shape=[], initializer=init_ops.truncated_normal_initializer())
    129         b = variable_scope.get_variable(
    130             "b", shape=[], initializer=init_ops.truncated_normal_initializer())
    131         return x * m + b
    132 
    133       line_template = template.make_template("line", test_line)
    134 
    135       def train_loss():
    136         train_prediction = line_template(training_input)
    137         return math_ops.reduce_mean(
    138             math_ops.square(train_prediction - training_output))
    139 
    140       def test_loss():
    141         test_prediction = line_template(test_input)
    142         return math_ops.reduce_mean(
    143             math_ops.square(test_prediction - test_output))
    144 
    145       optimizer = gradient_descent.GradientDescentOptimizer(0.1)
    146       initial_test_loss = test_loss()
    147       optimizer.minimize(train_loss)
    148       final_test_loss = test_loss()
    149 
    150       # Parameters are tied, so the loss should have gone down after training.
    151       self.assertLess(final_test_loss.numpy(), initial_test_loss.numpy())
    152 
    153   @test_util.run_in_graph_and_eager_modes()
    154   def test_skip_stack_frames(self):
    155     first = traceback.format_stack()
    156     second = traceback.format_stack()
    157     result = template._skip_common_stack_elements(first, second)
    158     self.assertEqual(1, len(result))
    159     self.assertNotEqual(len(first), len(result))
    160 
    161   @test_util.run_in_graph_and_eager_modes()
    162   def test_template_with_name(self):
    163     tmpl1 = template.make_template("s1", variable_scoped_function)
    164     tmpl2 = template.make_template("s1", variable_scoped_function)
    165 
    166     v1 = tmpl1()
    167     v2 = tmpl1()
    168     v3 = tmpl2()
    169     self.assertEqual(v1, v2)
    170     self.assertNotEqual(v1, v3)
    171     self.assertEqual("s1/dummy:0", v1.name)
    172     self.assertEqual("s1_1/dummy:0", v3.name)
    173 
    174   def test_same_unique_name_raise_error(self):
    175     tmpl1 = template.make_template(
    176         "_", variable_scoped_function, unique_name_="s1")
    177     tmpl1()
    178     tmpl2 = template.make_template(
    179         "_", variable_scoped_function, unique_name_="s1")
    180     with self.assertRaisesRegexp(
    181         ValueError, "Variable s1/dummy already exists, disallowed.*"):
    182       tmpl2()
    183 
    184   def test_unique_name_raise_error_in_eager(self):
    185     with context.eager_mode():
    186       with self.assertRaisesRegexp(
    187           ValueError,
    188           "unique_name_ cannot be used when eager exeuction is enabled."):
    189         template.make_template(
    190             "_", variable_scoped_function, unique_name_="s1")
    191 
    192   def test_unique_name_and_reuse(self):
    193     tmpl1 = template.make_template(
    194         "_", variable_scoped_function, unique_name_="s1")
    195     v1 = tmpl1()
    196     v2 = tmpl1()
    197 
    198     variable_scope.get_variable_scope().reuse_variables()
    199     tmpl2 = template.make_template(
    200         "_", variable_scoped_function, unique_name_="s1")
    201     v3 = tmpl2()
    202 
    203     self.assertEqual(v1, v2)
    204     self.assertEqual(v1, v3)
    205     self.assertEqual("s1/dummy:0", v1.name)
    206 
    207   @test_util.run_in_graph_and_eager_modes()
    208   def test_template_in_scope(self):
    209     tmpl1 = template.make_template("s1", variable_scoped_function)
    210     tmpl2 = template.make_template("s1", variable_scoped_function)
    211 
    212     with variable_scope.variable_scope("scope"):
    213       v1 = tmpl1()
    214       v3 = tmpl2()
    215 
    216     # The template contract requires the following to ignore scope2.
    217     with variable_scope.variable_scope("scope2"):
    218       v2 = tmpl1()
    219     self.assertEqual(v1, v2)
    220     self.assertNotEqual(v1, v3)
    221     self.assertEqual("scope/s1/dummy:0", v1.name)
    222     self.assertEqual("scope/s1_1/dummy:0", v3.name)
    223 
    224   @test_util.run_in_graph_and_eager_modes()
    225   def test_template_with_internal_reuse(self):
    226     tmpl1 = template.make_template("s1", internally_variable_scoped_function)
    227     tmpl2 = template.make_template("s1", internally_variable_scoped_function)
    228 
    229     v1 = tmpl1("test")
    230     v2 = tmpl1("test")
    231     v3 = tmpl2("test")
    232     self.assertEqual(v1, v2)
    233     self.assertNotEqual(v1, v3)
    234     self.assertEqual("s1/test/dummy:0", v1.name)
    235     self.assertEqual("s1_1/test/dummy:0", v3.name)
    236 
    237     with self.assertRaises(ValueError):
    238       tmpl1("not_test")
    239 
    240   @test_util.run_in_graph_and_eager_modes()
    241   def test_template_without_name(self):
    242     with self.assertRaisesRegexp(
    243         ValueError, "name cannot be None."):
    244       template.make_template(None, variable_scoped_function)
    245 
    246   @test_util.run_in_graph_and_eager_modes()
    247   def test_make_template(self):
    248     # Test both that we can call it with positional and keywords.
    249     tmpl1 = template.make_template(
    250         "s1", internally_variable_scoped_function, scope_name="test")
    251     tmpl2 = template.make_template(
    252         "s1", internally_variable_scoped_function, scope_name="test")
    253 
    254     v1 = tmpl1()
    255     v2 = tmpl1()
    256     v3 = tmpl2()
    257     self.assertEqual(v1, v2)
    258     self.assertNotEqual(v1, v3)
    259     self.assertEqual("s1/test/dummy:0", v1.name)
    260     self.assertEqual("s1_1/test/dummy:0", v3.name)
    261 
    262   def test_enforces_no_extra_trainable_variables(self):
    263     tmpl = template.make_template("s", function_with_create, trainable=True)
    264 
    265     tmpl()
    266     with self.assertRaises(ValueError):
    267       tmpl()
    268 
    269   @test_util.run_in_graph_and_eager_modes()
    270   def test_enforces_no_extra_trainable_variables_eager(self):
    271     tmpl = template.make_template("s",
    272                                   function_with_side_create,
    273                                   trainable=True)
    274 
    275     tmpl(name="1")
    276     with self.assertRaises(ValueError):
    277       tmpl(name="2")
    278 
    279   def test_permits_extra_non_trainable_variables(self):
    280     tmpl = template.make_template("s", function_with_create, trainable=False)
    281     self.assertEqual(tmpl(), tmpl())
    282 
    283   def test_permits_extra_non_trainable_variables_eager(self):
    284     with context.eager_mode():
    285       tmpl = template.make_template("s",
    286                                     function_with_side_create,
    287                                     trainable=False)
    288       self.assertEqual(tmpl(name="1"), tmpl(name="2"))
    289 
    290   @test_util.run_in_graph_and_eager_modes()
    291   def test_internal_variable_reuse(self):
    292 
    293     def nested():
    294       with variable_scope.variable_scope("nested") as vs:
    295         v1 = variable_scope.get_variable(
    296             "x", initializer=init_ops.zeros_initializer(), shape=[])
    297       with variable_scope.variable_scope(vs, reuse=True):
    298         v2 = variable_scope.get_variable("x")
    299       self.assertEqual(v1, v2)
    300       return v1
    301 
    302     tmpl1 = template.make_template("s1", nested)
    303     tmpl2 = template.make_template("s1", nested)
    304 
    305     v1 = tmpl1()
    306     v2 = tmpl1()
    307     v3 = tmpl2()
    308     self.assertEqual(v1, v2)
    309     self.assertNotEqual(v1, v3)
    310     self.assertEqual("s1/nested/x:0", v1.name)
    311     self.assertEqual("s1_1/nested/x:0", v3.name)
    312 
    313   @test_util.run_in_graph_and_eager_modes()
    314   def test_nested_templates(self):
    315 
    316     def nested_template():
    317       nested1 = template.make_template("nested", variable_scoped_function)
    318       nested2 = template.make_template("nested", variable_scoped_function)
    319       v1 = nested1()
    320       v2 = nested2()
    321 
    322       # nested1 and nested2 should not share variables
    323       self.assertNotEqual(v1, v2)
    324 
    325       # Variables created by nested1 should be isolated from variables
    326       # created by nested2.
    327       self.assertEqual(nested1.variables, [v1])
    328       self.assertEqual(nested2.variables, [v2])
    329       self.assertEqual(nested1.trainable_variables, [v1])
    330       self.assertEqual(nested2.trainable_variables, [v2])
    331       self.assertEqual(len(nested1.non_trainable_variables), 0)
    332       self.assertEqual(len(nested2.non_trainable_variables), 0)
    333       return v1, v2
    334 
    335     tmpl1 = template.make_template("s1", nested_template)
    336     tmpl2 = template.make_template("s1", nested_template)
    337 
    338     v1, v2 = tmpl1()
    339     v3, v4 = tmpl1()
    340     v5, v6 = tmpl2()
    341 
    342     # The second invocation of tmpl1 should reuse the variables
    343     # created in the first invocation.
    344     self.assertEqual([v1, v2], [v3, v4])
    345     self.assertEqual(tmpl1.variables, [v1, v2])
    346     self.assertEqual(tmpl1.trainable_variables, [v1, v2])
    347     self.assertEqual(len(tmpl1.non_trainable_variables), 0)
    348 
    349     # tmpl1 and tmpl2 should not share variables.
    350     self.assertNotEqual([v1, v2], [v5, v6])
    351     self.assertSequenceEqual(tmpl2.variables, [v5, v6])
    352     self.assertSequenceEqual(tmpl2.trainable_variables, [v5, v6])
    353     self.assertEqual(len(tmpl2.non_trainable_variables), 0)
    354     self.assertEqual("s1/nested/dummy:0", v1.name)
    355     self.assertEqual("s1/nested_1/dummy:0", v2.name)
    356     self.assertEqual("s1_1/nested/dummy:0", v5.name)
    357     self.assertEqual("s1_1/nested_1/dummy:0", v6.name)
    358 
    359   @test_util.run_in_graph_and_eager_modes()
    360   def test_nested_templates_with_defun(self):
    361 
    362     def variable_scoped_function_no_return_value(trainable=True):
    363       # defun cannot compile functions that return non-Tensor objects
    364       _ = variable_scope.get_variable(
    365           "dummy",
    366           shape=[1],
    367           trainable=trainable,
    368           initializer=init_ops.zeros_initializer())
    369 
    370     def nested_template():
    371       nested1 = template.make_template_internal(
    372           "nested",
    373           variable_scoped_function_no_return_value,
    374           create_graph_function_=True)
    375       nested2 = template.make_template_internal(
    376           "nested",
    377           variable_scoped_function_no_return_value,
    378           create_graph_function_=True)
    379       nested1()
    380       nested2()
    381       v1 = nested1.variables
    382       v2 = nested2.variables
    383 
    384       # nested1 and nested2 should not share variables
    385       self.assertNotEqual(v1, v2)
    386 
    387       # Variables created by nested1 should be isolated from variables
    388       # created by nested2.
    389       self.assertEqual(nested1.variables, v1)
    390       self.assertEqual(nested2.variables, v2)
    391       self.assertEqual(nested1.trainable_variables, v1)
    392       self.assertEqual(nested2.trainable_variables, v2)
    393       self.assertEqual(len(nested1.non_trainable_variables), 0)
    394       self.assertEqual(len(nested2.non_trainable_variables), 0)
    395 
    396     tmpl1 = template.make_template("s1", nested_template)
    397     tmpl2 = template.make_template("s1", nested_template)
    398 
    399     tmpl1()
    400     v1 = tmpl1.variables
    401     tmpl1()
    402     v2 = tmpl1.variables
    403     tmpl2()
    404     v3 = tmpl2.variables
    405 
    406     # The second invocation of tmpl1 should reuse the variables
    407     # created in the first invocation.
    408     self.assertSequenceEqual(v1, v2)
    409 
    410     # tmpl1 and tmpl2 should not share variables.
    411     self.assertNotEqual(v1, v3)
    412     self.assertEqual("s1/nested/dummy:0", v1[0].name)
    413     self.assertEqual("s1/nested_1/dummy:0", v1[1].name)
    414     self.assertEqual("s1_1/nested/dummy:0", v3[0].name)
    415     self.assertEqual("s1_1/nested_1/dummy:0", v3[1].name)
    416 
    417   def test_graph_function_no_name(self):
    418     with context.eager_mode():
    419 
    420       def f(_, y):
    421         return y + 1
    422 
    423       partial = functools.partial(f, 1.0)
    424       tmpl = template.make_template_internal(
    425           "a", partial, create_graph_function_=True)
    426       self.assertAllEqual(tmpl(ops.convert_to_tensor(1.0)), 2.0)
    427 
    428   @test_util.run_in_graph_and_eager_modes()
    429   def test_immediate_scope_creation(self):
    430     # Create templates in scope a then call in scope b. make_template should
    431     # capture the scope the first time it is called, and make_immediate_template
    432     # should capture the scope at construction time.
    433     with variable_scope.variable_scope("ctor_scope"):
    434       # Create scope here:
    435       tmpl_immed = template.make_template("a", variable_scoped_function,
    436                                           True)
    437       # default: create scope at __call__
    438       tmpl_defer = template.make_template(
    439           "b", variable_scoped_function, False)
    440     with variable_scope.variable_scope("call_scope"):
    441       inner_imm_var = tmpl_immed()
    442       inner_defer_var = tmpl_defer()
    443     outer_imm_var = tmpl_immed()
    444     outer_defer_var = tmpl_defer()
    445 
    446     self.assertNotEqual(inner_imm_var, inner_defer_var)
    447     self.assertEqual(outer_imm_var, inner_imm_var)
    448     self.assertEqual(outer_defer_var, inner_defer_var)
    449 
    450     self.assertEqual("ctor_scope/a/dummy:0", inner_imm_var.name)
    451     self.assertEqual("call_scope/b/dummy:0", inner_defer_var.name)
    452 
    453   @test_util.run_in_graph_and_eager_modes()
    454   def test_scope_access(self):
    455     # Ensure that we can access the scope inside the template, because the name
    456     # of that scope may be different from the name we pass to make_template, due
    457     # to having been made unique by variable_scope.
    458     with variable_scope.variable_scope("foo"):
    459       # Create two templates with the same name, ensure scopes are made unique.
    460       ta = template.make_template("bar", variable_scoped_function, True)
    461       tb = template.make_template("bar", variable_scoped_function, True)
    462 
    463     # Ensure we can get the scopes before either template is actually called.
    464     self.assertEqual(ta.variable_scope.name, "foo/bar")
    465     self.assertEqual(tb.variable_scope.name, "foo/bar_1")
    466 
    467     with variable_scope.variable_scope("foo_2"):
    468       # Create a template which defers scope creation.
    469       tc = template.make_template("blah", variable_scoped_function, False)
    470 
    471     # Before we call the template, the scope property will be set to None.
    472     self.assertEqual(tc.variable_scope, None)
    473     tc()
    474 
    475     # Template is called at the top level, so there is no preceding "foo_2".
    476     self.assertEqual(tc.variable_scope.name, "blah")
    477 
    478   @test_util.run_in_graph_and_eager_modes()
    479   def test_custom_getter(self):
    480     # Custom getter that maintains call count and forwards to true getter
    481     custom_getter_count = [0]
    482 
    483     def custom_getter(getter, name, *args, **kwargs):
    484       custom_getter_count[0] += 1
    485       return getter(name, *args, **kwargs)
    486 
    487     # Test that custom getter is called both when variables are created and
    488     # subsequently accessed
    489     tmpl1 = template.make_template(
    490         "s1", variable_scoped_function, custom_getter_=custom_getter)
    491     self.assertEqual(custom_getter_count[0], 0)
    492     tmpl1()
    493     self.assertEqual(custom_getter_count[0], 1)
    494     tmpl1()
    495     self.assertEqual(custom_getter_count[0], 2)
    496 
    497     # Test that custom getter is called when the variable scope is created
    498     # during construction
    499     custom_getter_count[0] = 0
    500     tmpl2 = template.make_template(
    501         "s2",
    502         variable_scoped_function,
    503         custom_getter_=custom_getter,
    504         create_scope_now_=True)
    505     self.assertEqual(custom_getter_count[0], 0)
    506     tmpl2()
    507     self.assertEqual(custom_getter_count[0], 1)
    508     tmpl2()
    509     self.assertEqual(custom_getter_count[0], 2)
    510 
    511   @test_util.run_in_graph_and_eager_modes()
    512   def test_fails_gracefully(self):
    513     for create_scope_now in [True, False]:
    514       def module_function_with_one_arg(inputs):
    515         w = variable_scope.get_variable(
    516             "w", shape=[1], initializer=init_ops.zeros_initializer())
    517         return inputs * w
    518 
    519       templatized_function = template.make_template(
    520           "f1", module_function_with_one_arg,
    521           create_scope_now_=create_scope_now)
    522       data = array_ops.zeros([1])
    523       try:
    524         # Try to connect with a kwarg which is unsupported.
    525         templatized_function(data, is_training=True)
    526       except TypeError:
    527         pass
    528 
    529       # The failed __call__ hasn't modified the inner state.
    530       self.assertFalse(templatized_function._variables_created)
    531       templatized_function(data)
    532       self.assertTrue(templatized_function._variables_created)
    533 
    534   @test_util.run_in_graph_and_eager_modes()
    535   def test_name_scopes_for_variable_scopes(self):
    536     # Test that name scopes are not unnecessarily uniquified (but are
    537     # still uniquified when necessary).
    538     def linear_module(x, output_size):
    539       w = variable_scope.get_variable(
    540           "w", shape=[x.get_shape()[1], output_size],
    541           initializer=init_ops.zeros_initializer())
    542       b = variable_scope.get_variable(
    543           "b", shape=[output_size],
    544           initializer=init_ops.zeros_initializer())
    545       return (math_ops.matmul(x, w) + b), w
    546 
    547     def make_linear_module(output_size, name):
    548       return template.make_template(
    549           name,
    550           linear_module,
    551           output_size=output_size,
    552           create_scope_now_=True)
    553 
    554     inputs = array_ops.ones((3, 4))
    555 
    556     linear1 = make_linear_module(output_size=2, name="foo")
    557     outputs_a, w1 = linear1(inputs)
    558     outputs_b, _ = linear1(inputs)
    559     self.assertEquals("foo", linear1.variable_scope.name)
    560     self.assertEquals("foo/w:0", w1.name)
    561     if context.in_graph_mode():
    562       self.assertEquals("foo/add:0", outputs_a.name,
    563                         "First application of template should get "
    564                         "same name scope as variables.")
    565       self.assertEquals("foo_1/add:0", outputs_b.name,
    566                         "Second application of template should get "
    567                         "a freshly uniquified name scope.")
    568 
    569     linear2 = make_linear_module(output_size=2, name="foo")
    570     outputs_c, w2 = linear2(inputs)
    571     outputs_d, _ = linear2(inputs)
    572     self.assertEquals("foo_1", linear2.variable_scope.name,
    573                       "New template gets a freshly uniquified variable scope "
    574                       "because 'foo' is already taken.")
    575     self.assertEquals("foo_1/w:0", w2.name)
    576     if context.in_graph_mode():
    577       self.assertEquals("foo_1_1/add:0", outputs_c.name,
    578                         "First application of template would get "
    579                         "same name scope as variables, but 'foo_1' is already "
    580                         "a name scope.")
    581       self.assertEquals("foo_1_2/add:0", outputs_d.name,
    582                         "Second application of template should also get "
    583                         "a freshly uniquified name scope.")
    584 
    585   @test_util.run_in_graph_and_eager_modes()
    586   def test_global_variables(self):
    587     # Make sure global_variables are created.
    588     with variable_scope.variable_scope("foo"):
    589       # Create two templates with the same name, ensure scopes are made unique.
    590       ta = template.make_template("bar", variable_scoped_function, True)
    591       if context.in_eager_mode():
    592         tb = template.make_template("s", function_with_side_create,
    593                                     trainable=False)
    594       else:
    595         tb = template.make_template("s", function_with_create, trainable=False)
    596 
    597     # Initially there are not variables created.
    598     self.assertEqual([], list(ta.global_variables))
    599     self.assertEqual([], list(tb.global_variables))
    600     # After calling there are variables created.
    601     ta()
    602     tb()
    603     # Ensure we can get the scopes before either template is actually called.
    604     self.assertEqual(1, len(ta.global_variables))
    605     self.assertEqual(2, len(tb.global_variables))
    606 
    607   @test_util.run_in_graph_and_eager_modes()
    608   def test_trainable_variables(self):
    609     # Make sure trainable_variables are created.
    610     with variable_scope.variable_scope("foo2"):
    611       # Create two templates with the same name, ensure scopes are made unique.
    612       ta = template.make_template("bar", variable_scoped_function, True)
    613       tb = template.make_template("bar", variable_scoped_function, True)
    614 
    615     # Initially there are not variables created.
    616     self.assertEqual([], list(ta.trainable_variables))
    617     self.assertEqual([], list(tb.trainable_variables))
    618     # After calling there are variables created.
    619     ta()
    620     tb()
    621     # Ensure we can get the scopes before either template is actually called.
    622     self.assertEqual(1, len(ta.trainable_variables))
    623     self.assertEqual(1, len(tb.trainable_variables))
    624     # None non-trainable variable was created.
    625     self.assertEqual([], list(ta.non_trainable_variables))
    626     self.assertEqual([], list(tb.non_trainable_variables))
    627     # Ensure variables returns all the variables.
    628     self.assertEqual(1, len(ta.variables))
    629     self.assertEqual(1, len(tb.variables))
    630 
    631   @test_util.run_in_graph_and_eager_modes()
    632   def test_non_trainable_variables(self):
    633     # Make sure non_trainable_variables are created.
    634     with variable_scope.variable_scope("foo2"):
    635       ta = template.make_template("a", variable_scoped_function,
    636                                   trainable=True)
    637       tb = template.make_template("b", variable_scoped_function,
    638                                   trainable=False)
    639     # Initially there are not variables created.
    640     self.assertEqual([], list(ta.variables))
    641     self.assertEqual([], list(tb.variables))
    642     # After calling there are variables created.
    643     ta()
    644     tb()
    645     # Check the trainable and non_trainable variables.
    646     self.assertEqual(1, len(ta.trainable_variables))
    647     self.assertEqual([], list(ta.non_trainable_variables))
    648 
    649     self.assertEqual([], list(tb.trainable_variables))
    650     self.assertEqual(1, len(tb.non_trainable_variables))
    651     # Ensure variables returns all the variables.
    652     self.assertEqual(1, len(ta.variables))
    653     self.assertEqual(1, len(tb.variables))
    654 
    655   # TODO(apassos) handle local variables in Eager
    656   def test_local_variables(self):
    657     # Make sure trainable_variables are created.
    658     with variable_scope.variable_scope("foo3"):
    659       # Create two templates with the same name, ensure scopes are made unique.
    660       ta = template.make_template("bar", variable_scoped_function, True)
    661       tb = template.make_template("bar",
    662                                   variable_scoped_function_with_local_variable)
    663 
    664     # Initially there are not variables created.
    665     self.assertEqual([], list(ta.local_variables))
    666     self.assertEqual([], list(tb.local_variables))
    667     # After calling there are variables created.
    668     ta()
    669     tb()
    670     # Ensure we can get the scopes before either template is actually called.
    671     self.assertEqual(0, len(ta.local_variables))
    672     self.assertEqual(1, len(tb.local_variables))
    673 
    674   @test_util.run_in_graph_and_eager_modes()
    675   def test_make_template_with_defun(self):
    676 
    677     def variable_scoped_function_no_return_value(scope_name):
    678       # defun cannot compile functions that return non-Tensor objects
    679       with variable_scope.variable_scope(scope_name):
    680         _ = variable_scope.get_variable(
    681             "dummy", shape=[1], initializer=init_ops.zeros_initializer())
    682 
    683     tmpl = template.make_template_internal(
    684         "s1",
    685         variable_scoped_function_no_return_value,
    686         create_graph_function_=True,
    687         scope_name="test")
    688 
    689     # The first invocation of tmpl1 creates variables, the second should
    690     # be executed as a graph function.
    691     tmpl()
    692     v1 = tmpl.variables
    693     tmpl()
    694     v2 = tmpl.variables
    695 
    696     self.assertSequenceEqual(v1, v2)
    697     self.assertEqual("s1/test/dummy:0", v1[0].name)
    698 
    699 
    700 if __name__ == "__main__":
    701   test.main()
    702