Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for training_util."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.framework import dtypes
     22 from tensorflow.python.framework import ops
     23 from tensorflow.python.framework import test_util
     24 from tensorflow.python.ops import variables
     25 from tensorflow.python.platform import test
     26 from tensorflow.python.training import monitored_session
     27 from tensorflow.python.training import training_util
     28 
     29 
     30 @test_util.run_v1_only('b/120545219')
     31 class GlobalStepTest(test.TestCase):
     32 
     33   def _assert_global_step(self, global_step, expected_dtype=dtypes.int64):
     34     self.assertEqual('%s:0' % ops.GraphKeys.GLOBAL_STEP, global_step.name)
     35     self.assertEqual(expected_dtype, global_step.dtype.base_dtype)
     36     self.assertEqual([], global_step.get_shape().as_list())
     37 
     38   def test_invalid_dtype(self):
     39     with ops.Graph().as_default() as g:
     40       self.assertIsNone(training_util.get_global_step())
     41       variables.Variable(
     42           0.0,
     43           trainable=False,
     44           dtype=dtypes.float32,
     45           name=ops.GraphKeys.GLOBAL_STEP)
     46       self.assertRaisesRegexp(TypeError, 'does not have integer type',
     47                               training_util.get_global_step)
     48     self.assertRaisesRegexp(TypeError, 'does not have integer type',
     49                             training_util.get_global_step, g)
     50 
     51   def test_invalid_shape(self):
     52     with ops.Graph().as_default() as g:
     53       self.assertIsNone(training_util.get_global_step())
     54       variables.VariableV1(
     55           [0],
     56           trainable=False,
     57           dtype=dtypes.int32,
     58           name=ops.GraphKeys.GLOBAL_STEP)
     59       self.assertRaisesRegexp(TypeError, 'not scalar',
     60                               training_util.get_global_step)
     61     self.assertRaisesRegexp(TypeError, 'not scalar',
     62                             training_util.get_global_step, g)
     63 
     64   def test_create_global_step(self):
     65     self.assertIsNone(training_util.get_global_step())
     66     with ops.Graph().as_default() as g:
     67       global_step = training_util.create_global_step()
     68       self._assert_global_step(global_step)
     69       self.assertRaisesRegexp(ValueError, 'already exists',
     70                               training_util.create_global_step)
     71       self.assertRaisesRegexp(ValueError, 'already exists',
     72                               training_util.create_global_step, g)
     73       self._assert_global_step(training_util.create_global_step(ops.Graph()))
     74 
     75   def test_get_global_step(self):
     76     with ops.Graph().as_default() as g:
     77       self.assertIsNone(training_util.get_global_step())
     78       variables.VariableV1(
     79           0,
     80           trainable=False,
     81           dtype=dtypes.int32,
     82           name=ops.GraphKeys.GLOBAL_STEP)
     83       self._assert_global_step(
     84           training_util.get_global_step(), expected_dtype=dtypes.int32)
     85     self._assert_global_step(
     86         training_util.get_global_step(g), expected_dtype=dtypes.int32)
     87 
     88   def test_get_or_create_global_step(self):
     89     with ops.Graph().as_default() as g:
     90       self.assertIsNone(training_util.get_global_step())
     91       self._assert_global_step(training_util.get_or_create_global_step())
     92       self._assert_global_step(training_util.get_or_create_global_step(g))
     93 
     94 
     95 @test_util.run_v1_only('b/120545219')
     96 class GlobalStepReadTest(test.TestCase):
     97 
     98   def test_global_step_read_is_none_if_there_is_no_global_step(self):
     99     with ops.Graph().as_default():
    100       self.assertIsNone(training_util._get_or_create_global_step_read())
    101       training_util.create_global_step()
    102       self.assertIsNotNone(training_util._get_or_create_global_step_read())
    103 
    104   def test_reads_from_cache(self):
    105     with ops.Graph().as_default():
    106       training_util.create_global_step()
    107       first = training_util._get_or_create_global_step_read()
    108       second = training_util._get_or_create_global_step_read()
    109       self.assertEqual(first, second)
    110 
    111   def test_reads_before_increments(self):
    112     with ops.Graph().as_default():
    113       training_util.create_global_step()
    114       read_tensor = training_util._get_or_create_global_step_read()
    115       inc_op = training_util._increment_global_step(1)
    116       inc_three_op = training_util._increment_global_step(3)
    117       with monitored_session.MonitoredTrainingSession() as sess:
    118         read_value, _ = sess.run([read_tensor, inc_op])
    119         self.assertEqual(0, read_value)
    120         read_value, _ = sess.run([read_tensor, inc_three_op])
    121         self.assertEqual(1, read_value)
    122         read_value = sess.run(read_tensor)
    123         self.assertEqual(4, read_value)
    124 
    125 
    126 if __name__ == '__main__':
    127   test.main()
    128