Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tf.contrib.kfac.optimizer."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.contrib.kfac.python.ops import layer_collection as lc
     24 from tensorflow.contrib.kfac.python.ops import optimizer
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import init_ops
     28 from tensorflow.python.ops import math_ops
     29 from tensorflow.python.ops import nn
     30 from tensorflow.python.ops import variable_scope
     31 from tensorflow.python.ops import variables as tf_variables
     32 from tensorflow.python.platform import test
     33 
     34 
     35 def dummy_layer_collection():
     36   lcoll = lc.LayerCollection()
     37   dummy = array_ops.constant([1., 2.])
     38   lcoll.register_categorical_predictive_distribution(logits=dummy)
     39   return lcoll
     40 
     41 
     42 class OptimizerTest(test.TestCase):
     43 
     44   def testOptimizerInitInvalidMomentumRegistration(self):
     45     with self.assertRaises(ValueError):
     46       optimizer.KfacOptimizer(
     47           0.1, 0.2, 0.3, lc.LayerCollection(), momentum_type='foo')
     48 
     49   def testOptimizerInit(self):
     50     with ops.Graph().as_default():
     51       layer_collection = lc.LayerCollection()
     52 
     53       inputs = array_ops.ones((2, 1)) * 2
     54       weights_val = np.ones((1, 1), dtype=np.float32) * 3.
     55       weights = variable_scope.get_variable(
     56           'w', initializer=array_ops.constant(weights_val))
     57       bias = variable_scope.get_variable(
     58           'b', initializer=init_ops.zeros_initializer(), shape=(1, 1))
     59       output = math_ops.matmul(inputs, weights) + bias
     60 
     61       layer_collection.register_fully_connected((weights, bias), inputs, output)
     62 
     63       logits = math_ops.tanh(output)
     64       targets = array_ops.constant([[0.], [1.]])
     65       output = math_ops.reduce_mean(
     66           nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets))
     67 
     68       layer_collection.register_categorical_predictive_distribution(logits)
     69 
     70       optimizer.KfacOptimizer(
     71           0.1,
     72           0.2,
     73           0.3,
     74           layer_collection,
     75           momentum=0.5,
     76           momentum_type='regular')
     77 
     78   def testSquaredFisherNorm(self):
     79     with ops.Graph().as_default(), self.test_session() as sess:
     80       grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None),
     81                         (array_ops.constant([[2., 3.], [4., 5.]]), None)]
     82       pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None),
     83                          (array_ops.constant([[7., 8.], [9., 10.]]), None)]
     84       opt = optimizer.KfacOptimizer(0.1, 0.2, 0.3, dummy_layer_collection())
     85       sq_norm = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars)
     86       self.assertAlmostEqual(174., sess.run(sq_norm), places=5)
     87 
     88   def testUpdateClipCoeff(self):
     89     with ops.Graph().as_default(), self.test_session() as sess:
     90       grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None),
     91                         (array_ops.constant([[2., 3.], [4., 5.]]), None)]
     92       pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None),
     93                          (array_ops.constant([[7., 8.], [9., 10.]]), None)]
     94       lrate = 0.1
     95 
     96       # Note: without rescaling, the squared Fisher norm of the update
     97       # is 1.74
     98 
     99       # If the update already satisfies the norm constraint, there should
    100       # be no rescaling.
    101       opt = optimizer.KfacOptimizer(
    102           lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=10.)
    103       coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars)
    104       self.assertAlmostEqual(1., sess.run(coeff), places=5)
    105 
    106       # If the update violates the constraint, it should be rescaled to
    107       # be on the constraint boundary.
    108       opt = optimizer.KfacOptimizer(
    109           lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=0.5)
    110       coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars)
    111       sq_norm_pgrad = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars)
    112       sq_norm_update = lrate**2 * coeff**2 * sq_norm_pgrad
    113       self.assertAlmostEqual(0.5, sess.run(sq_norm_update), places=5)
    114 
    115   def testComputeUpdateStepsRegular(self):
    116     # TODO(olganw): implement this.
    117     pass
    118 
    119   def testComputeUpdateStepsAdam(self):
    120     # TODO(olganw): implement this.
    121     pass
    122 
    123   def testUpdateVelocities(self):
    124     with ops.Graph().as_default(), self.test_session() as sess:
    125       layers = lc.LayerCollection()
    126       layers.register_categorical_predictive_distribution(
    127           array_ops.constant([1.0]))
    128       opt = optimizer.KfacOptimizer(
    129           0.1, 0.2, 0.3, layers, momentum=0.5, momentum_type='regular')
    130       x = variable_scope.get_variable('x', initializer=array_ops.ones((2, 2)))
    131       y = variable_scope.get_variable(
    132           'y', initializer=array_ops.ones((2, 2)) * 2)
    133       vec1 = array_ops.ones((2, 2)) * 3
    134       vec2 = array_ops.ones((2, 2)) * 4
    135 
    136       model_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
    137       update_op = opt._update_velocities([(vec1, x), (vec2, y)], 0.5)
    138       opt_vars = [
    139           v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
    140           if v not in model_vars
    141       ]
    142 
    143       sess.run(tf_variables.global_variables_initializer())
    144       old_opt_vars = sess.run(opt_vars)
    145 
    146       # Optimizer vars start out at 0.
    147       for opt_var in old_opt_vars:
    148         self.assertAllEqual(sess.run(array_ops.zeros_like(opt_var)), opt_var)
    149 
    150       sess.run(update_op)
    151       new_opt_vars = sess.run(opt_vars)
    152       # After one update, the velocities are equal to the vectors.
    153       for vec, opt_var in zip([vec1, vec2], new_opt_vars):
    154         self.assertAllEqual(sess.run(vec), opt_var)
    155 
    156       sess.run(update_op)
    157       final_opt_vars = sess.run(opt_vars)
    158       for first, second in zip(new_opt_vars, final_opt_vars):
    159         self.assertFalse(np.equal(first, second).all())
    160 
    161   def testApplyGradients(self):
    162     with ops.Graph().as_default(), self.test_session() as sess:
    163       layer_collection = lc.LayerCollection()
    164 
    165       inputs = array_ops.ones((2, 1)) * 2
    166       weights_val = np.ones((1, 1), dtype=np.float32) * 3.
    167       weights = variable_scope.get_variable(
    168           'w', initializer=array_ops.constant(weights_val))
    169       bias = variable_scope.get_variable(
    170           'b', initializer=init_ops.zeros_initializer(), shape=(1, 1))
    171       output = math_ops.matmul(inputs, weights) + bias
    172 
    173       layer_collection.register_fully_connected((weights, bias), inputs, output)
    174 
    175       logits = math_ops.tanh(output)
    176       targets = array_ops.constant([[0.], [1.]])
    177       output = math_ops.reduce_mean(
    178           nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets))
    179 
    180       layer_collection.register_categorical_predictive_distribution(logits)
    181 
    182       opt = optimizer.KfacOptimizer(
    183           0.1,
    184           0.2,
    185           0.3,
    186           layer_collection,
    187           momentum=0.5,
    188           momentum_type='regular')
    189       grads_and_vars = opt.compute_gradients(output, [weights, bias])
    190       all_vars = [grad_and_var[1] for grad_and_var in grads_and_vars]
    191 
    192       op = opt.apply_gradients(grads_and_vars)
    193 
    194       sess.run(tf_variables.global_variables_initializer())
    195       old_vars = sess.run(all_vars)
    196       sess.run(op)
    197       new_vars = sess.run(all_vars)
    198 
    199       for old_var, new_var in zip(old_vars, new_vars):
    200         self.assertNotEqual(old_var, new_var)
    201 
    202 
    203 if __name__ == '__main__':
    204   test.main()
    205