Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for initializers."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 from scipy import stats
     23 from tensorflow.contrib.distributions.python.ops import logistic
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.ops.distributions import distribution
     27 from tensorflow.python.platform import test
     28 
     29 
     30 class LogisticTest(test.TestCase):
     31 
     32   def testReparameterizable(self):
     33     batch_size = 6
     34     np_loc = np.array([2.0] * batch_size, dtype=np.float32)
     35     loc = constant_op.constant(np_loc)
     36     scale = 1.5
     37     dist = logistic.Logistic(loc, scale)
     38     self.assertTrue(
     39         dist.reparameterization_type == distribution.FULLY_REPARAMETERIZED)
     40 
     41   def testLogisticLogProb(self):
     42     with self.test_session():
     43       batch_size = 6
     44       np_loc = np.array([2.0] * batch_size, dtype=np.float32)
     45       loc = constant_op.constant(np_loc)
     46       scale = 1.5
     47       x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
     48       dist = logistic.Logistic(loc, scale)
     49       expected_log_prob = stats.logistic.logpdf(x, np_loc, scale)
     50 
     51       log_prob = dist.log_prob(x)
     52       self.assertEqual(log_prob.get_shape(), (6,))
     53       self.assertAllClose(log_prob.eval(), expected_log_prob)
     54 
     55       prob = dist.prob(x)
     56       self.assertEqual(prob.get_shape(), (6,))
     57       self.assertAllClose(prob.eval(), np.exp(expected_log_prob))
     58 
     59   def testLogisticCDF(self):
     60     with self.test_session():
     61       batch_size = 6
     62       np_loc = np.array([2.0] * batch_size, dtype=np.float32)
     63       loc = constant_op.constant(np_loc)
     64       scale = 1.5
     65 
     66       dist = logistic.Logistic(loc, scale)
     67       x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
     68       cdf = dist.cdf(x)
     69       expected_cdf = stats.logistic.cdf(x, np_loc, scale)
     70 
     71       self.assertEqual(cdf.get_shape(), (6,))
     72       self.assertAllClose(cdf.eval(), expected_cdf)
     73 
     74   def testLogisticLogCDF(self):
     75     with self.test_session():
     76       batch_size = 6
     77       np_loc = np.array([2.0] * batch_size, dtype=np.float32)
     78       loc = constant_op.constant(np_loc)
     79       scale = 1.5
     80 
     81       dist = logistic.Logistic(loc, scale)
     82       x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
     83       logcdf = dist.log_cdf(x)
     84       expected_logcdf = stats.logistic.logcdf(x, np_loc, scale)
     85 
     86       self.assertEqual(logcdf.get_shape(), (6,))
     87       self.assertAllClose(logcdf.eval(), expected_logcdf)
     88 
     89   def testLogisticSurvivalFunction(self):
     90     with self.test_session():
     91       batch_size = 6
     92       np_loc = np.array([2.0] * batch_size, dtype=np.float32)
     93       loc = constant_op.constant(np_loc)
     94       scale = 1.5
     95 
     96       dist = logistic.Logistic(loc, scale)
     97       x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
     98       survival_function = dist.survival_function(x)
     99       expected_survival_function = stats.logistic.sf(x, np_loc, scale)
    100 
    101       self.assertEqual(survival_function.get_shape(), (6,))
    102       self.assertAllClose(survival_function.eval(), expected_survival_function)
    103 
    104   def testLogisticLogSurvivalFunction(self):
    105     with self.test_session():
    106       batch_size = 6
    107       np_loc = np.array([2.0] * batch_size, dtype=np.float32)
    108       loc = constant_op.constant(np_loc)
    109       scale = 1.5
    110 
    111       dist = logistic.Logistic(loc, scale)
    112       x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
    113       logsurvival_function = dist.log_survival_function(x)
    114       expected_logsurvival_function = stats.logistic.logsf(x, np_loc, scale)
    115 
    116       self.assertEqual(logsurvival_function.get_shape(), (6,))
    117       self.assertAllClose(logsurvival_function.eval(),
    118                           expected_logsurvival_function)
    119 
    120   def testLogisticMean(self):
    121     with self.test_session():
    122       loc = [2.0, 1.5, 1.0]
    123       scale = 1.5
    124       expected_mean = stats.logistic.mean(loc, scale)
    125       dist = logistic.Logistic(loc, scale)
    126       self.assertAllClose(dist.mean().eval(), expected_mean)
    127 
    128   def testLogisticVariance(self):
    129     with self.test_session():
    130       loc = [2.0, 1.5, 1.0]
    131       scale = 1.5
    132       expected_variance = stats.logistic.var(loc, scale)
    133       dist = logistic.Logistic(loc, scale)
    134       self.assertAllClose(dist.variance().eval(), expected_variance)
    135 
    136   def testLogisticEntropy(self):
    137     with self.test_session():
    138       batch_size = 3
    139       np_loc = np.array([2.0] * batch_size, dtype=np.float32)
    140       loc = constant_op.constant(np_loc)
    141       scale = 1.5
    142       expected_entropy = stats.logistic.entropy(np_loc, scale)
    143       dist = logistic.Logistic(loc, scale)
    144       self.assertAllClose(dist.entropy().eval(), expected_entropy)
    145 
    146   def testLogisticSample(self):
    147     with self.test_session():
    148       loc = [3.0, 4.0, 2.0]
    149       scale = 1.0
    150       dist = logistic.Logistic(loc, scale)
    151       sample = dist.sample(seed=100)
    152       self.assertEqual(sample.get_shape(), (3,))
    153       self.assertAllClose(sample.eval(), [6.22460556, 3.79602098, 2.05084133])
    154 
    155   def testDtype(self):
    156     loc = constant_op.constant([0.1, 0.4], dtype=dtypes.float32)
    157     scale = constant_op.constant(1.0, dtype=dtypes.float32)
    158     dist = logistic.Logistic(loc, scale)
    159     self.assertEqual(dist.dtype, dtypes.float32)
    160     self.assertEqual(dist.loc.dtype, dist.scale.dtype)
    161     self.assertEqual(dist.dtype, dist.sample(5).dtype)
    162     self.assertEqual(dist.dtype, dist.mode().dtype)
    163     self.assertEqual(dist.loc.dtype, dist.mean().dtype)
    164     self.assertEqual(dist.loc.dtype, dist.variance().dtype)
    165     self.assertEqual(dist.loc.dtype, dist.stddev().dtype)
    166     self.assertEqual(dist.loc.dtype, dist.entropy().dtype)
    167     self.assertEqual(dist.loc.dtype, dist.prob(0.2).dtype)
    168     self.assertEqual(dist.loc.dtype, dist.log_prob(0.2).dtype)
    169 
    170     loc = constant_op.constant([0.1, 0.4], dtype=dtypes.float64)
    171     scale = constant_op.constant(1.0, dtype=dtypes.float64)
    172     dist64 = logistic.Logistic(loc, scale)
    173     self.assertEqual(dist64.dtype, dtypes.float64)
    174     self.assertEqual(dist64.dtype, dist64.sample(5).dtype)
    175 
    176 
    177 if __name__ == "__main__":
    178   test.main()
    179