1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for Keras initializers.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.keras._impl import keras 24 from tensorflow.python.ops import init_ops 25 from tensorflow.python.platform import test 26 27 28 class KerasInitializersTest(test.TestCase): 29 30 def _runner(self, init, shape, target_mean=None, target_std=None, 31 target_max=None, target_min=None): 32 variable = keras.backend.variable(init(shape)) 33 output = keras.backend.get_value(variable) 34 lim = 3e-2 35 if target_std is not None: 36 self.assertGreater(lim, abs(output.std() - target_std)) 37 if target_mean is not None: 38 self.assertGreater(lim, abs(output.mean() - target_mean)) 39 if target_max is not None: 40 self.assertGreater(lim, abs(output.max() - target_max)) 41 if target_min is not None: 42 self.assertGreater(lim, abs(output.min() - target_min)) 43 44 # Test serialization (assumes deterministic behavior). 45 config = init.get_config() 46 reconstructed_init = init.__class__.from_config(config) 47 variable = keras.backend.variable(reconstructed_init(shape)) 48 output_2 = keras.backend.get_value(variable) 49 self.assertAllClose(output, output_2, atol=1e-4) 50 51 def test_uniform(self): 52 tensor_shape = (9, 6, 7) 53 with self.test_session(): 54 self._runner(keras.initializers.RandomUniform(minval=-1, 55 maxval=1, 56 seed=124), 57 tensor_shape, 58 target_mean=0., target_max=1, target_min=-1) 59 60 def test_normal(self): 61 tensor_shape = (8, 12, 99) 62 with self.test_session(): 63 self._runner(keras.initializers.RandomNormal(mean=0, stddev=1, seed=153), 64 tensor_shape, 65 target_mean=0., target_std=1) 66 67 def test_truncated_normal(self): 68 tensor_shape = (12, 99, 7) 69 with self.test_session(): 70 self._runner(keras.initializers.TruncatedNormal(mean=0, 71 stddev=1, 72 seed=126), 73 tensor_shape, 74 target_mean=0., target_std=None, target_max=2) 75 76 def test_constant(self): 77 tensor_shape = (5, 6, 4) 78 with self.test_session(): 79 self._runner(keras.initializers.Constant(2), tensor_shape, 80 target_mean=2, target_max=2, target_min=2) 81 82 def test_lecun_uniform(self): 83 tensor_shape = (5, 6, 4, 2) 84 with self.test_session(): 85 fan_in, _ = init_ops._compute_fans(tensor_shape) 86 scale = np.sqrt(3. / fan_in) 87 self._runner(keras.initializers.lecun_uniform(seed=123), tensor_shape, 88 target_mean=0., target_max=scale, target_min=-scale) 89 90 def test_glorot_uniform(self): 91 tensor_shape = (5, 6, 4, 2) 92 with self.test_session(): 93 fan_in, fan_out = init_ops._compute_fans(tensor_shape) 94 scale = np.sqrt(6. / (fan_in + fan_out)) 95 self._runner(keras.initializers.glorot_uniform(seed=123), tensor_shape, 96 target_mean=0., target_max=scale, target_min=-scale) 97 98 def test_he_uniform(self): 99 tensor_shape = (5, 6, 4, 2) 100 with self.test_session(): 101 fan_in, _ = init_ops._compute_fans(tensor_shape) 102 scale = np.sqrt(6. / fan_in) 103 self._runner(keras.initializers.he_uniform(seed=123), tensor_shape, 104 target_mean=0., target_max=scale, target_min=-scale) 105 106 def test_lecun_normal(self): 107 tensor_shape = (5, 6, 4, 2) 108 with self.test_session(): 109 fan_in, _ = init_ops._compute_fans(tensor_shape) 110 scale = np.sqrt(1. / fan_in) 111 self._runner(keras.initializers.lecun_normal(seed=123), tensor_shape, 112 target_mean=0., target_std=None, target_max=2 * scale) 113 114 def test_glorot_normal(self): 115 tensor_shape = (5, 6, 4, 2) 116 with self.test_session(): 117 fan_in, fan_out = init_ops._compute_fans(tensor_shape) 118 scale = np.sqrt(2. / (fan_in + fan_out)) 119 self._runner(keras.initializers.glorot_normal(seed=123), tensor_shape, 120 target_mean=0., target_std=None, target_max=2 * scale) 121 122 def test_he_normal(self): 123 tensor_shape = (5, 6, 4, 2) 124 with self.test_session(): 125 fan_in, _ = init_ops._compute_fans(tensor_shape) 126 scale = np.sqrt(2. / fan_in) 127 self._runner(keras.initializers.he_normal(seed=123), tensor_shape, 128 target_mean=0., target_std=None, target_max=2 * scale) 129 130 def test_orthogonal(self): 131 tensor_shape = (20, 20) 132 with self.test_session(): 133 self._runner(keras.initializers.orthogonal(seed=123), tensor_shape, 134 target_mean=0.) 135 136 def test_identity(self): 137 with self.test_session(): 138 tensor_shape = (3, 4, 5) 139 with self.assertRaises(ValueError): 140 self._runner(keras.initializers.identity(), tensor_shape, 141 target_mean=1. / tensor_shape[0], target_max=1.) 142 143 tensor_shape = (3, 3) 144 self._runner(keras.initializers.identity(), tensor_shape, 145 target_mean=1. / tensor_shape[0], target_max=1.) 146 147 def test_zero(self): 148 tensor_shape = (4, 5) 149 with self.test_session(): 150 self._runner(keras.initializers.zeros(), tensor_shape, 151 target_mean=0., target_max=0.) 152 153 def test_one(self): 154 tensor_shape = (4, 5) 155 with self.test_session(): 156 self._runner(keras.initializers.ones(), tensor_shape, 157 target_mean=1., target_max=1.) 158 159 160 if __name__ == '__main__': 161 test.main() 162