1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for losses util.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.framework import constant_op 22 from tensorflow.python.framework import ops 23 from tensorflow.python.ops.losses import util 24 from tensorflow.python.platform import test 25 26 27 class LossesUtilTest(test.TestCase): 28 29 def testGetRegularizationLoss(self): 30 # Empty regularization collection should evaluate to 0.0. 31 with self.test_session(): 32 self.assertEqual(0.0, util.get_regularization_loss().eval()) 33 34 # Loss should sum. 35 ops.add_to_collection( 36 ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(2.0)) 37 ops.add_to_collection( 38 ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(3.0)) 39 with self.test_session(): 40 self.assertEqual(5.0, util.get_regularization_loss().eval()) 41 42 # Check scope capture mechanism. 43 with ops.name_scope('scope1'): 44 ops.add_to_collection( 45 ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(-1.0)) 46 with self.test_session(): 47 self.assertEqual(-1.0, util.get_regularization_loss('scope1').eval()) 48 49 50 if __name__ == '__main__': 51 test.main() 52