1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for initializers.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 from scipy import stats 23 from tensorflow.contrib.distributions.python.ops import logistic 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.ops.distributions import distribution 27 from tensorflow.python.platform import test 28 29 30 class LogisticTest(test.TestCase): 31 32 def testReparameterizable(self): 33 batch_size = 6 34 np_loc = np.array([2.0] * batch_size, dtype=np.float32) 35 loc = constant_op.constant(np_loc) 36 scale = 1.5 37 dist = logistic.Logistic(loc, scale) 38 self.assertTrue( 39 dist.reparameterization_type == distribution.FULLY_REPARAMETERIZED) 40 41 def testLogisticLogProb(self): 42 with self.test_session(): 43 batch_size = 6 44 np_loc = np.array([2.0] * batch_size, dtype=np.float32) 45 loc = constant_op.constant(np_loc) 46 scale = 1.5 47 x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) 48 dist = logistic.Logistic(loc, scale) 49 expected_log_prob = stats.logistic.logpdf(x, np_loc, scale) 50 51 log_prob = dist.log_prob(x) 52 self.assertEqual(log_prob.get_shape(), (6,)) 53 self.assertAllClose(log_prob.eval(), expected_log_prob) 54 55 prob = dist.prob(x) 56 self.assertEqual(prob.get_shape(), (6,)) 57 self.assertAllClose(prob.eval(), np.exp(expected_log_prob)) 58 59 def testLogisticCDF(self): 60 with self.test_session(): 61 batch_size = 6 62 np_loc = np.array([2.0] * batch_size, dtype=np.float32) 63 loc = constant_op.constant(np_loc) 64 scale = 1.5 65 66 dist = logistic.Logistic(loc, scale) 67 x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) 68 cdf = dist.cdf(x) 69 expected_cdf = stats.logistic.cdf(x, np_loc, scale) 70 71 self.assertEqual(cdf.get_shape(), (6,)) 72 self.assertAllClose(cdf.eval(), expected_cdf) 73 74 def testLogisticLogCDF(self): 75 with self.test_session(): 76 batch_size = 6 77 np_loc = np.array([2.0] * batch_size, dtype=np.float32) 78 loc = constant_op.constant(np_loc) 79 scale = 1.5 80 81 dist = logistic.Logistic(loc, scale) 82 x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) 83 logcdf = dist.log_cdf(x) 84 expected_logcdf = stats.logistic.logcdf(x, np_loc, scale) 85 86 self.assertEqual(logcdf.get_shape(), (6,)) 87 self.assertAllClose(logcdf.eval(), expected_logcdf) 88 89 def testLogisticSurvivalFunction(self): 90 with self.test_session(): 91 batch_size = 6 92 np_loc = np.array([2.0] * batch_size, dtype=np.float32) 93 loc = constant_op.constant(np_loc) 94 scale = 1.5 95 96 dist = logistic.Logistic(loc, scale) 97 x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) 98 survival_function = dist.survival_function(x) 99 expected_survival_function = stats.logistic.sf(x, np_loc, scale) 100 101 self.assertEqual(survival_function.get_shape(), (6,)) 102 self.assertAllClose(survival_function.eval(), expected_survival_function) 103 104 def testLogisticLogSurvivalFunction(self): 105 with self.test_session(): 106 batch_size = 6 107 np_loc = np.array([2.0] * batch_size, dtype=np.float32) 108 loc = constant_op.constant(np_loc) 109 scale = 1.5 110 111 dist = logistic.Logistic(loc, scale) 112 x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) 113 logsurvival_function = dist.log_survival_function(x) 114 expected_logsurvival_function = stats.logistic.logsf(x, np_loc, scale) 115 116 self.assertEqual(logsurvival_function.get_shape(), (6,)) 117 self.assertAllClose(logsurvival_function.eval(), 118 expected_logsurvival_function) 119 120 def testLogisticMean(self): 121 with self.test_session(): 122 loc = [2.0, 1.5, 1.0] 123 scale = 1.5 124 expected_mean = stats.logistic.mean(loc, scale) 125 dist = logistic.Logistic(loc, scale) 126 self.assertAllClose(dist.mean().eval(), expected_mean) 127 128 def testLogisticVariance(self): 129 with self.test_session(): 130 loc = [2.0, 1.5, 1.0] 131 scale = 1.5 132 expected_variance = stats.logistic.var(loc, scale) 133 dist = logistic.Logistic(loc, scale) 134 self.assertAllClose(dist.variance().eval(), expected_variance) 135 136 def testLogisticEntropy(self): 137 with self.test_session(): 138 batch_size = 3 139 np_loc = np.array([2.0] * batch_size, dtype=np.float32) 140 loc = constant_op.constant(np_loc) 141 scale = 1.5 142 expected_entropy = stats.logistic.entropy(np_loc, scale) 143 dist = logistic.Logistic(loc, scale) 144 self.assertAllClose(dist.entropy().eval(), expected_entropy) 145 146 def testLogisticSample(self): 147 with self.test_session(): 148 loc = [3.0, 4.0, 2.0] 149 scale = 1.0 150 dist = logistic.Logistic(loc, scale) 151 sample = dist.sample(seed=100) 152 self.assertEqual(sample.get_shape(), (3,)) 153 self.assertAllClose(sample.eval(), [6.22460556, 3.79602098, 2.05084133]) 154 155 def testDtype(self): 156 loc = constant_op.constant([0.1, 0.4], dtype=dtypes.float32) 157 scale = constant_op.constant(1.0, dtype=dtypes.float32) 158 dist = logistic.Logistic(loc, scale) 159 self.assertEqual(dist.dtype, dtypes.float32) 160 self.assertEqual(dist.loc.dtype, dist.scale.dtype) 161 self.assertEqual(dist.dtype, dist.sample(5).dtype) 162 self.assertEqual(dist.dtype, dist.mode().dtype) 163 self.assertEqual(dist.loc.dtype, dist.mean().dtype) 164 self.assertEqual(dist.loc.dtype, dist.variance().dtype) 165 self.assertEqual(dist.loc.dtype, dist.stddev().dtype) 166 self.assertEqual(dist.loc.dtype, dist.entropy().dtype) 167 self.assertEqual(dist.loc.dtype, dist.prob(0.2).dtype) 168 self.assertEqual(dist.loc.dtype, dist.log_prob(0.2).dtype) 169 170 loc = constant_op.constant([0.1, 0.4], dtype=dtypes.float64) 171 scale = constant_op.constant(1.0, dtype=dtypes.float64) 172 dist64 = logistic.Logistic(loc, scale) 173 self.assertEqual(dist64.dtype, dtypes.float64) 174 self.assertEqual(dist64.dtype, dist64.sample(5).dtype) 175 176 177 if __name__ == "__main__": 178 test.main() 179