1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for sign_decay.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import math 22 23 from tensorflow.contrib.opt.python.training import sign_decay 24 from tensorflow.python.platform import test 25 26 27 def py_linear_decay_fn(decay_steps): 28 29 def linear_decay(step): 30 step = min(step, decay_steps) 31 return float(decay_steps - step) / decay_steps 32 33 return linear_decay 34 35 36 def py_cosine_decay_fn(decay_steps, num_periods=0.5, zero_after=None): 37 38 def cosine_decay(step): 39 step = min(step, decay_steps) 40 fraction = 2.0 * num_periods * step / float(decay_steps) 41 if zero_after is not None and fraction >= 2 * zero_after: 42 return 0.0 43 return 0.5 * (1.0 + math.cos(math.pi * fraction)) 44 45 return cosine_decay 46 47 48 def py_restart_decay_fn(decay_steps, num_periods=1, zero_after=None): 49 50 def restart_decay(step): 51 step = min(step, decay_steps) 52 tmp = num_periods * step / float(decay_steps) 53 fraction = ( 54 num_periods * step % decay_steps) / float(decay_steps) 55 if zero_after is not None and tmp >= zero_after: 56 return 0 57 return 0.5 * (1.0 + math.cos(math.pi * fraction)) 58 59 return restart_decay 60 61 62 class SignDecaysTest(test.TestCase): 63 64 def testLinearDecay(self): 65 num_training_steps = 1000 66 linear_decay_fn = sign_decay.get_linear_decay_fn(num_training_steps) 67 68 for step in range(0, 1000, 100): 69 with self.test_session(): 70 tf_decayed = linear_decay_fn(step).eval() 71 py_decayed = py_linear_decay_fn(num_training_steps)(step) 72 self.assertAlmostEqual(tf_decayed, py_decayed, places=4) 73 74 def testCosineDecay(self): 75 num_training_steps = 1000 76 cosine_decay_fn = sign_decay.get_cosine_decay_fn(num_training_steps) 77 cosine_decay_2_fn = sign_decay.get_cosine_decay_fn( 78 num_training_steps, num_periods=5, zero_after=2) 79 80 for step in range(0, 1000, 100): 81 with self.test_session(): 82 tf_decayed = cosine_decay_fn(step).eval() 83 py_decayed = py_cosine_decay_fn(num_training_steps)(step) 84 self.assertAlmostEqual(tf_decayed, py_decayed, places=4) 85 86 tf_decayed = cosine_decay_2_fn(step).eval() 87 py_decayed = py_cosine_decay_fn( 88 num_training_steps, num_periods=5, zero_after=2)(step) 89 self.assertAlmostEqual(tf_decayed, py_decayed, places=4) 90 91 def testRestartDecay(self): 92 num_training_steps = 1000 93 restart_decay_fn = sign_decay.get_restart_decay_fn(num_training_steps) 94 restart_decay_2_fn = sign_decay.get_restart_decay_fn( 95 num_training_steps, num_periods=5, zero_after=2) 96 97 for step in range(0, 1000, 100): 98 with self.test_session(): 99 tf_decayed = restart_decay_fn(step).eval() 100 py_decayed = py_restart_decay_fn(num_training_steps)(step) 101 self.assertAlmostEqual(tf_decayed, py_decayed, places=4) 102 103 tf_decayed = restart_decay_2_fn(step).eval() 104 py_decayed = py_restart_decay_fn( 105 num_training_steps, num_periods=5, zero_after=2)(step) 106 self.assertAlmostEqual(tf_decayed, py_decayed, places=4) 107 108 109 if __name__ == "__main__": 110 test.main() 111