Home | History | Annotate | Download | only in distributions
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 from __future__ import absolute_import
     16 from __future__ import division
     17 from __future__ import print_function
     18 
     19 import importlib
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.client import session
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import random_seed
     26 from tensorflow.python.framework import tensor_shape
     27 from tensorflow.python.ops import math_ops
     28 from tensorflow.python.ops import nn_ops
     29 from tensorflow.python.ops.distributions import beta as beta_lib
     30 from tensorflow.python.ops.distributions import kullback_leibler
     31 from tensorflow.python.platform import test
     32 from tensorflow.python.platform import tf_logging
     33 
     34 
     35 def try_import(name):  # pylint: disable=invalid-name
     36   module = None
     37   try:
     38     module = importlib.import_module(name)
     39   except ImportError as e:
     40     tf_logging.warning("Could not import %s: %s" % (name, str(e)))
     41   return module
     42 
     43 
     44 special = try_import("scipy.special")
     45 stats = try_import("scipy.stats")
     46 
     47 
     48 class BetaTest(test.TestCase):
     49 
     50   def testSimpleShapes(self):
     51     with self.test_session():
     52       a = np.random.rand(3)
     53       b = np.random.rand(3)
     54       dist = beta_lib.Beta(a, b)
     55       self.assertAllEqual([], dist.event_shape_tensor().eval())
     56       self.assertAllEqual([3], dist.batch_shape_tensor().eval())
     57       self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
     58       self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape)
     59 
     60   def testComplexShapes(self):
     61     with self.test_session():
     62       a = np.random.rand(3, 2, 2)
     63       b = np.random.rand(3, 2, 2)
     64       dist = beta_lib.Beta(a, b)
     65       self.assertAllEqual([], dist.event_shape_tensor().eval())
     66       self.assertAllEqual([3, 2, 2], dist.batch_shape_tensor().eval())
     67       self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
     68       self.assertEqual(
     69           tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
     70 
     71   def testComplexShapesBroadcast(self):
     72     with self.test_session():
     73       a = np.random.rand(3, 2, 2)
     74       b = np.random.rand(2, 2)
     75       dist = beta_lib.Beta(a, b)
     76       self.assertAllEqual([], dist.event_shape_tensor().eval())
     77       self.assertAllEqual([3, 2, 2], dist.batch_shape_tensor().eval())
     78       self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
     79       self.assertEqual(
     80           tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
     81 
     82   def testAlphaProperty(self):
     83     a = [[1., 2, 3]]
     84     b = [[2., 4, 3]]
     85     with self.test_session():
     86       dist = beta_lib.Beta(a, b)
     87       self.assertEqual([1, 3], dist.concentration1.get_shape())
     88       self.assertAllClose(a, dist.concentration1.eval())
     89 
     90   def testBetaProperty(self):
     91     a = [[1., 2, 3]]
     92     b = [[2., 4, 3]]
     93     with self.test_session():
     94       dist = beta_lib.Beta(a, b)
     95       self.assertEqual([1, 3], dist.concentration0.get_shape())
     96       self.assertAllClose(b, dist.concentration0.eval())
     97 
     98   def testPdfXProper(self):
     99     a = [[1., 2, 3]]
    100     b = [[2., 4, 3]]
    101     with self.test_session():
    102       dist = beta_lib.Beta(a, b, validate_args=True)
    103       dist.prob([.1, .3, .6]).eval()
    104       dist.prob([.2, .3, .5]).eval()
    105       # Either condition can trigger.
    106       with self.assertRaisesOpError("sample must be positive"):
    107         dist.prob([-1., 0.1, 0.5]).eval()
    108       with self.assertRaisesOpError("sample must be positive"):
    109         dist.prob([0., 0.1, 0.5]).eval()
    110       with self.assertRaisesOpError("sample must be less than `1`"):
    111         dist.prob([.1, .2, 1.2]).eval()
    112       with self.assertRaisesOpError("sample must be less than `1`"):
    113         dist.prob([.1, .2, 1.0]).eval()
    114 
    115   def testPdfTwoBatches(self):
    116     with self.test_session():
    117       a = [1., 2]
    118       b = [1., 2]
    119       x = [.5, .5]
    120       dist = beta_lib.Beta(a, b)
    121       pdf = dist.prob(x)
    122       self.assertAllClose([1., 3. / 2], pdf.eval())
    123       self.assertEqual((2,), pdf.get_shape())
    124 
    125   def testPdfTwoBatchesNontrivialX(self):
    126     with self.test_session():
    127       a = [1., 2]
    128       b = [1., 2]
    129       x = [.3, .7]
    130       dist = beta_lib.Beta(a, b)
    131       pdf = dist.prob(x)
    132       self.assertAllClose([1, 63. / 50], pdf.eval())
    133       self.assertEqual((2,), pdf.get_shape())
    134 
    135   def testPdfUniformZeroBatch(self):
    136     with self.test_session():
    137       # This is equivalent to a uniform distribution
    138       a = 1.
    139       b = 1.
    140       x = np.array([.1, .2, .3, .5, .8], dtype=np.float32)
    141       dist = beta_lib.Beta(a, b)
    142       pdf = dist.prob(x)
    143       self.assertAllClose([1.] * 5, pdf.eval())
    144       self.assertEqual((5,), pdf.get_shape())
    145 
    146   def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
    147     with self.test_session():
    148       a = [[1., 2]]
    149       b = [[1., 2]]
    150       x = [[.5, .5], [.3, .7]]
    151       dist = beta_lib.Beta(a, b)
    152       pdf = dist.prob(x)
    153       self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], pdf.eval())
    154       self.assertEqual((2, 2), pdf.get_shape())
    155 
    156   def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
    157     with self.test_session():
    158       a = [1., 2]
    159       b = [1., 2]
    160       x = [[.5, .5], [.2, .8]]
    161       pdf = beta_lib.Beta(a, b).prob(x)
    162       self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], pdf.eval())
    163       self.assertEqual((2, 2), pdf.get_shape())
    164 
    165   def testPdfXStretchedInBroadcastWhenSameRank(self):
    166     with self.test_session():
    167       a = [[1., 2], [2., 3]]
    168       b = [[1., 2], [2., 3]]
    169       x = [[.5, .5]]
    170       pdf = beta_lib.Beta(a, b).prob(x)
    171       self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], pdf.eval())
    172       self.assertEqual((2, 2), pdf.get_shape())
    173 
    174   def testPdfXStretchedInBroadcastWhenLowerRank(self):
    175     with self.test_session():
    176       a = [[1., 2], [2., 3]]
    177       b = [[1., 2], [2., 3]]
    178       x = [.5, .5]
    179       pdf = beta_lib.Beta(a, b).prob(x)
    180       self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], pdf.eval())
    181       self.assertEqual((2, 2), pdf.get_shape())
    182 
    183   def testBetaMean(self):
    184     with session.Session():
    185       a = [1., 2, 3]
    186       b = [2., 4, 1.2]
    187       dist = beta_lib.Beta(a, b)
    188       self.assertEqual(dist.mean().get_shape(), (3,))
    189       if not stats:
    190         return
    191       expected_mean = stats.beta.mean(a, b)
    192       self.assertAllClose(expected_mean, dist.mean().eval())
    193 
    194   def testBetaVariance(self):
    195     with session.Session():
    196       a = [1., 2, 3]
    197       b = [2., 4, 1.2]
    198       dist = beta_lib.Beta(a, b)
    199       self.assertEqual(dist.variance().get_shape(), (3,))
    200       if not stats:
    201         return
    202       expected_variance = stats.beta.var(a, b)
    203       self.assertAllClose(expected_variance, dist.variance().eval())
    204 
    205   def testBetaMode(self):
    206     with session.Session():
    207       a = np.array([1.1, 2, 3])
    208       b = np.array([2., 4, 1.2])
    209       expected_mode = (a - 1) / (a + b - 2)
    210       dist = beta_lib.Beta(a, b)
    211       self.assertEqual(dist.mode().get_shape(), (3,))
    212       self.assertAllClose(expected_mode, dist.mode().eval())
    213 
    214   def testBetaModeInvalid(self):
    215     with session.Session():
    216       a = np.array([1., 2, 3])
    217       b = np.array([2., 4, 1.2])
    218       dist = beta_lib.Beta(a, b, allow_nan_stats=False)
    219       with self.assertRaisesOpError("Condition x < y.*"):
    220         dist.mode().eval()
    221 
    222       a = np.array([2., 2, 3])
    223       b = np.array([1., 4, 1.2])
    224       dist = beta_lib.Beta(a, b, allow_nan_stats=False)
    225       with self.assertRaisesOpError("Condition x < y.*"):
    226         dist.mode().eval()
    227 
    228   def testBetaModeEnableAllowNanStats(self):
    229     with session.Session():
    230       a = np.array([1., 2, 3])
    231       b = np.array([2., 4, 1.2])
    232       dist = beta_lib.Beta(a, b, allow_nan_stats=True)
    233 
    234       expected_mode = (a - 1) / (a + b - 2)
    235       expected_mode[0] = np.nan
    236       self.assertEqual((3,), dist.mode().get_shape())
    237       self.assertAllClose(expected_mode, dist.mode().eval())
    238 
    239       a = np.array([2., 2, 3])
    240       b = np.array([1., 4, 1.2])
    241       dist = beta_lib.Beta(a, b, allow_nan_stats=True)
    242 
    243       expected_mode = (a - 1) / (a + b - 2)
    244       expected_mode[0] = np.nan
    245       self.assertEqual((3,), dist.mode().get_shape())
    246       self.assertAllClose(expected_mode, dist.mode().eval())
    247 
    248   def testBetaEntropy(self):
    249     with session.Session():
    250       a = [1., 2, 3]
    251       b = [2., 4, 1.2]
    252       dist = beta_lib.Beta(a, b)
    253       self.assertEqual(dist.entropy().get_shape(), (3,))
    254       if not stats:
    255         return
    256       expected_entropy = stats.beta.entropy(a, b)
    257       self.assertAllClose(expected_entropy, dist.entropy().eval())
    258 
    259   def testBetaSample(self):
    260     with self.test_session():
    261       a = 1.
    262       b = 2.
    263       beta = beta_lib.Beta(a, b)
    264       n = constant_op.constant(100000)
    265       samples = beta.sample(n)
    266       sample_values = samples.eval()
    267       self.assertEqual(sample_values.shape, (100000,))
    268       self.assertFalse(np.any(sample_values < 0.0))
    269       if not stats:
    270         return
    271       self.assertLess(
    272           stats.kstest(
    273               # Beta is a univariate distribution.
    274               sample_values,
    275               stats.beta(a=1., b=2.).cdf)[0],
    276           0.01)
    277       # The standard error of the sample mean is 1 / (sqrt(18 * n))
    278       self.assertAllClose(
    279           sample_values.mean(axis=0), stats.beta.mean(a, b), atol=1e-2)
    280       self.assertAllClose(
    281           np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1)
    282 
    283   # Test that sampling with the same seed twice gives the same results.
    284   def testBetaSampleMultipleTimes(self):
    285     with self.test_session():
    286       a_val = 1.
    287       b_val = 2.
    288       n_val = 100
    289 
    290       random_seed.set_random_seed(654321)
    291       beta1 = beta_lib.Beta(concentration1=a_val,
    292                             concentration0=b_val,
    293                             name="beta1")
    294       samples1 = beta1.sample(n_val, seed=123456).eval()
    295 
    296       random_seed.set_random_seed(654321)
    297       beta2 = beta_lib.Beta(concentration1=a_val,
    298                             concentration0=b_val,
    299                             name="beta2")
    300       samples2 = beta2.sample(n_val, seed=123456).eval()
    301 
    302       self.assertAllClose(samples1, samples2)
    303 
    304   def testBetaSampleMultidimensional(self):
    305     with self.test_session():
    306       a = np.random.rand(3, 2, 2).astype(np.float32)
    307       b = np.random.rand(3, 2, 2).astype(np.float32)
    308       beta = beta_lib.Beta(a, b)
    309       n = constant_op.constant(100000)
    310       samples = beta.sample(n)
    311       sample_values = samples.eval()
    312       self.assertEqual(sample_values.shape, (100000, 3, 2, 2))
    313       self.assertFalse(np.any(sample_values < 0.0))
    314       if not stats:
    315         return
    316       self.assertAllClose(
    317           sample_values[:, 1, :].mean(axis=0),
    318           stats.beta.mean(a, b)[1, :],
    319           atol=1e-1)
    320 
    321   def testBetaCdf(self):
    322     with self.test_session():
    323       shape = (30, 40, 50)
    324       for dt in (np.float32, np.float64):
    325         a = 10. * np.random.random(shape).astype(dt)
    326         b = 10. * np.random.random(shape).astype(dt)
    327         x = np.random.random(shape).astype(dt)
    328         actual = beta_lib.Beta(a, b).cdf(x).eval()
    329         self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
    330         self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
    331         if not stats:
    332           return
    333         self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
    334 
    335   def testBetaLogCdf(self):
    336     with self.test_session():
    337       shape = (30, 40, 50)
    338       for dt in (np.float32, np.float64):
    339         a = 10. * np.random.random(shape).astype(dt)
    340         b = 10. * np.random.random(shape).astype(dt)
    341         x = np.random.random(shape).astype(dt)
    342         actual = math_ops.exp(beta_lib.Beta(a, b).log_cdf(x)).eval()
    343         self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
    344         self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
    345         if not stats:
    346           return
    347         self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
    348 
    349   def testBetaWithSoftplusConcentration(self):
    350     with self.test_session():
    351       a, b = -4.2, -9.1
    352       dist = beta_lib.BetaWithSoftplusConcentration(a, b)
    353       self.assertAllClose(nn_ops.softplus(a).eval(), dist.concentration1.eval())
    354       self.assertAllClose(nn_ops.softplus(b).eval(), dist.concentration0.eval())
    355 
    356   def testBetaBetaKL(self):
    357     with self.test_session() as sess:
    358       for shape in [(10,), (4, 5)]:
    359         a1 = 6.0 * np.random.random(size=shape) + 1e-4
    360         b1 = 6.0 * np.random.random(size=shape) + 1e-4
    361         a2 = 6.0 * np.random.random(size=shape) + 1e-4
    362         b2 = 6.0 * np.random.random(size=shape) + 1e-4
    363         # Take inverse softplus of values to test BetaWithSoftplusConcentration
    364         a1_sp = np.log(np.exp(a1) - 1.0)
    365         b1_sp = np.log(np.exp(b1) - 1.0)
    366         a2_sp = np.log(np.exp(a2) - 1.0)
    367         b2_sp = np.log(np.exp(b2) - 1.0)
    368 
    369         d1 = beta_lib.Beta(concentration1=a1, concentration0=b1)
    370         d2 = beta_lib.Beta(concentration1=a2, concentration0=b2)
    371         d1_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a1_sp,
    372                                                        concentration0=b1_sp)
    373         d2_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a2_sp,
    374                                                        concentration0=b2_sp)
    375 
    376         if not special:
    377           return
    378         kl_expected = (special.betaln(a2, b2) - special.betaln(a1, b1) +
    379                        (a1 - a2) * special.digamma(a1) +
    380                        (b1 - b2) * special.digamma(b1) +
    381                        (a2 - a1 + b2 - b1) * special.digamma(a1 + b1))
    382 
    383         for dist1 in [d1, d1_sp]:
    384           for dist2 in [d2, d2_sp]:
    385             kl = kullback_leibler.kl_divergence(dist1, dist2)
    386             kl_val = sess.run(kl)
    387             self.assertEqual(kl.get_shape(), shape)
    388             self.assertAllClose(kl_val, kl_expected)
    389 
    390         # Make sure KL(d1||d1) is 0
    391         kl_same = sess.run(kullback_leibler.kl_divergence(d1, d1))
    392         self.assertAllClose(kl_same, np.zeros_like(kl_expected))
    393 
    394 
    395 if __name__ == "__main__":
    396   test.main()
    397