1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for RMSProp optimizer.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.compiler.tests.xla_test import XLATestCase 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.ops import resource_variable_ops 26 from tensorflow.python.ops import variables 27 from tensorflow.python.platform import test 28 from tensorflow.python.training import rmsprop 29 30 31 class RmspropTest(XLATestCase): 32 33 def testBasic(self): 34 for dtype in self.float_types: 35 with self.test_session(), self.test_scope(): 36 var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) 37 var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) 38 grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) 39 grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) 40 rms_opt = rmsprop.RMSPropOptimizer(3.0) 41 rms_update = rms_opt.apply_gradients( 42 zip([grads0, grads1], [var0, var1])) 43 variables.global_variables_initializer().run() 44 45 # Fetch params to validate initial values 46 self.assertAllClose([1.0, 2.0], var0.eval()) 47 self.assertAllClose([3.0, 4.0], var1.eval()) 48 49 # Run 3 steps of RMSProp 50 for _ in range(3): 51 rms_update.run() 52 53 # Validate updated params 54 self.assertAllCloseAccordingToType( 55 np.array([2.91705132e-04, 1.00029182e+00]), var0.eval()) 56 self.assertAllCloseAccordingToType( 57 np.array([2.89990854, 3.89990854]), var1.eval()) 58 59 60 if __name__ == "__main__": 61 test.main() 62