Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for Adadelta Optimizer."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.framework import constant_op
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.ops import embedding_ops
     26 from tensorflow.python.ops import math_ops
     27 from tensorflow.python.ops import resource_variable_ops
     28 from tensorflow.python.ops import variables
     29 from tensorflow.python.platform import test
     30 from tensorflow.python.training import adadelta
     31 
     32 
     33 class AdadeltaOptimizerTest(test.TestCase):
     34 
     35   def doTestBasic(self, use_resource=False):
     36     num_updates = 4  # number of ADADELTA steps to perform
     37     for dtype in [dtypes.half, dtypes.float32]:
     38       for grad in [0.2, 0.1, 0.01]:
     39         for lr in [1.0, 0.5, 0.1]:
     40           with self.test_session():
     41             var0_init = [1.0, 2.0]
     42             var1_init = [3.0, 4.0]
     43             if use_resource:
     44               var0 = resource_variable_ops.ResourceVariable(
     45                   var0_init, dtype=dtype)
     46               var1 = resource_variable_ops.ResourceVariable(
     47                   var1_init, dtype=dtype)
     48             else:
     49               var0 = variables.Variable(var0_init, dtype=dtype)
     50               var1 = variables.Variable(var1_init, dtype=dtype)
     51 
     52             grads = constant_op.constant([grad, grad], dtype=dtype)
     53 
     54             accum = 0.0
     55             accum_update = 0.0
     56 
     57             # ADADELTA gradient optimizer
     58             rho = 0.95
     59             epsilon = 1e-8
     60             adadelta_opt = adadelta.AdadeltaOptimizer(lr, rho, epsilon)
     61             adadelta_update = adadelta_opt.apply_gradients(
     62                 zip([grads, grads], [var0, var1]))
     63 
     64             opt_vars = adadelta_opt.variables()
     65             self.assertStartsWith(opt_vars[0].name, var0._shared_name)
     66             self.assertStartsWith(opt_vars[1].name, var0._shared_name)
     67             self.assertStartsWith(opt_vars[2].name, var1._shared_name)
     68             self.assertStartsWith(opt_vars[3].name, var1._shared_name)
     69             self.assertEqual(4, len(opt_vars))
     70 
     71             variables.global_variables_initializer().run()
     72 
     73             # Assign slots
     74             slot = [None] * 2
     75             slot_update = [None] * 2
     76             self.assertEqual(["accum", "accum_update"],
     77                              adadelta_opt.get_slot_names())
     78             slot[0] = adadelta_opt.get_slot(var0, "accum")
     79             self.assertEquals(slot[0].get_shape(), var0.get_shape())
     80             self.assertFalse(slot[0] in variables.trainable_variables())
     81 
     82             slot_update[0] = adadelta_opt.get_slot(var0, "accum_update")
     83             self.assertEquals(slot_update[0].get_shape(), var0.get_shape())
     84             self.assertFalse(slot_update[0] in variables.trainable_variables())
     85 
     86             slot[1] = adadelta_opt.get_slot(var1, "accum")
     87             self.assertEquals(slot[1].get_shape(), var1.get_shape())
     88             self.assertFalse(slot[1] in variables.trainable_variables())
     89 
     90             slot_update[1] = adadelta_opt.get_slot(var1, "accum_update")
     91             self.assertEquals(slot_update[1].get_shape(), var1.get_shape())
     92             self.assertFalse(slot_update[1] in variables.trainable_variables())
     93 
     94             # Fetch params to validate initial values
     95             self.assertAllClose(var0_init, var0.eval())
     96             self.assertAllClose(var1_init, var1.eval())
     97 
     98             update = [None] * num_updates
     99             tot_update = 0
    100             for step in range(num_updates):
    101               # Run adadelta update for comparison
    102               adadelta_update.run()
    103 
    104               # Perform initial update without previous accum values
    105               accum = accum * rho + (grad**2) * (1 - rho)
    106               update[step] = (np.sqrt(accum_update + epsilon) *
    107                               (1. / np.sqrt(accum + epsilon)) * grad)
    108               accum_update = (accum_update * rho + (update[step]**2) *
    109                               (1.0 - rho))
    110               tot_update += update[step] * lr
    111 
    112               # Check that the accumulators have been updated
    113               for slot_idx in range(2):
    114                 self.assertAllCloseAccordingToType(
    115                     np.array([accum, accum], dtype=dtype.as_numpy_dtype()),
    116                     slot[slot_idx].eval(),
    117                     rtol=1e-5)
    118 
    119                 self.assertAllCloseAccordingToType(
    120                     np.array(
    121                         [accum_update, accum_update],
    122                         dtype=dtype.as_numpy_dtype()),
    123                     slot_update[slot_idx].eval(),
    124                     rtol=1e-5)
    125 
    126               # Check that the parameters have been updated
    127               self.assertAllCloseAccordingToType(
    128                   np.array(
    129                       [var0_init[0] - tot_update, var0_init[1] - tot_update],
    130                       dtype=dtype.as_numpy_dtype()),
    131                   var0.eval(),
    132                   rtol=1e-5)
    133 
    134               self.assertAllCloseAccordingToType(
    135                   np.array(
    136                       [var1_init[0] - tot_update, var1_init[1] - tot_update],
    137                       dtype=dtype.as_numpy_dtype()),
    138                   var1.eval(),
    139                   rtol=1e-5)
    140 
    141   def testBasic(self):
    142     self.doTestBasic(use_resource=False)
    143 
    144   def testResourceBasic(self):
    145     self.doTestBasic(use_resource=True)
    146 
    147   def testMinimizeSparseResourceVariable(self):
    148     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
    149       with self.test_session():
    150         var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
    151         x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
    152         pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
    153         loss = pred * pred
    154         sgd_op = adadelta.AdadeltaOptimizer(
    155             1.0, 1.0, 1.0).minimize(loss)
    156         variables.global_variables_initializer().run()
    157         # Fetch params to validate initial values
    158         self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
    159         # Run 1 step of sgd
    160         sgd_op.run()
    161         # Validate updated params
    162         self.assertAllCloseAccordingToType(
    163             [[-111, -138]], var0.eval())
    164 
    165 
    166 if __name__ == "__main__":
    167   test.main()
    168