Home | History | Annotate | Download | only in training
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for sign_decay."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import math
     22 
     23 from tensorflow.contrib.opt.python.training import sign_decay
     24 from tensorflow.python.platform import test
     25 
     26 
     27 def py_linear_decay_fn(decay_steps):
     28 
     29   def linear_decay(step):
     30     step = min(step, decay_steps)
     31     return float(decay_steps - step) / decay_steps
     32 
     33   return linear_decay
     34 
     35 
     36 def py_cosine_decay_fn(decay_steps, num_periods=0.5, zero_after=None):
     37 
     38   def cosine_decay(step):
     39     step = min(step, decay_steps)
     40     fraction = 2.0 * num_periods * step / float(decay_steps)
     41     if zero_after is not None and fraction >= 2 * zero_after:
     42       return 0.0
     43     return 0.5 * (1.0 + math.cos(math.pi * fraction))
     44 
     45   return cosine_decay
     46 
     47 
     48 def py_restart_decay_fn(decay_steps, num_periods=1, zero_after=None):
     49 
     50   def restart_decay(step):
     51     step = min(step, decay_steps)
     52     tmp = num_periods * step / float(decay_steps)
     53     fraction = (
     54         num_periods * step % decay_steps) / float(decay_steps)
     55     if zero_after is not None and tmp >= zero_after:
     56       return 0
     57     return 0.5 * (1.0 + math.cos(math.pi * fraction))
     58 
     59   return restart_decay
     60 
     61 
     62 class SignDecaysTest(test.TestCase):
     63 
     64   def testLinearDecay(self):
     65     num_training_steps = 1000
     66     linear_decay_fn = sign_decay.get_linear_decay_fn(num_training_steps)
     67 
     68     for step in range(0, 1000, 100):
     69       with self.test_session():
     70         tf_decayed = linear_decay_fn(step).eval()
     71         py_decayed = py_linear_decay_fn(num_training_steps)(step)
     72         self.assertAlmostEqual(tf_decayed, py_decayed, places=4)
     73 
     74   def testCosineDecay(self):
     75     num_training_steps = 1000
     76     cosine_decay_fn = sign_decay.get_cosine_decay_fn(num_training_steps)
     77     cosine_decay_2_fn = sign_decay.get_cosine_decay_fn(
     78         num_training_steps, num_periods=5, zero_after=2)
     79 
     80     for step in range(0, 1000, 100):
     81       with self.test_session():
     82         tf_decayed = cosine_decay_fn(step).eval()
     83         py_decayed = py_cosine_decay_fn(num_training_steps)(step)
     84         self.assertAlmostEqual(tf_decayed, py_decayed, places=4)
     85 
     86         tf_decayed = cosine_decay_2_fn(step).eval()
     87         py_decayed = py_cosine_decay_fn(
     88             num_training_steps, num_periods=5, zero_after=2)(step)
     89         self.assertAlmostEqual(tf_decayed, py_decayed, places=4)
     90 
     91   def testRestartDecay(self):
     92     num_training_steps = 1000
     93     restart_decay_fn = sign_decay.get_restart_decay_fn(num_training_steps)
     94     restart_decay_2_fn = sign_decay.get_restart_decay_fn(
     95         num_training_steps, num_periods=5, zero_after=2)
     96 
     97     for step in range(0, 1000, 100):
     98       with self.test_session():
     99         tf_decayed = restart_decay_fn(step).eval()
    100         py_decayed = py_restart_decay_fn(num_training_steps)(step)
    101         self.assertAlmostEqual(tf_decayed, py_decayed, places=4)
    102 
    103         tf_decayed = restart_decay_2_fn(step).eval()
    104         py_decayed = py_restart_decay_fn(
    105             num_training_steps, num_periods=5, zero_after=2)(step)
    106         self.assertAlmostEqual(tf_decayed, py_decayed, places=4)
    107 
    108 
    109 if __name__ == "__main__":
    110   test.main()
    111