1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for training_util.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.framework import dtypes 22 from tensorflow.python.framework import ops 23 from tensorflow.python.framework import test_util 24 from tensorflow.python.ops import variables 25 from tensorflow.python.platform import test 26 from tensorflow.python.training import monitored_session 27 from tensorflow.python.training import training_util 28 29 30 @test_util.run_v1_only('b/120545219') 31 class GlobalStepTest(test.TestCase): 32 33 def _assert_global_step(self, global_step, expected_dtype=dtypes.int64): 34 self.assertEqual('%s:0' % ops.GraphKeys.GLOBAL_STEP, global_step.name) 35 self.assertEqual(expected_dtype, global_step.dtype.base_dtype) 36 self.assertEqual([], global_step.get_shape().as_list()) 37 38 def test_invalid_dtype(self): 39 with ops.Graph().as_default() as g: 40 self.assertIsNone(training_util.get_global_step()) 41 variables.Variable( 42 0.0, 43 trainable=False, 44 dtype=dtypes.float32, 45 name=ops.GraphKeys.GLOBAL_STEP) 46 self.assertRaisesRegexp(TypeError, 'does not have integer type', 47 training_util.get_global_step) 48 self.assertRaisesRegexp(TypeError, 'does not have integer type', 49 training_util.get_global_step, g) 50 51 def test_invalid_shape(self): 52 with ops.Graph().as_default() as g: 53 self.assertIsNone(training_util.get_global_step()) 54 variables.VariableV1( 55 [0], 56 trainable=False, 57 dtype=dtypes.int32, 58 name=ops.GraphKeys.GLOBAL_STEP) 59 self.assertRaisesRegexp(TypeError, 'not scalar', 60 training_util.get_global_step) 61 self.assertRaisesRegexp(TypeError, 'not scalar', 62 training_util.get_global_step, g) 63 64 def test_create_global_step(self): 65 self.assertIsNone(training_util.get_global_step()) 66 with ops.Graph().as_default() as g: 67 global_step = training_util.create_global_step() 68 self._assert_global_step(global_step) 69 self.assertRaisesRegexp(ValueError, 'already exists', 70 training_util.create_global_step) 71 self.assertRaisesRegexp(ValueError, 'already exists', 72 training_util.create_global_step, g) 73 self._assert_global_step(training_util.create_global_step(ops.Graph())) 74 75 def test_get_global_step(self): 76 with ops.Graph().as_default() as g: 77 self.assertIsNone(training_util.get_global_step()) 78 variables.VariableV1( 79 0, 80 trainable=False, 81 dtype=dtypes.int32, 82 name=ops.GraphKeys.GLOBAL_STEP) 83 self._assert_global_step( 84 training_util.get_global_step(), expected_dtype=dtypes.int32) 85 self._assert_global_step( 86 training_util.get_global_step(g), expected_dtype=dtypes.int32) 87 88 def test_get_or_create_global_step(self): 89 with ops.Graph().as_default() as g: 90 self.assertIsNone(training_util.get_global_step()) 91 self._assert_global_step(training_util.get_or_create_global_step()) 92 self._assert_global_step(training_util.get_or_create_global_step(g)) 93 94 95 @test_util.run_v1_only('b/120545219') 96 class GlobalStepReadTest(test.TestCase): 97 98 def test_global_step_read_is_none_if_there_is_no_global_step(self): 99 with ops.Graph().as_default(): 100 self.assertIsNone(training_util._get_or_create_global_step_read()) 101 training_util.create_global_step() 102 self.assertIsNotNone(training_util._get_or_create_global_step_read()) 103 104 def test_reads_from_cache(self): 105 with ops.Graph().as_default(): 106 training_util.create_global_step() 107 first = training_util._get_or_create_global_step_read() 108 second = training_util._get_or_create_global_step_read() 109 self.assertEqual(first, second) 110 111 def test_reads_before_increments(self): 112 with ops.Graph().as_default(): 113 training_util.create_global_step() 114 read_tensor = training_util._get_or_create_global_step_read() 115 inc_op = training_util._increment_global_step(1) 116 inc_three_op = training_util._increment_global_step(3) 117 with monitored_session.MonitoredTrainingSession() as sess: 118 read_value, _ = sess.run([read_tensor, inc_op]) 119 self.assertEqual(0, read_value) 120 read_value, _ = sess.run([read_tensor, inc_three_op]) 121 self.assertEqual(1, read_value) 122 read_value = sess.run(read_tensor) 123 self.assertEqual(4, read_value) 124 125 126 if __name__ == '__main__': 127 test.main() 128