1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tf.contrib.kfac.optimizer.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.contrib.kfac.python.ops import layer_collection as lc 24 from tensorflow.contrib.kfac.python.ops import optimizer 25 from tensorflow.python.framework import ops 26 from tensorflow.python.ops import array_ops 27 from tensorflow.python.ops import init_ops 28 from tensorflow.python.ops import math_ops 29 from tensorflow.python.ops import nn 30 from tensorflow.python.ops import variable_scope 31 from tensorflow.python.ops import variables as tf_variables 32 from tensorflow.python.platform import test 33 34 35 def dummy_layer_collection(): 36 lcoll = lc.LayerCollection() 37 dummy = array_ops.constant([1., 2.]) 38 lcoll.register_categorical_predictive_distribution(logits=dummy) 39 return lcoll 40 41 42 class OptimizerTest(test.TestCase): 43 44 def testOptimizerInitInvalidMomentumRegistration(self): 45 with self.assertRaises(ValueError): 46 optimizer.KfacOptimizer( 47 0.1, 0.2, 0.3, lc.LayerCollection(), momentum_type='foo') 48 49 def testOptimizerInit(self): 50 with ops.Graph().as_default(): 51 layer_collection = lc.LayerCollection() 52 53 inputs = array_ops.ones((2, 1)) * 2 54 weights_val = np.ones((1, 1), dtype=np.float32) * 3. 55 weights = variable_scope.get_variable( 56 'w', initializer=array_ops.constant(weights_val)) 57 bias = variable_scope.get_variable( 58 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1)) 59 output = math_ops.matmul(inputs, weights) + bias 60 61 layer_collection.register_fully_connected((weights, bias), inputs, output) 62 63 logits = math_ops.tanh(output) 64 targets = array_ops.constant([[0.], [1.]]) 65 output = math_ops.reduce_mean( 66 nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets)) 67 68 layer_collection.register_categorical_predictive_distribution(logits) 69 70 optimizer.KfacOptimizer( 71 0.1, 72 0.2, 73 0.3, 74 layer_collection, 75 momentum=0.5, 76 momentum_type='regular') 77 78 def testSquaredFisherNorm(self): 79 with ops.Graph().as_default(), self.test_session() as sess: 80 grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None), 81 (array_ops.constant([[2., 3.], [4., 5.]]), None)] 82 pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None), 83 (array_ops.constant([[7., 8.], [9., 10.]]), None)] 84 opt = optimizer.KfacOptimizer(0.1, 0.2, 0.3, dummy_layer_collection()) 85 sq_norm = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars) 86 self.assertAlmostEqual(174., sess.run(sq_norm), places=5) 87 88 def testUpdateClipCoeff(self): 89 with ops.Graph().as_default(), self.test_session() as sess: 90 grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None), 91 (array_ops.constant([[2., 3.], [4., 5.]]), None)] 92 pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None), 93 (array_ops.constant([[7., 8.], [9., 10.]]), None)] 94 lrate = 0.1 95 96 # Note: without rescaling, the squared Fisher norm of the update 97 # is 1.74 98 99 # If the update already satisfies the norm constraint, there should 100 # be no rescaling. 101 opt = optimizer.KfacOptimizer( 102 lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=10.) 103 coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars) 104 self.assertAlmostEqual(1., sess.run(coeff), places=5) 105 106 # If the update violates the constraint, it should be rescaled to 107 # be on the constraint boundary. 108 opt = optimizer.KfacOptimizer( 109 lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=0.5) 110 coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars) 111 sq_norm_pgrad = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars) 112 sq_norm_update = lrate**2 * coeff**2 * sq_norm_pgrad 113 self.assertAlmostEqual(0.5, sess.run(sq_norm_update), places=5) 114 115 def testComputeUpdateStepsRegular(self): 116 # TODO(olganw): implement this. 117 pass 118 119 def testComputeUpdateStepsAdam(self): 120 # TODO(olganw): implement this. 121 pass 122 123 def testUpdateVelocities(self): 124 with ops.Graph().as_default(), self.test_session() as sess: 125 layers = lc.LayerCollection() 126 layers.register_categorical_predictive_distribution( 127 array_ops.constant([1.0])) 128 opt = optimizer.KfacOptimizer( 129 0.1, 0.2, 0.3, layers, momentum=0.5, momentum_type='regular') 130 x = variable_scope.get_variable('x', initializer=array_ops.ones((2, 2))) 131 y = variable_scope.get_variable( 132 'y', initializer=array_ops.ones((2, 2)) * 2) 133 vec1 = array_ops.ones((2, 2)) * 3 134 vec2 = array_ops.ones((2, 2)) * 4 135 136 model_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 137 update_op = opt._update_velocities([(vec1, x), (vec2, y)], 0.5) 138 opt_vars = [ 139 v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 140 if v not in model_vars 141 ] 142 143 sess.run(tf_variables.global_variables_initializer()) 144 old_opt_vars = sess.run(opt_vars) 145 146 # Optimizer vars start out at 0. 147 for opt_var in old_opt_vars: 148 self.assertAllEqual(sess.run(array_ops.zeros_like(opt_var)), opt_var) 149 150 sess.run(update_op) 151 new_opt_vars = sess.run(opt_vars) 152 # After one update, the velocities are equal to the vectors. 153 for vec, opt_var in zip([vec1, vec2], new_opt_vars): 154 self.assertAllEqual(sess.run(vec), opt_var) 155 156 sess.run(update_op) 157 final_opt_vars = sess.run(opt_vars) 158 for first, second in zip(new_opt_vars, final_opt_vars): 159 self.assertFalse(np.equal(first, second).all()) 160 161 def testApplyGradients(self): 162 with ops.Graph().as_default(), self.test_session() as sess: 163 layer_collection = lc.LayerCollection() 164 165 inputs = array_ops.ones((2, 1)) * 2 166 weights_val = np.ones((1, 1), dtype=np.float32) * 3. 167 weights = variable_scope.get_variable( 168 'w', initializer=array_ops.constant(weights_val)) 169 bias = variable_scope.get_variable( 170 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1)) 171 output = math_ops.matmul(inputs, weights) + bias 172 173 layer_collection.register_fully_connected((weights, bias), inputs, output) 174 175 logits = math_ops.tanh(output) 176 targets = array_ops.constant([[0.], [1.]]) 177 output = math_ops.reduce_mean( 178 nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets)) 179 180 layer_collection.register_categorical_predictive_distribution(logits) 181 182 opt = optimizer.KfacOptimizer( 183 0.1, 184 0.2, 185 0.3, 186 layer_collection, 187 momentum=0.5, 188 momentum_type='regular') 189 grads_and_vars = opt.compute_gradients(output, [weights, bias]) 190 all_vars = [grad_and_var[1] for grad_and_var in grads_and_vars] 191 192 op = opt.apply_gradients(grads_and_vars) 193 194 sess.run(tf_variables.global_variables_initializer()) 195 old_vars = sess.run(all_vars) 196 sess.run(op) 197 new_vars = sess.run(all_vars) 198 199 for old_var, new_var in zip(old_vars, new_vars): 200 self.assertNotEqual(old_var, new_var) 201 202 203 if __name__ == '__main__': 204 test.main() 205