1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tf.layers.base.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import copy 22 23 from tensorflow.python.eager import context 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.framework import ops 27 from tensorflow.python.framework import test_util 28 from tensorflow.python.layers import base as base_layers 29 from tensorflow.python.layers import core as core_layers 30 from tensorflow.python.ops import array_ops 31 from tensorflow.python.ops import init_ops 32 from tensorflow.python.ops import math_ops 33 from tensorflow.python.ops import random_ops 34 from tensorflow.python.ops import state_ops 35 from tensorflow.python.ops import variable_scope 36 from tensorflow.python.platform import test 37 38 39 class BaseLayerTest(test.TestCase): 40 41 @test_util.run_in_graph_and_eager_modes() 42 def testLayerProperties(self): 43 layer = base_layers.Layer(name='my_layer') 44 self.assertEqual(layer.variables, []) 45 self.assertEqual(layer.trainable_variables, []) 46 self.assertEqual(layer.non_trainable_variables, []) 47 if context.in_graph_mode(): 48 # updates, losses only supported in GRAPH mode 49 self.assertEqual(layer.updates, []) 50 self.assertEqual(layer.losses, []) 51 self.assertEqual(layer.built, False) 52 layer = base_layers.Layer(name='my_layer', trainable=False) 53 self.assertEqual(layer.trainable, False) 54 55 @test_util.run_in_graph_and_eager_modes() 56 def testAddWeight(self): 57 layer = base_layers.Layer(name='my_layer') 58 59 # Test basic variable creation. 60 variable = layer.add_variable( 61 'my_var', [2, 2], initializer=init_ops.zeros_initializer()) 62 self.assertEqual(variable.name, 'my_layer/my_var:0') 63 self.assertEqual(layer.variables, [variable]) 64 self.assertEqual(layer.trainable_variables, [variable]) 65 self.assertEqual(layer.non_trainable_variables, []) 66 if context.in_graph_mode(): 67 self.assertEqual( 68 layer.variables, 69 ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) 70 71 # Test non-trainable variable creation. 72 # layer.add_variable should work even outside `build` and `call`. 73 variable_2 = layer.add_variable( 74 'non_trainable_var', [2, 2], 75 initializer=init_ops.zeros_initializer(), 76 trainable=False) 77 self.assertEqual(layer.variables, [variable, variable_2]) 78 self.assertEqual(layer.trainable_variables, [variable]) 79 self.assertEqual(layer.non_trainable_variables, [variable_2]) 80 if context.in_graph_mode(): 81 self.assertEqual( 82 len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1) 83 84 # regularizers only supported in GRAPH mode. 85 regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3 86 variable = layer.add_variable( 87 'reg_var', [2, 2], 88 initializer=init_ops.zeros_initializer(), 89 regularizer=regularizer) 90 self.assertEqual(len(layer.losses), 1) 91 92 def testNoEagerActivityRegularizer(self): 93 with context.eager_mode(): 94 with self.assertRaisesRegexp(ValueError, 'activity_regularizer'): 95 core_layers.Dense(1, activity_regularizer=lambda *args, **kwargs: 0.) 96 97 def testGetVariable(self): 98 with self.test_session(): 99 100 class MyLayer(base_layers.Layer): 101 102 def build(self, input_shape): 103 self.my_var = self.add_variable( 104 'my_var', [2, 2], initializer=init_ops.zeros_initializer()) 105 106 def call(self, inputs): 107 return inputs * 2 108 109 layer = MyLayer(name='my_layer') 110 inputs = random_ops.random_uniform((5,), seed=1) 111 layer.apply(inputs) 112 layer.apply(inputs) 113 self.assertEqual([v.name for v in layer.variables], 114 ['my_layer/my_var:0']) 115 116 # Creating a layer with no scope leads to lazy construction of 117 # the scope at apply() time. It uses scope "<current scope>/base_name" 118 lazy_layer = MyLayer(_reuse=True) 119 with variable_scope.variable_scope('new_scope'): 120 with variable_scope.variable_scope('my_layer'): 121 variable_scope.get_variable('my_var', [2, 2]) 122 123 # Smoke test: it runs. 124 lazy_layer.apply(inputs) 125 # The variables were created outside of the Layer, and 126 # reuse=True, so the Layer does not own them and they are not 127 # stored in its collection. 128 self.assertEqual(lazy_layer.variables, []) 129 self.assertEqual(lazy_layer._scope.name, 'new_scope/my_layer') 130 131 # Creating a layer with no scope leads to lazy construction of 132 # the scope at apply() time. If 'scope' argument is passed to 133 # apply(), it uses that scope when accessing variables. 134 lazy_layer = MyLayer(_reuse=True) 135 with variable_scope.variable_scope('new_scope') as new_scope: 136 variable_scope.get_variable('my_var', [2, 2]) 137 138 # Smoke test: it runs. 139 lazy_layer.apply(inputs, scope=new_scope) 140 # The variables were created outside of the Layer, and 141 # reuse=True, so the Layer does not own them and they are not 142 # stored in its collection. 143 self.assertEqual(lazy_layer.variables, []) 144 self.assertEqual(lazy_layer._scope.name, 'new_scope') 145 146 # Checking for graph equality is only done in GRAPH mode. 147 with ops.Graph().as_default(): 148 inputs_ng = random_ops.random_uniform((5,), seed=1) 149 with self.assertRaisesRegexp(ValueError, r'graph are not the same'): 150 layer.apply(inputs_ng) 151 152 @test_util.run_in_graph_and_eager_modes() 153 def testCall(self): 154 155 class MyLayer(base_layers.Layer): 156 157 def call(self, inputs): 158 return math_ops.square(inputs) 159 160 layer = MyLayer(name='my_layer') 161 inputs = random_ops.random_uniform((5,), seed=1) 162 outputs = layer.apply(inputs) 163 self.assertEqual(layer.built, True) 164 if context.in_graph_mode(): 165 # op is only supported in GRAPH mode 166 self.assertEqual(outputs.op.name, 'my_layer/Square') 167 168 def testFirstCallCanCreateVariablesButSecondCanNotWhenBuildEmpty(self): 169 # Note that this test is only run in Graph mode since with EAGER mode we can 170 # still create a new variable on second call. 171 172 class MyLayer(base_layers.Layer): 173 174 def build(self, _): 175 # Do not mark the layer as built. 176 pass 177 178 def call(self, inputs): 179 self.my_var = self.add_variable('my_var', [2, 2]) 180 if self.built: 181 # Skip creating on the first call; try to create after it's 182 # built. This is expected to fail. 183 self.add_variable('this_will_break_on_second_call', [2, 2]) 184 return inputs + math_ops.square(self.my_var) 185 186 layer = MyLayer(name='my_layer') 187 inputs = random_ops.random_uniform((2,), seed=1) 188 outputs = layer.apply(inputs) 189 self.assertEqual(layer.built, True) 190 self.assertEqual(outputs.op.name, 'my_layer/add') 191 self.assertEqual([v.name 192 for v in layer.variables], ['my_layer/my_var:0']) 193 with self.assertRaisesRegexp(ValueError, 194 'my_layer/this_will_break_on_second_call'): 195 layer.apply(inputs) 196 # The list of variables hasn't changed. 197 self.assertEqual([v.name 198 for v in layer.variables], ['my_layer/my_var:0']) 199 200 @test_util.run_in_graph_and_eager_modes() 201 def testDeepCopy(self): 202 203 class MyLayer(base_layers.Layer): 204 205 def call(self, inputs): 206 return math_ops.square(inputs) 207 208 layer = MyLayer(name='my_layer') 209 layer._private_tensor = random_ops.random_uniform(()) 210 inputs = random_ops.random_uniform((5,), seed=1) 211 outputs = layer.apply(inputs) 212 self.assertEqual(layer.built, True) 213 if context.in_graph_mode(): 214 # op only supported in GRAPH mode. 215 self.assertEqual(outputs.op.name, 'my_layer/Square') 216 217 layer_copy = copy.deepcopy(layer) 218 self.assertEqual(layer_copy.name, layer.name) 219 self.assertEqual(layer_copy._scope.name, layer._scope.name) 220 self.assertEqual(layer_copy._graph, layer._graph) 221 self.assertEqual(layer_copy._private_tensor, layer._private_tensor) 222 223 @test_util.run_in_graph_and_eager_modes() 224 def testScopeNaming(self): 225 226 class PrivateLayer(base_layers.Layer): 227 228 def call(self, inputs): 229 return inputs 230 231 inputs = random_ops.random_uniform((5,)) 232 default_layer = PrivateLayer() 233 _ = default_layer.apply(inputs) 234 self.assertEqual(default_layer._scope.name, 'private_layer') 235 default_layer1 = PrivateLayer() 236 default_layer1.apply(inputs) 237 self.assertEqual(default_layer1._scope.name, 'private_layer_1') 238 my_layer = PrivateLayer(name='my_layer') 239 my_layer.apply(inputs) 240 self.assertEqual(my_layer._scope.name, 'my_layer') 241 my_layer1 = PrivateLayer(name='my_layer') 242 my_layer1.apply(inputs) 243 self.assertEqual(my_layer1._scope.name, 'my_layer_1') 244 my_layer2 = PrivateLayer(name='my_layer') 245 my_layer2.apply(inputs) 246 self.assertEqual(my_layer2._scope.name, 'my_layer_2') 247 # Name scope shouldn't affect names. 248 with ops.name_scope('some_name_scope'): 249 default_layer2 = PrivateLayer() 250 default_layer2.apply(inputs) 251 self.assertEqual(default_layer2._scope.name, 'private_layer_2') 252 my_layer3 = PrivateLayer(name='my_layer') 253 my_layer3.apply(inputs) 254 self.assertEqual(my_layer3._scope.name, 'my_layer_3') 255 other_layer = PrivateLayer(name='other_layer') 256 other_layer.apply(inputs) 257 self.assertEqual(other_layer._scope.name, 'other_layer') 258 # Variable scope gets added to scope names. 259 with variable_scope.variable_scope('var_scope'): 260 default_layer_scoped = PrivateLayer() 261 default_layer_scoped.apply(inputs) 262 self.assertEqual(default_layer_scoped._scope.name, 263 'var_scope/private_layer') 264 my_layer_scoped = PrivateLayer(name='my_layer') 265 my_layer_scoped.apply(inputs) 266 self.assertEqual(my_layer_scoped._scope.name, 'var_scope/my_layer') 267 my_layer_scoped1 = PrivateLayer(name='my_layer') 268 my_layer_scoped1.apply(inputs) 269 self.assertEqual(my_layer_scoped1._scope.name, 'var_scope/my_layer_1') 270 271 @test_util.run_in_graph_and_eager_modes() 272 def testInputSpecNdimCheck(self): 273 274 class CustomerLayer(base_layers.Layer): 275 276 def __init__(self): 277 super(CustomerLayer, self).__init__() 278 self.input_spec = base_layers.InputSpec(ndim=2) 279 280 def call(self, inputs): 281 return inputs 282 283 if context.in_graph_mode(): 284 layer = CustomerLayer() 285 with self.assertRaisesRegexp(ValueError, r'requires a defined rank'): 286 layer.apply(array_ops.placeholder('int32')) 287 288 layer = CustomerLayer() 289 with self.assertRaisesRegexp(ValueError, r'expected ndim=2'): 290 layer.apply(constant_op.constant([1])) 291 292 # Note that we re-create the layer since in Eager mode, input spec checks 293 # only happen on first call. 294 # Works 295 layer = CustomerLayer() 296 layer.apply(constant_op.constant([[1], [2]])) 297 298 @test_util.run_in_graph_and_eager_modes() 299 def testInputSpecMinNdimCheck(self): 300 301 class CustomerLayer(base_layers.Layer): 302 303 def __init__(self): 304 super(CustomerLayer, self).__init__() 305 self.input_spec = base_layers.InputSpec(min_ndim=2) 306 307 def call(self, inputs): 308 return inputs 309 310 if context.in_graph_mode(): 311 layer = CustomerLayer() 312 with self.assertRaisesRegexp(ValueError, r'requires a defined rank'): 313 layer.apply(array_ops.placeholder('int32')) 314 315 layer = CustomerLayer() 316 with self.assertRaisesRegexp(ValueError, r'expected min_ndim=2'): 317 layer.apply(constant_op.constant([1])) 318 319 # Works 320 layer = CustomerLayer() 321 layer.apply(constant_op.constant([[1], [2]])) 322 323 layer = CustomerLayer() 324 layer.apply(constant_op.constant([[[1], [2]]])) 325 326 @test_util.run_in_graph_and_eager_modes() 327 def testInputSpecMaxNdimCheck(self): 328 329 class CustomerLayer(base_layers.Layer): 330 331 def __init__(self): 332 super(CustomerLayer, self).__init__() 333 self.input_spec = base_layers.InputSpec(max_ndim=2) 334 335 def call(self, inputs): 336 return inputs 337 338 if context.in_graph_mode(): 339 layer = CustomerLayer() 340 with self.assertRaisesRegexp(ValueError, r'requires a defined rank'): 341 layer.apply(array_ops.placeholder('int32')) 342 343 layer = CustomerLayer() 344 with self.assertRaisesRegexp(ValueError, r'expected max_ndim=2'): 345 layer.apply(constant_op.constant([[[1], [2]]])) 346 347 # Works 348 layer = CustomerLayer() 349 layer.apply(constant_op.constant([1])) 350 351 layer = CustomerLayer() 352 layer.apply(constant_op.constant([[1], [2]])) 353 354 @test_util.run_in_graph_and_eager_modes() 355 def testInputSpecDtypeCheck(self): 356 357 class CustomerLayer(base_layers.Layer): 358 359 def __init__(self): 360 super(CustomerLayer, self).__init__() 361 self.input_spec = base_layers.InputSpec(dtype='float32') 362 363 def call(self, inputs): 364 return inputs 365 366 layer = CustomerLayer() 367 with self.assertRaisesRegexp(ValueError, r'expected dtype=float32'): 368 layer.apply(constant_op.constant(1, dtype=dtypes.int32)) 369 370 # Works 371 layer = CustomerLayer() 372 layer.apply(constant_op.constant(1.0, dtype=dtypes.float32)) 373 374 @test_util.run_in_graph_and_eager_modes() 375 def testInputSpecAxesCheck(self): 376 377 class CustomerLayer(base_layers.Layer): 378 379 def __init__(self): 380 super(CustomerLayer, self).__init__() 381 self.input_spec = base_layers.InputSpec(axes={-1: 2}) 382 383 def call(self, inputs): 384 return inputs 385 386 layer = CustomerLayer() 387 with self.assertRaisesRegexp(ValueError, r'expected axis'): 388 layer.apply(constant_op.constant([1, 2, 3])) 389 390 # Works 391 layer = CustomerLayer() 392 layer.apply(constant_op.constant([1, 2])) 393 layer = CustomerLayer() 394 layer.apply(constant_op.constant([[1, 2], [3, 4], [5, 6]])) 395 396 @test_util.run_in_graph_and_eager_modes() 397 def testInputSpecShapeCheck(self): 398 399 class CustomerLayer(base_layers.Layer): 400 401 def __init__(self): 402 super(CustomerLayer, self).__init__() 403 self.input_spec = base_layers.InputSpec(shape=(None, 3)) 404 405 def call(self, inputs): 406 return inputs 407 408 layer = CustomerLayer() 409 with self.assertRaisesRegexp(ValueError, r'expected shape'): 410 layer.apply(constant_op.constant([[1, 2]])) 411 412 # Works 413 layer = CustomerLayer() 414 layer.apply(constant_op.constant([[1, 2, 3], [4, 5, 6]])) 415 416 @test_util.run_in_graph_and_eager_modes() 417 def testNoInputSpec(self): 418 419 class CustomerLayer(base_layers.Layer): 420 421 def __init__(self): 422 super(CustomerLayer, self).__init__() 423 self.input_spec = None 424 425 def call(self, inputs): 426 return inputs 427 428 layer = CustomerLayer() 429 430 layer.apply(constant_op.constant(1)) 431 432 # Works 433 if context.in_graph_mode(): 434 layer.apply(array_ops.placeholder('int32')) 435 layer.apply(array_ops.placeholder('int32', shape=(2, 3))) 436 437 @test_util.run_in_graph_and_eager_modes() 438 def test_count_params(self): 439 dense = core_layers.Dense(16) 440 dense.build((None, 4)) 441 self.assertEqual(dense.count_params(), 16 * 4 + 16) 442 443 dense = core_layers.Dense(16) 444 with self.assertRaises(ValueError): 445 dense.count_params() 446 447 @test_util.run_in_graph_and_eager_modes() 448 def testDictInputOutput(self): 449 450 class DictLayer(base_layers.Layer): 451 452 def call(self, inputs): 453 return {'l' + key: inputs[key] for key in inputs} 454 455 layer = DictLayer() 456 if context.in_graph_mode(): 457 i1 = array_ops.placeholder('int32') 458 i2 = array_ops.placeholder('float32') 459 result = layer.apply({'abel': i1, 'ogits': i2}) 460 self.assertTrue(isinstance(result, dict)) 461 self.assertEqual(set(['label', 'logits']), set(result.keys())) 462 else: 463 i1 = constant_op.constant(3) 464 i2 = constant_op.constant(4.0) 465 result = layer.apply({'abel': i1, 'ogits': i2}) 466 self.assertTrue(isinstance(result, dict)) 467 self.assertEqual(set(['label', 'logits']), set(result.keys())) 468 self.assertEqual(3, result['label'].numpy()) 469 self.assertEqual(4.0, result['logits'].numpy()) 470 471 def testActivityRegularizer(self): 472 regularizer = math_ops.reduce_sum 473 layer = base_layers.Layer(activity_regularizer=regularizer) 474 x = array_ops.placeholder('int32') 475 layer.apply(x) 476 self.assertEqual(len(layer.get_losses_for(x)), 1) 477 478 def testNameScopeIsConsistentWithVariableScope(self): 479 # Github issue 13429. 480 481 class MyLayer(base_layers.Layer): 482 483 def build(self, input_shape): 484 self.my_var = self.add_variable('my_var', (), dtypes.float32) 485 self.built = True 486 487 def call(self, inputs): 488 return math_ops.multiply(inputs, self.my_var, name='my_op') 489 490 def _gen_layer(x, name=None): 491 layer = MyLayer(name=name) 492 out = layer.apply(x) 493 return layer, out 494 495 # unnamed layer 496 with ops.Graph().as_default(): 497 x = array_ops.placeholder(dtypes.float32, (), 'x') 498 layer, op = _gen_layer(x) 499 layer1, op1 = _gen_layer(op) 500 layer2, op2 = _gen_layer(op1) 501 502 self.assertEqual(layer.my_var.name, 'my_layer/my_var:0') 503 self.assertEqual(op.name, 'my_layer/my_op:0') 504 self.assertEqual(layer1.my_var.name, 'my_layer_1/my_var:0') 505 self.assertEqual(op1.name, 'my_layer_1/my_op:0') 506 self.assertEqual(layer2.my_var.name, 'my_layer_2/my_var:0') 507 self.assertEqual(op2.name, 'my_layer_2/my_op:0') 508 # name starts from zero 509 with ops.Graph().as_default(): 510 x = array_ops.placeholder(dtypes.float32, (), 'x') 511 layer, op = _gen_layer(x, name='name') 512 layer1, op1 = _gen_layer(op, name='name_1') 513 layer2, op2 = _gen_layer(op1, name='name_2') 514 515 self.assertEqual(layer.my_var.name, 'name/my_var:0') 516 self.assertEqual(op.name, 'name/my_op:0') 517 self.assertEqual(layer1.my_var.name, 'name_1/my_var:0') 518 self.assertEqual(op1.name, 'name_1/my_op:0') 519 self.assertEqual(layer2.my_var.name, 'name_2/my_var:0') 520 self.assertEqual(op2.name, 'name_2/my_op:0') 521 # name starts from one 522 with ops.Graph().as_default(): 523 x = array_ops.placeholder(dtypes.float32, (), 'x') 524 layer, op = _gen_layer(x, name='name_1') 525 layer1, op1 = _gen_layer(op, name='name_2') 526 layer2, op2 = _gen_layer(op1, name='name_3') 527 528 self.assertEqual(layer.my_var.name, 'name_1/my_var:0') 529 self.assertEqual(op.name, 'name_1/my_op:0') 530 self.assertEqual(layer1.my_var.name, 'name_2/my_var:0') 531 self.assertEqual(op1.name, 'name_2/my_op:0') 532 self.assertEqual(layer2.my_var.name, 'name_3/my_var:0') 533 self.assertEqual(op2.name, 'name_3/my_op:0') 534 535 def testVariablesAreLiftedFromFunctionBuildingGraphs(self): 536 class MyLayer(base_layers.Layer): 537 538 def build(self, input_shape): 539 self.my_var = self.add_variable('my_var', (), dtypes.float32) 540 self.built = True 541 542 def call(self, inputs): 543 return inputs 544 545 outer_graph = ops.get_default_graph() 546 function_building_graph = ops.Graph() 547 function_building_graph._building_function = True 548 with outer_graph.as_default(): 549 with function_building_graph.as_default(): 550 layer = MyLayer() 551 # Create a variable by invoking build through __call__ and assert that 552 # it is both tracked and lifted into the outer graph. 553 inputs = array_ops.placeholder(dtypes.float32, (), 'inputs') 554 layer.apply(inputs) 555 self.assertEqual(len(layer.variables), 1) 556 self.assertEqual(len(layer.trainable_variables), 1) 557 self.assertEqual(layer.variables[0].graph, outer_graph) 558 559 def testGetUpdateFor(self): 560 561 class MyLayer(base_layers.Layer): 562 563 def build(self, input_shape): 564 self.a = self.add_variable('a', 565 (), 566 dtypes.float32, 567 trainable=False) 568 self.b = self.add_variable('b', 569 (), 570 dtypes.float32, 571 trainable=False) 572 self.add_update(state_ops.assign_add(self.a, 1., name='b_update')) 573 self.built = True 574 575 def call(self, inputs): 576 self.add_update(state_ops.assign_add(self.a, inputs, name='a_update'), 577 inputs=True) 578 return inputs + 1 579 580 layer = MyLayer() 581 inputs = array_ops.placeholder(dtypes.float32, (), 'inputs') 582 intermediate_inputs = inputs + 1 583 outputs = layer.apply(intermediate_inputs) 584 585 self.assertEqual(len(layer.updates), 2) 586 self.assertEqual(len(layer.get_updates_for(None)), 1) 587 self.assertEqual(len(layer.get_updates_for([inputs])), 1) 588 self.assertEqual(len(layer.get_updates_for([intermediate_inputs])), 1) 589 self.assertEqual(len(layer.get_updates_for([outputs])), 0) 590 591 # Call same layer on new input, creating one more conditional update 592 inputs = array_ops.placeholder(dtypes.float32, (), 'inputs') 593 intermediate_inputs = inputs + 1 594 outputs = layer.apply(intermediate_inputs) 595 596 self.assertEqual(len(layer.updates), 3) 597 self.assertEqual(len(layer.get_updates_for(None)), 1) 598 # Check that we are successfully filtering out irrelevant updates 599 self.assertEqual(len(layer.get_updates_for([inputs])), 1) 600 self.assertEqual(len(layer.get_updates_for([intermediate_inputs])), 1) 601 self.assertEqual(len(layer.get_updates_for([outputs])), 0) 602 603 def testGetLossesFor(self): 604 605 class MyLayer(base_layers.Layer): 606 607 def build(self, input_shape): 608 self.a = self.add_variable('a', 609 (), 610 dtypes.float32, 611 trainable=False) 612 self.b = self.add_variable('b', 613 (), 614 dtypes.float32, 615 trainable=False) 616 self.add_loss(self.a) 617 self.built = True 618 619 def call(self, inputs): 620 self.add_loss(inputs, inputs=True) 621 return inputs + 1 622 623 layer = MyLayer() 624 inputs = array_ops.placeholder(dtypes.float32, (), 'inputs') 625 intermediate_inputs = inputs + 1 626 outputs = layer.apply(intermediate_inputs) 627 628 self.assertEqual(len(layer.losses), 2) 629 self.assertEqual(len(layer.get_losses_for(None)), 1) 630 self.assertEqual(len(layer.get_losses_for([inputs])), 1) 631 self.assertEqual(len(layer.get_losses_for([intermediate_inputs])), 1) 632 self.assertEqual(len(layer.get_losses_for([outputs])), 0) 633 634 # Call same layer on new input, creating one more conditional loss 635 inputs = array_ops.placeholder(dtypes.float32, (), 'inputs') 636 intermediate_inputs = inputs + 1 637 outputs = layer.apply(intermediate_inputs) 638 639 self.assertEqual(len(layer.losses), 3) 640 self.assertEqual(len(layer.get_losses_for(None)), 1) 641 # Check that we are successfully filtering out irrelevant losses 642 self.assertEqual(len(layer.get_losses_for([inputs])), 1) 643 self.assertEqual(len(layer.get_losses_for([intermediate_inputs])), 1) 644 self.assertEqual(len(layer.get_losses_for([outputs])), 0) 645 646 647 if __name__ == '__main__': 648 test.main() 649