Home | History | Annotate | Download | only in layers
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tf.layers.base."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import copy
     22 
     23 from tensorflow.python.eager import context
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.framework import test_util
     28 from tensorflow.python.layers import base as base_layers
     29 from tensorflow.python.layers import core as core_layers
     30 from tensorflow.python.ops import array_ops
     31 from tensorflow.python.ops import init_ops
     32 from tensorflow.python.ops import math_ops
     33 from tensorflow.python.ops import random_ops
     34 from tensorflow.python.ops import state_ops
     35 from tensorflow.python.ops import variable_scope
     36 from tensorflow.python.platform import test
     37 
     38 
     39 class BaseLayerTest(test.TestCase):
     40 
     41   @test_util.run_in_graph_and_eager_modes()
     42   def testLayerProperties(self):
     43     layer = base_layers.Layer(name='my_layer')
     44     self.assertEqual(layer.variables, [])
     45     self.assertEqual(layer.trainable_variables, [])
     46     self.assertEqual(layer.non_trainable_variables, [])
     47     if context.in_graph_mode():
     48       # updates, losses only supported in GRAPH mode
     49       self.assertEqual(layer.updates, [])
     50       self.assertEqual(layer.losses, [])
     51     self.assertEqual(layer.built, False)
     52     layer = base_layers.Layer(name='my_layer', trainable=False)
     53     self.assertEqual(layer.trainable, False)
     54 
     55   @test_util.run_in_graph_and_eager_modes()
     56   def testAddWeight(self):
     57     layer = base_layers.Layer(name='my_layer')
     58 
     59     # Test basic variable creation.
     60     variable = layer.add_variable(
     61         'my_var', [2, 2], initializer=init_ops.zeros_initializer())
     62     self.assertEqual(variable.name, 'my_layer/my_var:0')
     63     self.assertEqual(layer.variables, [variable])
     64     self.assertEqual(layer.trainable_variables, [variable])
     65     self.assertEqual(layer.non_trainable_variables, [])
     66     if context.in_graph_mode():
     67       self.assertEqual(
     68           layer.variables,
     69           ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
     70 
     71     # Test non-trainable variable creation.
     72     # layer.add_variable should work even outside `build` and `call`.
     73     variable_2 = layer.add_variable(
     74         'non_trainable_var', [2, 2],
     75         initializer=init_ops.zeros_initializer(),
     76         trainable=False)
     77     self.assertEqual(layer.variables, [variable, variable_2])
     78     self.assertEqual(layer.trainable_variables, [variable])
     79     self.assertEqual(layer.non_trainable_variables, [variable_2])
     80     if context.in_graph_mode():
     81       self.assertEqual(
     82           len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)
     83 
     84       # regularizers only supported in GRAPH mode.
     85       regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
     86       variable = layer.add_variable(
     87           'reg_var', [2, 2],
     88           initializer=init_ops.zeros_initializer(),
     89           regularizer=regularizer)
     90       self.assertEqual(len(layer.losses), 1)
     91 
     92   def testNoEagerActivityRegularizer(self):
     93     with context.eager_mode():
     94       with self.assertRaisesRegexp(ValueError, 'activity_regularizer'):
     95         core_layers.Dense(1, activity_regularizer=lambda *args, **kwargs: 0.)
     96 
     97   def testGetVariable(self):
     98     with self.test_session():
     99 
    100       class MyLayer(base_layers.Layer):
    101 
    102         def build(self, input_shape):
    103           self.my_var = self.add_variable(
    104               'my_var', [2, 2], initializer=init_ops.zeros_initializer())
    105 
    106         def call(self, inputs):
    107           return inputs * 2
    108 
    109       layer = MyLayer(name='my_layer')
    110       inputs = random_ops.random_uniform((5,), seed=1)
    111       layer.apply(inputs)
    112       layer.apply(inputs)
    113       self.assertEqual([v.name for v in layer.variables],
    114                        ['my_layer/my_var:0'])
    115 
    116       # Creating a layer with no scope leads to lazy construction of
    117       # the scope at apply() time.  It uses scope "<current scope>/base_name"
    118       lazy_layer = MyLayer(_reuse=True)
    119       with variable_scope.variable_scope('new_scope'):
    120         with variable_scope.variable_scope('my_layer'):
    121           variable_scope.get_variable('my_var', [2, 2])
    122 
    123         # Smoke test: it runs.
    124         lazy_layer.apply(inputs)
    125         # The variables were created outside of the Layer, and
    126         # reuse=True, so the Layer does not own them and they are not
    127         # stored in its collection.
    128         self.assertEqual(lazy_layer.variables, [])
    129         self.assertEqual(lazy_layer._scope.name, 'new_scope/my_layer')
    130 
    131       # Creating a layer with no scope leads to lazy construction of
    132       # the scope at apply() time. If 'scope' argument is passed to
    133       # apply(), it uses that scope when accessing variables.
    134       lazy_layer = MyLayer(_reuse=True)
    135       with variable_scope.variable_scope('new_scope') as new_scope:
    136         variable_scope.get_variable('my_var', [2, 2])
    137 
    138         # Smoke test: it runs.
    139         lazy_layer.apply(inputs, scope=new_scope)
    140         # The variables were created outside of the Layer, and
    141         # reuse=True, so the Layer does not own them and they are not
    142         # stored in its collection.
    143         self.assertEqual(lazy_layer.variables, [])
    144         self.assertEqual(lazy_layer._scope.name, 'new_scope')
    145 
    146       # Checking for graph equality is only done in GRAPH mode.
    147       with ops.Graph().as_default():
    148         inputs_ng = random_ops.random_uniform((5,), seed=1)
    149         with self.assertRaisesRegexp(ValueError, r'graph are not the same'):
    150           layer.apply(inputs_ng)
    151 
    152   @test_util.run_in_graph_and_eager_modes()
    153   def testCall(self):
    154 
    155     class MyLayer(base_layers.Layer):
    156 
    157       def call(self, inputs):
    158         return math_ops.square(inputs)
    159 
    160     layer = MyLayer(name='my_layer')
    161     inputs = random_ops.random_uniform((5,), seed=1)
    162     outputs = layer.apply(inputs)
    163     self.assertEqual(layer.built, True)
    164     if context.in_graph_mode():
    165       # op is only supported in GRAPH mode
    166       self.assertEqual(outputs.op.name, 'my_layer/Square')
    167 
    168   def testFirstCallCanCreateVariablesButSecondCanNotWhenBuildEmpty(self):
    169     # Note that this test is only run in Graph mode since with EAGER mode we can
    170     # still create a new variable on second call.
    171 
    172     class MyLayer(base_layers.Layer):
    173 
    174       def build(self, _):
    175         # Do not mark the layer as built.
    176         pass
    177 
    178       def call(self, inputs):
    179         self.my_var = self.add_variable('my_var', [2, 2])
    180         if self.built:
    181           # Skip creating on the first call; try to create after it's
    182           # built.  This is expected to fail.
    183           self.add_variable('this_will_break_on_second_call', [2, 2])
    184         return inputs + math_ops.square(self.my_var)
    185 
    186     layer = MyLayer(name='my_layer')
    187     inputs = random_ops.random_uniform((2,), seed=1)
    188     outputs = layer.apply(inputs)
    189     self.assertEqual(layer.built, True)
    190     self.assertEqual(outputs.op.name, 'my_layer/add')
    191     self.assertEqual([v.name
    192                       for v in layer.variables], ['my_layer/my_var:0'])
    193     with self.assertRaisesRegexp(ValueError,
    194                                  'my_layer/this_will_break_on_second_call'):
    195       layer.apply(inputs)
    196     # The list of variables hasn't changed.
    197     self.assertEqual([v.name
    198                       for v in layer.variables], ['my_layer/my_var:0'])
    199 
    200   @test_util.run_in_graph_and_eager_modes()
    201   def testDeepCopy(self):
    202 
    203     class MyLayer(base_layers.Layer):
    204 
    205       def call(self, inputs):
    206         return math_ops.square(inputs)
    207 
    208     layer = MyLayer(name='my_layer')
    209     layer._private_tensor = random_ops.random_uniform(())
    210     inputs = random_ops.random_uniform((5,), seed=1)
    211     outputs = layer.apply(inputs)
    212     self.assertEqual(layer.built, True)
    213     if context.in_graph_mode():
    214       # op only supported in GRAPH mode.
    215       self.assertEqual(outputs.op.name, 'my_layer/Square')
    216 
    217     layer_copy = copy.deepcopy(layer)
    218     self.assertEqual(layer_copy.name, layer.name)
    219     self.assertEqual(layer_copy._scope.name, layer._scope.name)
    220     self.assertEqual(layer_copy._graph, layer._graph)
    221     self.assertEqual(layer_copy._private_tensor, layer._private_tensor)
    222 
    223   @test_util.run_in_graph_and_eager_modes()
    224   def testScopeNaming(self):
    225 
    226     class PrivateLayer(base_layers.Layer):
    227 
    228       def call(self, inputs):
    229         return inputs
    230 
    231     inputs = random_ops.random_uniform((5,))
    232     default_layer = PrivateLayer()
    233     _ = default_layer.apply(inputs)
    234     self.assertEqual(default_layer._scope.name, 'private_layer')
    235     default_layer1 = PrivateLayer()
    236     default_layer1.apply(inputs)
    237     self.assertEqual(default_layer1._scope.name, 'private_layer_1')
    238     my_layer = PrivateLayer(name='my_layer')
    239     my_layer.apply(inputs)
    240     self.assertEqual(my_layer._scope.name, 'my_layer')
    241     my_layer1 = PrivateLayer(name='my_layer')
    242     my_layer1.apply(inputs)
    243     self.assertEqual(my_layer1._scope.name, 'my_layer_1')
    244     my_layer2 = PrivateLayer(name='my_layer')
    245     my_layer2.apply(inputs)
    246     self.assertEqual(my_layer2._scope.name, 'my_layer_2')
    247     # Name scope shouldn't affect names.
    248     with ops.name_scope('some_name_scope'):
    249       default_layer2 = PrivateLayer()
    250       default_layer2.apply(inputs)
    251       self.assertEqual(default_layer2._scope.name, 'private_layer_2')
    252       my_layer3 = PrivateLayer(name='my_layer')
    253       my_layer3.apply(inputs)
    254       self.assertEqual(my_layer3._scope.name, 'my_layer_3')
    255       other_layer = PrivateLayer(name='other_layer')
    256       other_layer.apply(inputs)
    257       self.assertEqual(other_layer._scope.name, 'other_layer')
    258     # Variable scope gets added to scope names.
    259     with variable_scope.variable_scope('var_scope'):
    260       default_layer_scoped = PrivateLayer()
    261       default_layer_scoped.apply(inputs)
    262       self.assertEqual(default_layer_scoped._scope.name,
    263                        'var_scope/private_layer')
    264       my_layer_scoped = PrivateLayer(name='my_layer')
    265       my_layer_scoped.apply(inputs)
    266       self.assertEqual(my_layer_scoped._scope.name, 'var_scope/my_layer')
    267       my_layer_scoped1 = PrivateLayer(name='my_layer')
    268       my_layer_scoped1.apply(inputs)
    269       self.assertEqual(my_layer_scoped1._scope.name, 'var_scope/my_layer_1')
    270 
    271   @test_util.run_in_graph_and_eager_modes()
    272   def testInputSpecNdimCheck(self):
    273 
    274     class CustomerLayer(base_layers.Layer):
    275 
    276       def __init__(self):
    277         super(CustomerLayer, self).__init__()
    278         self.input_spec = base_layers.InputSpec(ndim=2)
    279 
    280       def call(self, inputs):
    281         return inputs
    282 
    283     if context.in_graph_mode():
    284       layer = CustomerLayer()
    285       with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
    286         layer.apply(array_ops.placeholder('int32'))
    287 
    288     layer = CustomerLayer()
    289     with self.assertRaisesRegexp(ValueError, r'expected ndim=2'):
    290       layer.apply(constant_op.constant([1]))
    291 
    292     # Note that we re-create the layer since in Eager mode, input spec checks
    293     # only happen on first call.
    294     # Works
    295     layer = CustomerLayer()
    296     layer.apply(constant_op.constant([[1], [2]]))
    297 
    298   @test_util.run_in_graph_and_eager_modes()
    299   def testInputSpecMinNdimCheck(self):
    300 
    301     class CustomerLayer(base_layers.Layer):
    302 
    303       def __init__(self):
    304         super(CustomerLayer, self).__init__()
    305         self.input_spec = base_layers.InputSpec(min_ndim=2)
    306 
    307       def call(self, inputs):
    308         return inputs
    309 
    310     if context.in_graph_mode():
    311       layer = CustomerLayer()
    312       with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
    313         layer.apply(array_ops.placeholder('int32'))
    314 
    315     layer = CustomerLayer()
    316     with self.assertRaisesRegexp(ValueError, r'expected min_ndim=2'):
    317       layer.apply(constant_op.constant([1]))
    318 
    319     # Works
    320     layer = CustomerLayer()
    321     layer.apply(constant_op.constant([[1], [2]]))
    322 
    323     layer = CustomerLayer()
    324     layer.apply(constant_op.constant([[[1], [2]]]))
    325 
    326   @test_util.run_in_graph_and_eager_modes()
    327   def testInputSpecMaxNdimCheck(self):
    328 
    329     class CustomerLayer(base_layers.Layer):
    330 
    331       def __init__(self):
    332         super(CustomerLayer, self).__init__()
    333         self.input_spec = base_layers.InputSpec(max_ndim=2)
    334 
    335       def call(self, inputs):
    336         return inputs
    337 
    338     if context.in_graph_mode():
    339       layer = CustomerLayer()
    340       with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
    341         layer.apply(array_ops.placeholder('int32'))
    342 
    343     layer = CustomerLayer()
    344     with self.assertRaisesRegexp(ValueError, r'expected max_ndim=2'):
    345       layer.apply(constant_op.constant([[[1], [2]]]))
    346 
    347     # Works
    348     layer = CustomerLayer()
    349     layer.apply(constant_op.constant([1]))
    350 
    351     layer = CustomerLayer()
    352     layer.apply(constant_op.constant([[1], [2]]))
    353 
    354   @test_util.run_in_graph_and_eager_modes()
    355   def testInputSpecDtypeCheck(self):
    356 
    357     class CustomerLayer(base_layers.Layer):
    358 
    359       def __init__(self):
    360         super(CustomerLayer, self).__init__()
    361         self.input_spec = base_layers.InputSpec(dtype='float32')
    362 
    363       def call(self, inputs):
    364         return inputs
    365 
    366     layer = CustomerLayer()
    367     with self.assertRaisesRegexp(ValueError, r'expected dtype=float32'):
    368       layer.apply(constant_op.constant(1, dtype=dtypes.int32))
    369 
    370     # Works
    371     layer = CustomerLayer()
    372     layer.apply(constant_op.constant(1.0, dtype=dtypes.float32))
    373 
    374   @test_util.run_in_graph_and_eager_modes()
    375   def testInputSpecAxesCheck(self):
    376 
    377     class CustomerLayer(base_layers.Layer):
    378 
    379       def __init__(self):
    380         super(CustomerLayer, self).__init__()
    381         self.input_spec = base_layers.InputSpec(axes={-1: 2})
    382 
    383       def call(self, inputs):
    384         return inputs
    385 
    386     layer = CustomerLayer()
    387     with self.assertRaisesRegexp(ValueError, r'expected axis'):
    388       layer.apply(constant_op.constant([1, 2, 3]))
    389 
    390     # Works
    391     layer = CustomerLayer()
    392     layer.apply(constant_op.constant([1, 2]))
    393     layer = CustomerLayer()
    394     layer.apply(constant_op.constant([[1, 2], [3, 4], [5, 6]]))
    395 
    396   @test_util.run_in_graph_and_eager_modes()
    397   def testInputSpecShapeCheck(self):
    398 
    399     class CustomerLayer(base_layers.Layer):
    400 
    401       def __init__(self):
    402         super(CustomerLayer, self).__init__()
    403         self.input_spec = base_layers.InputSpec(shape=(None, 3))
    404 
    405       def call(self, inputs):
    406         return inputs
    407 
    408     layer = CustomerLayer()
    409     with self.assertRaisesRegexp(ValueError, r'expected shape'):
    410       layer.apply(constant_op.constant([[1, 2]]))
    411 
    412     # Works
    413     layer = CustomerLayer()
    414     layer.apply(constant_op.constant([[1, 2, 3], [4, 5, 6]]))
    415 
    416   @test_util.run_in_graph_and_eager_modes()
    417   def testNoInputSpec(self):
    418 
    419     class CustomerLayer(base_layers.Layer):
    420 
    421       def __init__(self):
    422         super(CustomerLayer, self).__init__()
    423         self.input_spec = None
    424 
    425       def call(self, inputs):
    426         return inputs
    427 
    428     layer = CustomerLayer()
    429 
    430     layer.apply(constant_op.constant(1))
    431 
    432     # Works
    433     if context.in_graph_mode():
    434       layer.apply(array_ops.placeholder('int32'))
    435       layer.apply(array_ops.placeholder('int32', shape=(2, 3)))
    436 
    437   @test_util.run_in_graph_and_eager_modes()
    438   def test_count_params(self):
    439     dense = core_layers.Dense(16)
    440     dense.build((None, 4))
    441     self.assertEqual(dense.count_params(), 16 * 4 + 16)
    442 
    443     dense = core_layers.Dense(16)
    444     with self.assertRaises(ValueError):
    445       dense.count_params()
    446 
    447   @test_util.run_in_graph_and_eager_modes()
    448   def testDictInputOutput(self):
    449 
    450     class DictLayer(base_layers.Layer):
    451 
    452       def call(self, inputs):
    453         return {'l' + key: inputs[key] for key in inputs}
    454 
    455     layer = DictLayer()
    456     if context.in_graph_mode():
    457       i1 = array_ops.placeholder('int32')
    458       i2 = array_ops.placeholder('float32')
    459       result = layer.apply({'abel': i1, 'ogits': i2})
    460       self.assertTrue(isinstance(result, dict))
    461       self.assertEqual(set(['label', 'logits']), set(result.keys()))
    462     else:
    463       i1 = constant_op.constant(3)
    464       i2 = constant_op.constant(4.0)
    465       result = layer.apply({'abel': i1, 'ogits': i2})
    466       self.assertTrue(isinstance(result, dict))
    467       self.assertEqual(set(['label', 'logits']), set(result.keys()))
    468       self.assertEqual(3, result['label'].numpy())
    469       self.assertEqual(4.0, result['logits'].numpy())
    470 
    471   def testActivityRegularizer(self):
    472     regularizer = math_ops.reduce_sum
    473     layer = base_layers.Layer(activity_regularizer=regularizer)
    474     x = array_ops.placeholder('int32')
    475     layer.apply(x)
    476     self.assertEqual(len(layer.get_losses_for(x)), 1)
    477 
    478   def testNameScopeIsConsistentWithVariableScope(self):
    479     # Github issue 13429.
    480 
    481     class MyLayer(base_layers.Layer):
    482 
    483       def build(self, input_shape):
    484         self.my_var = self.add_variable('my_var', (), dtypes.float32)
    485         self.built = True
    486 
    487       def call(self, inputs):
    488         return math_ops.multiply(inputs, self.my_var, name='my_op')
    489 
    490     def _gen_layer(x, name=None):
    491       layer = MyLayer(name=name)
    492       out = layer.apply(x)
    493       return layer, out
    494 
    495     # unnamed layer
    496     with ops.Graph().as_default():
    497       x = array_ops.placeholder(dtypes.float32, (), 'x')
    498       layer, op = _gen_layer(x)
    499       layer1, op1 = _gen_layer(op)
    500       layer2, op2 = _gen_layer(op1)
    501 
    502       self.assertEqual(layer.my_var.name, 'my_layer/my_var:0')
    503       self.assertEqual(op.name, 'my_layer/my_op:0')
    504       self.assertEqual(layer1.my_var.name, 'my_layer_1/my_var:0')
    505       self.assertEqual(op1.name, 'my_layer_1/my_op:0')
    506       self.assertEqual(layer2.my_var.name, 'my_layer_2/my_var:0')
    507       self.assertEqual(op2.name, 'my_layer_2/my_op:0')
    508     # name starts from zero
    509     with ops.Graph().as_default():
    510       x = array_ops.placeholder(dtypes.float32, (), 'x')
    511       layer, op = _gen_layer(x, name='name')
    512       layer1, op1 = _gen_layer(op, name='name_1')
    513       layer2, op2 = _gen_layer(op1, name='name_2')
    514 
    515       self.assertEqual(layer.my_var.name, 'name/my_var:0')
    516       self.assertEqual(op.name, 'name/my_op:0')
    517       self.assertEqual(layer1.my_var.name, 'name_1/my_var:0')
    518       self.assertEqual(op1.name, 'name_1/my_op:0')
    519       self.assertEqual(layer2.my_var.name, 'name_2/my_var:0')
    520       self.assertEqual(op2.name, 'name_2/my_op:0')
    521     # name starts from one
    522     with ops.Graph().as_default():
    523       x = array_ops.placeholder(dtypes.float32, (), 'x')
    524       layer, op = _gen_layer(x, name='name_1')
    525       layer1, op1 = _gen_layer(op, name='name_2')
    526       layer2, op2 = _gen_layer(op1, name='name_3')
    527 
    528       self.assertEqual(layer.my_var.name, 'name_1/my_var:0')
    529       self.assertEqual(op.name, 'name_1/my_op:0')
    530       self.assertEqual(layer1.my_var.name, 'name_2/my_var:0')
    531       self.assertEqual(op1.name, 'name_2/my_op:0')
    532       self.assertEqual(layer2.my_var.name, 'name_3/my_var:0')
    533       self.assertEqual(op2.name, 'name_3/my_op:0')
    534 
    535   def testVariablesAreLiftedFromFunctionBuildingGraphs(self):
    536     class MyLayer(base_layers.Layer):
    537 
    538       def build(self, input_shape):
    539         self.my_var = self.add_variable('my_var', (), dtypes.float32)
    540         self.built = True
    541 
    542       def call(self, inputs):
    543         return inputs
    544 
    545     outer_graph = ops.get_default_graph()
    546     function_building_graph = ops.Graph()
    547     function_building_graph._building_function = True
    548     with outer_graph.as_default():
    549       with function_building_graph.as_default():
    550         layer = MyLayer()
    551         # Create a variable by invoking build through __call__ and assert that
    552         # it is both tracked and lifted into the outer graph.
    553         inputs = array_ops.placeholder(dtypes.float32, (), 'inputs')
    554         layer.apply(inputs)
    555         self.assertEqual(len(layer.variables), 1)
    556         self.assertEqual(len(layer.trainable_variables), 1)
    557         self.assertEqual(layer.variables[0].graph, outer_graph)
    558 
    559   def testGetUpdateFor(self):
    560 
    561     class MyLayer(base_layers.Layer):
    562 
    563       def build(self, input_shape):
    564         self.a = self.add_variable('a',
    565                                    (),
    566                                    dtypes.float32,
    567                                    trainable=False)
    568         self.b = self.add_variable('b',
    569                                    (),
    570                                    dtypes.float32,
    571                                    trainable=False)
    572         self.add_update(state_ops.assign_add(self.a, 1., name='b_update'))
    573         self.built = True
    574 
    575       def call(self, inputs):
    576         self.add_update(state_ops.assign_add(self.a, inputs, name='a_update'),
    577                         inputs=True)
    578         return inputs + 1
    579 
    580     layer = MyLayer()
    581     inputs = array_ops.placeholder(dtypes.float32, (), 'inputs')
    582     intermediate_inputs = inputs + 1
    583     outputs = layer.apply(intermediate_inputs)
    584 
    585     self.assertEqual(len(layer.updates), 2)
    586     self.assertEqual(len(layer.get_updates_for(None)), 1)
    587     self.assertEqual(len(layer.get_updates_for([inputs])), 1)
    588     self.assertEqual(len(layer.get_updates_for([intermediate_inputs])), 1)
    589     self.assertEqual(len(layer.get_updates_for([outputs])), 0)
    590 
    591     # Call same layer on new input, creating one more conditional update
    592     inputs = array_ops.placeholder(dtypes.float32, (), 'inputs')
    593     intermediate_inputs = inputs + 1
    594     outputs = layer.apply(intermediate_inputs)
    595 
    596     self.assertEqual(len(layer.updates), 3)
    597     self.assertEqual(len(layer.get_updates_for(None)), 1)
    598     # Check that we are successfully filtering out irrelevant updates
    599     self.assertEqual(len(layer.get_updates_for([inputs])), 1)
    600     self.assertEqual(len(layer.get_updates_for([intermediate_inputs])), 1)
    601     self.assertEqual(len(layer.get_updates_for([outputs])), 0)
    602 
    603   def testGetLossesFor(self):
    604 
    605     class MyLayer(base_layers.Layer):
    606 
    607       def build(self, input_shape):
    608         self.a = self.add_variable('a',
    609                                    (),
    610                                    dtypes.float32,
    611                                    trainable=False)
    612         self.b = self.add_variable('b',
    613                                    (),
    614                                    dtypes.float32,
    615                                    trainable=False)
    616         self.add_loss(self.a)
    617         self.built = True
    618 
    619       def call(self, inputs):
    620         self.add_loss(inputs, inputs=True)
    621         return inputs + 1
    622 
    623     layer = MyLayer()
    624     inputs = array_ops.placeholder(dtypes.float32, (), 'inputs')
    625     intermediate_inputs = inputs + 1
    626     outputs = layer.apply(intermediate_inputs)
    627 
    628     self.assertEqual(len(layer.losses), 2)
    629     self.assertEqual(len(layer.get_losses_for(None)), 1)
    630     self.assertEqual(len(layer.get_losses_for([inputs])), 1)
    631     self.assertEqual(len(layer.get_losses_for([intermediate_inputs])), 1)
    632     self.assertEqual(len(layer.get_losses_for([outputs])), 0)
    633 
    634     # Call same layer on new input, creating one more conditional loss
    635     inputs = array_ops.placeholder(dtypes.float32, (), 'inputs')
    636     intermediate_inputs = inputs + 1
    637     outputs = layer.apply(intermediate_inputs)
    638 
    639     self.assertEqual(len(layer.losses), 3)
    640     self.assertEqual(len(layer.get_losses_for(None)), 1)
    641     # Check that we are successfully filtering out irrelevant losses
    642     self.assertEqual(len(layer.get_losses_for([inputs])), 1)
    643     self.assertEqual(len(layer.get_losses_for([intermediate_inputs])), 1)
    644     self.assertEqual(len(layer.get_losses_for([outputs])), 0)
    645 
    646 
    647 if __name__ == '__main__':
    648   test.main()
    649