1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 from __future__ import absolute_import 16 from __future__ import division 17 from __future__ import print_function 18 19 import importlib 20 21 import numpy as np 22 23 from tensorflow.python.client import session 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.framework import random_seed 26 from tensorflow.python.framework import tensor_shape 27 from tensorflow.python.ops import math_ops 28 from tensorflow.python.ops import nn_ops 29 from tensorflow.python.ops.distributions import beta as beta_lib 30 from tensorflow.python.ops.distributions import kullback_leibler 31 from tensorflow.python.platform import test 32 from tensorflow.python.platform import tf_logging 33 34 35 def try_import(name): # pylint: disable=invalid-name 36 module = None 37 try: 38 module = importlib.import_module(name) 39 except ImportError as e: 40 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 41 return module 42 43 44 special = try_import("scipy.special") 45 stats = try_import("scipy.stats") 46 47 48 class BetaTest(test.TestCase): 49 50 def testSimpleShapes(self): 51 with self.test_session(): 52 a = np.random.rand(3) 53 b = np.random.rand(3) 54 dist = beta_lib.Beta(a, b) 55 self.assertAllEqual([], dist.event_shape_tensor().eval()) 56 self.assertAllEqual([3], dist.batch_shape_tensor().eval()) 57 self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) 58 self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape) 59 60 def testComplexShapes(self): 61 with self.test_session(): 62 a = np.random.rand(3, 2, 2) 63 b = np.random.rand(3, 2, 2) 64 dist = beta_lib.Beta(a, b) 65 self.assertAllEqual([], dist.event_shape_tensor().eval()) 66 self.assertAllEqual([3, 2, 2], dist.batch_shape_tensor().eval()) 67 self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) 68 self.assertEqual( 69 tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) 70 71 def testComplexShapesBroadcast(self): 72 with self.test_session(): 73 a = np.random.rand(3, 2, 2) 74 b = np.random.rand(2, 2) 75 dist = beta_lib.Beta(a, b) 76 self.assertAllEqual([], dist.event_shape_tensor().eval()) 77 self.assertAllEqual([3, 2, 2], dist.batch_shape_tensor().eval()) 78 self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) 79 self.assertEqual( 80 tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) 81 82 def testAlphaProperty(self): 83 a = [[1., 2, 3]] 84 b = [[2., 4, 3]] 85 with self.test_session(): 86 dist = beta_lib.Beta(a, b) 87 self.assertEqual([1, 3], dist.concentration1.get_shape()) 88 self.assertAllClose(a, dist.concentration1.eval()) 89 90 def testBetaProperty(self): 91 a = [[1., 2, 3]] 92 b = [[2., 4, 3]] 93 with self.test_session(): 94 dist = beta_lib.Beta(a, b) 95 self.assertEqual([1, 3], dist.concentration0.get_shape()) 96 self.assertAllClose(b, dist.concentration0.eval()) 97 98 def testPdfXProper(self): 99 a = [[1., 2, 3]] 100 b = [[2., 4, 3]] 101 with self.test_session(): 102 dist = beta_lib.Beta(a, b, validate_args=True) 103 dist.prob([.1, .3, .6]).eval() 104 dist.prob([.2, .3, .5]).eval() 105 # Either condition can trigger. 106 with self.assertRaisesOpError("sample must be positive"): 107 dist.prob([-1., 0.1, 0.5]).eval() 108 with self.assertRaisesOpError("sample must be positive"): 109 dist.prob([0., 0.1, 0.5]).eval() 110 with self.assertRaisesOpError("sample must be less than `1`"): 111 dist.prob([.1, .2, 1.2]).eval() 112 with self.assertRaisesOpError("sample must be less than `1`"): 113 dist.prob([.1, .2, 1.0]).eval() 114 115 def testPdfTwoBatches(self): 116 with self.test_session(): 117 a = [1., 2] 118 b = [1., 2] 119 x = [.5, .5] 120 dist = beta_lib.Beta(a, b) 121 pdf = dist.prob(x) 122 self.assertAllClose([1., 3. / 2], pdf.eval()) 123 self.assertEqual((2,), pdf.get_shape()) 124 125 def testPdfTwoBatchesNontrivialX(self): 126 with self.test_session(): 127 a = [1., 2] 128 b = [1., 2] 129 x = [.3, .7] 130 dist = beta_lib.Beta(a, b) 131 pdf = dist.prob(x) 132 self.assertAllClose([1, 63. / 50], pdf.eval()) 133 self.assertEqual((2,), pdf.get_shape()) 134 135 def testPdfUniformZeroBatch(self): 136 with self.test_session(): 137 # This is equivalent to a uniform distribution 138 a = 1. 139 b = 1. 140 x = np.array([.1, .2, .3, .5, .8], dtype=np.float32) 141 dist = beta_lib.Beta(a, b) 142 pdf = dist.prob(x) 143 self.assertAllClose([1.] * 5, pdf.eval()) 144 self.assertEqual((5,), pdf.get_shape()) 145 146 def testPdfAlphaStretchedInBroadcastWhenSameRank(self): 147 with self.test_session(): 148 a = [[1., 2]] 149 b = [[1., 2]] 150 x = [[.5, .5], [.3, .7]] 151 dist = beta_lib.Beta(a, b) 152 pdf = dist.prob(x) 153 self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], pdf.eval()) 154 self.assertEqual((2, 2), pdf.get_shape()) 155 156 def testPdfAlphaStretchedInBroadcastWhenLowerRank(self): 157 with self.test_session(): 158 a = [1., 2] 159 b = [1., 2] 160 x = [[.5, .5], [.2, .8]] 161 pdf = beta_lib.Beta(a, b).prob(x) 162 self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], pdf.eval()) 163 self.assertEqual((2, 2), pdf.get_shape()) 164 165 def testPdfXStretchedInBroadcastWhenSameRank(self): 166 with self.test_session(): 167 a = [[1., 2], [2., 3]] 168 b = [[1., 2], [2., 3]] 169 x = [[.5, .5]] 170 pdf = beta_lib.Beta(a, b).prob(x) 171 self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], pdf.eval()) 172 self.assertEqual((2, 2), pdf.get_shape()) 173 174 def testPdfXStretchedInBroadcastWhenLowerRank(self): 175 with self.test_session(): 176 a = [[1., 2], [2., 3]] 177 b = [[1., 2], [2., 3]] 178 x = [.5, .5] 179 pdf = beta_lib.Beta(a, b).prob(x) 180 self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], pdf.eval()) 181 self.assertEqual((2, 2), pdf.get_shape()) 182 183 def testBetaMean(self): 184 with session.Session(): 185 a = [1., 2, 3] 186 b = [2., 4, 1.2] 187 dist = beta_lib.Beta(a, b) 188 self.assertEqual(dist.mean().get_shape(), (3,)) 189 if not stats: 190 return 191 expected_mean = stats.beta.mean(a, b) 192 self.assertAllClose(expected_mean, dist.mean().eval()) 193 194 def testBetaVariance(self): 195 with session.Session(): 196 a = [1., 2, 3] 197 b = [2., 4, 1.2] 198 dist = beta_lib.Beta(a, b) 199 self.assertEqual(dist.variance().get_shape(), (3,)) 200 if not stats: 201 return 202 expected_variance = stats.beta.var(a, b) 203 self.assertAllClose(expected_variance, dist.variance().eval()) 204 205 def testBetaMode(self): 206 with session.Session(): 207 a = np.array([1.1, 2, 3]) 208 b = np.array([2., 4, 1.2]) 209 expected_mode = (a - 1) / (a + b - 2) 210 dist = beta_lib.Beta(a, b) 211 self.assertEqual(dist.mode().get_shape(), (3,)) 212 self.assertAllClose(expected_mode, dist.mode().eval()) 213 214 def testBetaModeInvalid(self): 215 with session.Session(): 216 a = np.array([1., 2, 3]) 217 b = np.array([2., 4, 1.2]) 218 dist = beta_lib.Beta(a, b, allow_nan_stats=False) 219 with self.assertRaisesOpError("Condition x < y.*"): 220 dist.mode().eval() 221 222 a = np.array([2., 2, 3]) 223 b = np.array([1., 4, 1.2]) 224 dist = beta_lib.Beta(a, b, allow_nan_stats=False) 225 with self.assertRaisesOpError("Condition x < y.*"): 226 dist.mode().eval() 227 228 def testBetaModeEnableAllowNanStats(self): 229 with session.Session(): 230 a = np.array([1., 2, 3]) 231 b = np.array([2., 4, 1.2]) 232 dist = beta_lib.Beta(a, b, allow_nan_stats=True) 233 234 expected_mode = (a - 1) / (a + b - 2) 235 expected_mode[0] = np.nan 236 self.assertEqual((3,), dist.mode().get_shape()) 237 self.assertAllClose(expected_mode, dist.mode().eval()) 238 239 a = np.array([2., 2, 3]) 240 b = np.array([1., 4, 1.2]) 241 dist = beta_lib.Beta(a, b, allow_nan_stats=True) 242 243 expected_mode = (a - 1) / (a + b - 2) 244 expected_mode[0] = np.nan 245 self.assertEqual((3,), dist.mode().get_shape()) 246 self.assertAllClose(expected_mode, dist.mode().eval()) 247 248 def testBetaEntropy(self): 249 with session.Session(): 250 a = [1., 2, 3] 251 b = [2., 4, 1.2] 252 dist = beta_lib.Beta(a, b) 253 self.assertEqual(dist.entropy().get_shape(), (3,)) 254 if not stats: 255 return 256 expected_entropy = stats.beta.entropy(a, b) 257 self.assertAllClose(expected_entropy, dist.entropy().eval()) 258 259 def testBetaSample(self): 260 with self.test_session(): 261 a = 1. 262 b = 2. 263 beta = beta_lib.Beta(a, b) 264 n = constant_op.constant(100000) 265 samples = beta.sample(n) 266 sample_values = samples.eval() 267 self.assertEqual(sample_values.shape, (100000,)) 268 self.assertFalse(np.any(sample_values < 0.0)) 269 if not stats: 270 return 271 self.assertLess( 272 stats.kstest( 273 # Beta is a univariate distribution. 274 sample_values, 275 stats.beta(a=1., b=2.).cdf)[0], 276 0.01) 277 # The standard error of the sample mean is 1 / (sqrt(18 * n)) 278 self.assertAllClose( 279 sample_values.mean(axis=0), stats.beta.mean(a, b), atol=1e-2) 280 self.assertAllClose( 281 np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1) 282 283 # Test that sampling with the same seed twice gives the same results. 284 def testBetaSampleMultipleTimes(self): 285 with self.test_session(): 286 a_val = 1. 287 b_val = 2. 288 n_val = 100 289 290 random_seed.set_random_seed(654321) 291 beta1 = beta_lib.Beta(concentration1=a_val, 292 concentration0=b_val, 293 name="beta1") 294 samples1 = beta1.sample(n_val, seed=123456).eval() 295 296 random_seed.set_random_seed(654321) 297 beta2 = beta_lib.Beta(concentration1=a_val, 298 concentration0=b_val, 299 name="beta2") 300 samples2 = beta2.sample(n_val, seed=123456).eval() 301 302 self.assertAllClose(samples1, samples2) 303 304 def testBetaSampleMultidimensional(self): 305 with self.test_session(): 306 a = np.random.rand(3, 2, 2).astype(np.float32) 307 b = np.random.rand(3, 2, 2).astype(np.float32) 308 beta = beta_lib.Beta(a, b) 309 n = constant_op.constant(100000) 310 samples = beta.sample(n) 311 sample_values = samples.eval() 312 self.assertEqual(sample_values.shape, (100000, 3, 2, 2)) 313 self.assertFalse(np.any(sample_values < 0.0)) 314 if not stats: 315 return 316 self.assertAllClose( 317 sample_values[:, 1, :].mean(axis=0), 318 stats.beta.mean(a, b)[1, :], 319 atol=1e-1) 320 321 def testBetaCdf(self): 322 with self.test_session(): 323 shape = (30, 40, 50) 324 for dt in (np.float32, np.float64): 325 a = 10. * np.random.random(shape).astype(dt) 326 b = 10. * np.random.random(shape).astype(dt) 327 x = np.random.random(shape).astype(dt) 328 actual = beta_lib.Beta(a, b).cdf(x).eval() 329 self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) 330 self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) 331 if not stats: 332 return 333 self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0) 334 335 def testBetaLogCdf(self): 336 with self.test_session(): 337 shape = (30, 40, 50) 338 for dt in (np.float32, np.float64): 339 a = 10. * np.random.random(shape).astype(dt) 340 b = 10. * np.random.random(shape).astype(dt) 341 x = np.random.random(shape).astype(dt) 342 actual = math_ops.exp(beta_lib.Beta(a, b).log_cdf(x)).eval() 343 self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) 344 self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) 345 if not stats: 346 return 347 self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0) 348 349 def testBetaWithSoftplusConcentration(self): 350 with self.test_session(): 351 a, b = -4.2, -9.1 352 dist = beta_lib.BetaWithSoftplusConcentration(a, b) 353 self.assertAllClose(nn_ops.softplus(a).eval(), dist.concentration1.eval()) 354 self.assertAllClose(nn_ops.softplus(b).eval(), dist.concentration0.eval()) 355 356 def testBetaBetaKL(self): 357 with self.test_session() as sess: 358 for shape in [(10,), (4, 5)]: 359 a1 = 6.0 * np.random.random(size=shape) + 1e-4 360 b1 = 6.0 * np.random.random(size=shape) + 1e-4 361 a2 = 6.0 * np.random.random(size=shape) + 1e-4 362 b2 = 6.0 * np.random.random(size=shape) + 1e-4 363 # Take inverse softplus of values to test BetaWithSoftplusConcentration 364 a1_sp = np.log(np.exp(a1) - 1.0) 365 b1_sp = np.log(np.exp(b1) - 1.0) 366 a2_sp = np.log(np.exp(a2) - 1.0) 367 b2_sp = np.log(np.exp(b2) - 1.0) 368 369 d1 = beta_lib.Beta(concentration1=a1, concentration0=b1) 370 d2 = beta_lib.Beta(concentration1=a2, concentration0=b2) 371 d1_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a1_sp, 372 concentration0=b1_sp) 373 d2_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a2_sp, 374 concentration0=b2_sp) 375 376 if not special: 377 return 378 kl_expected = (special.betaln(a2, b2) - special.betaln(a1, b1) + 379 (a1 - a2) * special.digamma(a1) + 380 (b1 - b2) * special.digamma(b1) + 381 (a2 - a1 + b2 - b1) * special.digamma(a1 + b1)) 382 383 for dist1 in [d1, d1_sp]: 384 for dist2 in [d2, d2_sp]: 385 kl = kullback_leibler.kl_divergence(dist1, dist2) 386 kl_val = sess.run(kl) 387 self.assertEqual(kl.get_shape(), shape) 388 self.assertAllClose(kl_val, kl_expected) 389 390 # Make sure KL(d1||d1) is 0 391 kl_same = sess.run(kullback_leibler.kl_divergence(d1, d1)) 392 self.assertAllClose(kl_same, np.zeros_like(kl_expected)) 393 394 395 if __name__ == "__main__": 396 test.main() 397