1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for Student t distribution.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import importlib 22 import math 23 24 import numpy as np 25 26 from tensorflow.python.framework import constant_op 27 from tensorflow.python.framework import random_seed 28 from tensorflow.python.ops import math_ops 29 from tensorflow.python.ops import nn_ops 30 from tensorflow.python.ops.distributions import student_t 31 from tensorflow.python.platform import test 32 from tensorflow.python.platform import tf_logging 33 34 35 def try_import(name): # pylint: disable=invalid-name 36 module = None 37 try: 38 module = importlib.import_module(name) 39 except ImportError as e: 40 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 41 return module 42 43 44 stats = try_import("scipy.stats") 45 46 47 class StudentTTest(test.TestCase): 48 49 def testStudentPDFAndLogPDF(self): 50 with self.test_session(): 51 batch_size = 6 52 df = constant_op.constant([3.] * batch_size) 53 mu = constant_op.constant([7.] * batch_size) 54 sigma = constant_op.constant([8.] * batch_size) 55 df_v = 3. 56 mu_v = 7. 57 sigma_v = 8. 58 t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32) 59 student = student_t.StudentT(df, loc=mu, scale=-sigma) 60 61 log_pdf = student.log_prob(t) 62 self.assertEquals(log_pdf.get_shape(), (6,)) 63 log_pdf_values = log_pdf.eval() 64 pdf = student.prob(t) 65 self.assertEquals(pdf.get_shape(), (6,)) 66 pdf_values = pdf.eval() 67 68 if not stats: 69 return 70 71 expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v) 72 expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v) 73 self.assertAllClose(expected_log_pdf, log_pdf_values) 74 self.assertAllClose(np.log(expected_pdf), log_pdf_values) 75 self.assertAllClose(expected_pdf, pdf_values) 76 self.assertAllClose(np.exp(expected_log_pdf), pdf_values) 77 78 def testStudentLogPDFMultidimensional(self): 79 with self.test_session(): 80 batch_size = 6 81 df = constant_op.constant([[1.5, 7.2]] * batch_size) 82 mu = constant_op.constant([[3., -3.]] * batch_size) 83 sigma = constant_op.constant([[-math.sqrt(10.), math.sqrt(15.)]] * 84 batch_size) 85 df_v = np.array([1.5, 7.2]) 86 mu_v = np.array([3., -3.]) 87 sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)]) 88 t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T 89 student = student_t.StudentT(df, loc=mu, scale=sigma) 90 log_pdf = student.log_prob(t) 91 log_pdf_values = log_pdf.eval() 92 self.assertEqual(log_pdf.get_shape(), (6, 2)) 93 pdf = student.prob(t) 94 pdf_values = pdf.eval() 95 self.assertEqual(pdf.get_shape(), (6, 2)) 96 97 if not stats: 98 return 99 expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v) 100 expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v) 101 self.assertAllClose(expected_log_pdf, log_pdf_values) 102 self.assertAllClose(np.log(expected_pdf), log_pdf_values) 103 self.assertAllClose(expected_pdf, pdf_values) 104 self.assertAllClose(np.exp(expected_log_pdf), pdf_values) 105 106 def testStudentCDFAndLogCDF(self): 107 with self.test_session(): 108 batch_size = 6 109 df = constant_op.constant([3.] * batch_size) 110 mu = constant_op.constant([7.] * batch_size) 111 sigma = constant_op.constant([-8.] * batch_size) 112 df_v = 3. 113 mu_v = 7. 114 sigma_v = 8. 115 t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32) 116 student = student_t.StudentT(df, loc=mu, scale=sigma) 117 118 log_cdf = student.log_cdf(t) 119 self.assertEquals(log_cdf.get_shape(), (6,)) 120 log_cdf_values = log_cdf.eval() 121 cdf = student.cdf(t) 122 self.assertEquals(cdf.get_shape(), (6,)) 123 cdf_values = cdf.eval() 124 125 if not stats: 126 return 127 expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v) 128 expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v) 129 self.assertAllClose(expected_log_cdf, log_cdf_values, atol=0., rtol=1e-5) 130 self.assertAllClose( 131 np.log(expected_cdf), log_cdf_values, atol=0., rtol=1e-5) 132 self.assertAllClose(expected_cdf, cdf_values, atol=0., rtol=1e-5) 133 self.assertAllClose( 134 np.exp(expected_log_cdf), cdf_values, atol=0., rtol=1e-5) 135 136 def testStudentEntropy(self): 137 df_v = np.array([[2., 3., 7.]]) # 1x3 138 mu_v = np.array([[1., -1, 0]]) # 1x3 139 sigma_v = np.array([[1., -2., 3.]]).T # transposed => 3x1 140 with self.test_session(): 141 student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v) 142 ent = student.entropy() 143 ent_values = ent.eval() 144 145 # Help scipy broadcast to 3x3 146 ones = np.array([[1, 1, 1]]) 147 sigma_bc = np.abs(sigma_v) * ones 148 mu_bc = ones.T * mu_v 149 df_bc = ones.T * df_v 150 if not stats: 151 return 152 expected_entropy = stats.t.entropy( 153 np.reshape(df_bc, [-1]), 154 loc=np.reshape(mu_bc, [-1]), 155 scale=np.reshape(sigma_bc, [-1])) 156 expected_entropy = np.reshape(expected_entropy, df_bc.shape) 157 self.assertAllClose(expected_entropy, ent_values) 158 159 def testStudentSample(self): 160 with self.test_session(): 161 df = constant_op.constant(4.) 162 mu = constant_op.constant(3.) 163 sigma = constant_op.constant(-math.sqrt(10.)) 164 df_v = 4. 165 mu_v = 3. 166 sigma_v = np.sqrt(10.) 167 n = constant_op.constant(200000) 168 student = student_t.StudentT(df=df, loc=mu, scale=sigma) 169 samples = student.sample(n, seed=123456) 170 sample_values = samples.eval() 171 n_val = 200000 172 self.assertEqual(sample_values.shape, (n_val,)) 173 self.assertAllClose(sample_values.mean(), mu_v, rtol=1e-2, atol=0) 174 self.assertAllClose( 175 sample_values.var(), 176 sigma_v**2 * df_v / (df_v - 2), 177 rtol=1e-2, 178 atol=0) 179 self._checkKLApprox(df_v, mu_v, sigma_v, sample_values) 180 181 # Test that sampling with the same seed twice gives the same results. 182 def testStudentSampleMultipleTimes(self): 183 with self.test_session(): 184 df = constant_op.constant(4.) 185 mu = constant_op.constant(3.) 186 sigma = constant_op.constant(math.sqrt(10.)) 187 n = constant_op.constant(100) 188 189 random_seed.set_random_seed(654321) 190 student = student_t.StudentT( 191 df=df, loc=mu, scale=sigma, name="student_t1") 192 samples1 = student.sample(n, seed=123456).eval() 193 194 random_seed.set_random_seed(654321) 195 student2 = student_t.StudentT( 196 df=df, loc=mu, scale=sigma, name="student_t2") 197 samples2 = student2.sample(n, seed=123456).eval() 198 199 self.assertAllClose(samples1, samples2) 200 201 def testStudentSampleSmallDfNoNan(self): 202 with self.test_session(): 203 df_v = [1e-1, 1e-5, 1e-10, 1e-20] 204 df = constant_op.constant(df_v) 205 n = constant_op.constant(200000) 206 student = student_t.StudentT(df=df, loc=1., scale=1.) 207 samples = student.sample(n, seed=123456) 208 sample_values = samples.eval() 209 n_val = 200000 210 self.assertEqual(sample_values.shape, (n_val, 4)) 211 self.assertTrue(np.all(np.logical_not(np.isnan(sample_values)))) 212 213 def testStudentSampleMultiDimensional(self): 214 with self.test_session(): 215 batch_size = 7 216 df = constant_op.constant([[3., 7.]] * batch_size) 217 mu = constant_op.constant([[3., -3.]] * batch_size) 218 sigma = constant_op.constant([[math.sqrt(10.), math.sqrt(15.)]] * 219 batch_size) 220 df_v = [3., 7.] 221 mu_v = [3., -3.] 222 sigma_v = [np.sqrt(10.), np.sqrt(15.)] 223 n = constant_op.constant(200000) 224 student = student_t.StudentT(df=df, loc=mu, scale=sigma) 225 samples = student.sample(n, seed=123456) 226 sample_values = samples.eval() 227 self.assertEqual(samples.get_shape(), (200000, batch_size, 2)) 228 self.assertAllClose( 229 sample_values[:, 0, 0].mean(), mu_v[0], rtol=1e-2, atol=0) 230 self.assertAllClose( 231 sample_values[:, 0, 0].var(), 232 sigma_v[0]**2 * df_v[0] / (df_v[0] - 2), 233 rtol=1e-1, 234 atol=0) 235 self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0]) 236 self.assertAllClose( 237 sample_values[:, 0, 1].mean(), mu_v[1], rtol=1e-2, atol=0) 238 self.assertAllClose( 239 sample_values[:, 0, 1].var(), 240 sigma_v[1]**2 * df_v[1] / (df_v[1] - 2), 241 rtol=1e-1, 242 atol=0) 243 self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 1]) 244 245 def _checkKLApprox(self, df, mu, sigma, samples): 246 n = samples.size 247 np.random.seed(137) 248 if not stats: 249 return 250 sample_scipy = stats.t.rvs(df, loc=mu, scale=sigma, size=n) 251 covg = 0.99 252 r = stats.t.interval(covg, df, loc=mu, scale=sigma) 253 bins = 100 254 hist, _ = np.histogram(samples, bins=bins, range=r) 255 hist_scipy, _ = np.histogram(sample_scipy, bins=bins, range=r) 256 self.assertGreater(hist.sum(), n * (covg - .01)) 257 self.assertGreater(hist_scipy.sum(), n * (covg - .01)) 258 hist_min1 = hist + 1. # put at least one item in each bucket 259 hist_norm = hist_min1 / hist_min1.sum() 260 hist_scipy_min1 = hist_scipy + 1. # put at least one item in each bucket 261 hist_scipy_norm = hist_scipy_min1 / hist_scipy_min1.sum() 262 kl_appx = np.sum(np.log(hist_scipy_norm / hist_norm) * hist_scipy_norm) 263 self.assertLess(kl_appx, 1) 264 265 def testBroadcastingParams(self): 266 267 def _check(student): 268 self.assertEqual(student.mean().get_shape(), (3,)) 269 self.assertEqual(student.variance().get_shape(), (3,)) 270 self.assertEqual(student.entropy().get_shape(), (3,)) 271 self.assertEqual(student.log_prob(2.).get_shape(), (3,)) 272 self.assertEqual(student.prob(2.).get_shape(), (3,)) 273 self.assertEqual(student.sample(37, seed=123456).get_shape(), (37, 3,)) 274 275 _check(student_t.StudentT(df=[2., 3., 4.,], loc=2., scale=1.)) 276 _check(student_t.StudentT(df=7., loc=[2., 3., 4.,], scale=1.)) 277 _check(student_t.StudentT(df=7., loc=3., scale=[2., 3., 4.,])) 278 279 def testBroadcastingPdfArgs(self): 280 281 def _assert_shape(student, arg, shape): 282 self.assertEqual(student.log_prob(arg).get_shape(), shape) 283 self.assertEqual(student.prob(arg).get_shape(), shape) 284 285 def _check(student): 286 _assert_shape(student, 2., (3,)) 287 xs = np.array([2., 3., 4.], dtype=np.float32) 288 _assert_shape(student, xs, (3,)) 289 xs = np.array([xs]) 290 _assert_shape(student, xs, (1, 3)) 291 xs = xs.T 292 _assert_shape(student, xs, (3, 3)) 293 294 _check(student_t.StudentT(df=[2., 3., 4.,], loc=2., scale=1.)) 295 _check(student_t.StudentT(df=7., loc=[2., 3., 4.,], scale=1.)) 296 _check(student_t.StudentT(df=7., loc=3., scale=[2., 3., 4.,])) 297 298 def _check2d(student): 299 _assert_shape(student, 2., (1, 3)) 300 xs = np.array([2., 3., 4.], dtype=np.float32) 301 _assert_shape(student, xs, (1, 3)) 302 xs = np.array([xs]) 303 _assert_shape(student, xs, (1, 3)) 304 xs = xs.T 305 _assert_shape(student, xs, (3, 3)) 306 307 _check2d(student_t.StudentT(df=[[2., 3., 4.,]], loc=2., scale=1.)) 308 _check2d(student_t.StudentT(df=7., loc=[[2., 3., 4.,]], scale=1.)) 309 _check2d(student_t.StudentT(df=7., loc=3., scale=[[2., 3., 4.,]])) 310 311 def _check2d_rows(student): 312 _assert_shape(student, 2., (3, 1)) 313 xs = np.array([2., 3., 4.], dtype=np.float32) # (3,) 314 _assert_shape(student, xs, (3, 3)) 315 xs = np.array([xs]) # (1,3) 316 _assert_shape(student, xs, (3, 3)) 317 xs = xs.T # (3,1) 318 _assert_shape(student, xs, (3, 1)) 319 320 _check2d_rows(student_t.StudentT(df=[[2.], [3.], [4.]], loc=2., scale=1.)) 321 _check2d_rows(student_t.StudentT(df=7., loc=[[2.], [3.], [4.]], scale=1.)) 322 _check2d_rows(student_t.StudentT(df=7., loc=3., scale=[[2.], [3.], [4.]])) 323 324 def testMeanAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self): 325 with self.test_session(): 326 mu = [1., 3.3, 4.4] 327 student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.]) 328 mean = student.mean().eval() 329 self.assertAllClose([1., 3.3, 4.4], mean) 330 331 def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self): 332 with self.test_session(): 333 mu = [1., 3.3, 4.4] 334 student = student_t.StudentT( 335 df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.], 336 allow_nan_stats=False) 337 with self.assertRaisesOpError("x < y"): 338 student.mean().eval() 339 340 def testMeanAllowNanStatsIsTrueReturnsNaNForUndefinedBatchMembers(self): 341 with self.test_session(): 342 mu = [-2, 0., 1., 3.3, 4.4] 343 sigma = [5., 4., 3., 2., 1.] 344 student = student_t.StudentT( 345 df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma, 346 allow_nan_stats=True) 347 mean = student.mean().eval() 348 self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean) 349 350 def testVarianceAllowNanStatsTrueReturnsNaNforUndefinedBatchMembers(self): 351 with self.test_session(): 352 # df = 0.5 ==> undefined mean ==> undefined variance. 353 # df = 1.5 ==> infinite variance. 354 df = [0.5, 1.5, 3., 5., 7.] 355 mu = [-2, 0., 1., 3.3, 4.4] 356 sigma = [5., 4., 3., 2., 1.] 357 student = student_t.StudentT( 358 df=df, loc=mu, scale=sigma, allow_nan_stats=True) 359 var = student.variance().eval() 360 ## scipy uses inf for variance when the mean is undefined. When mean is 361 # undefined we say variance is undefined as well. So test the first 362 # member of var, making sure it is NaN, then replace with inf and compare 363 # to scipy. 364 self.assertTrue(np.isnan(var[0])) 365 var[0] = np.inf 366 367 if not stats: 368 return 369 expected_var = [ 370 stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) 371 ] 372 self.assertAllClose(expected_var, var) 373 374 def testVarianceAllowNanStatsFalseGivesCorrectValueForDefinedBatchMembers( 375 self): 376 with self.test_session(): 377 # df = 1.5 ==> infinite variance. 378 df = [1.5, 3., 5., 7.] 379 mu = [0., 1., 3.3, 4.4] 380 sigma = [4., 3., 2., 1.] 381 student = student_t.StudentT(df=df, loc=mu, scale=sigma) 382 var = student.variance().eval() 383 384 if not stats: 385 return 386 expected_var = [ 387 stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) 388 ] 389 self.assertAllClose(expected_var, var) 390 391 def testVarianceAllowNanStatsFalseRaisesForUndefinedBatchMembers(self): 392 with self.test_session(): 393 # df <= 1 ==> variance not defined 394 student = student_t.StudentT( 395 df=1., loc=0., scale=1., allow_nan_stats=False) 396 with self.assertRaisesOpError("x < y"): 397 student.variance().eval() 398 399 with self.test_session(): 400 # df <= 1 ==> variance not defined 401 student = student_t.StudentT( 402 df=0.5, loc=0., scale=1., allow_nan_stats=False) 403 with self.assertRaisesOpError("x < y"): 404 student.variance().eval() 405 406 def testStd(self): 407 with self.test_session(): 408 # Defined for all batch members. 409 df = [3.5, 5., 3., 5., 7.] 410 mu = [-2.2] 411 sigma = [5., 4., 3., 2., 1.] 412 student = student_t.StudentT(df=df, loc=mu, scale=sigma) 413 # Test broadcast of mu across shape of df/sigma 414 stddev = student.stddev().eval() 415 mu *= len(df) 416 417 if not stats: 418 return 419 expected_stddev = [ 420 stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) 421 ] 422 self.assertAllClose(expected_stddev, stddev) 423 424 def testMode(self): 425 with self.test_session(): 426 df = [0.5, 1., 3] 427 mu = [-1, 0., 1] 428 sigma = [5., 4., 3.] 429 student = student_t.StudentT(df=df, loc=mu, scale=sigma) 430 # Test broadcast of mu across shape of df/sigma 431 mode = student.mode().eval() 432 self.assertAllClose([-1., 0, 1], mode) 433 434 def testPdfOfSample(self): 435 with self.test_session() as sess: 436 student = student_t.StudentT(df=3., loc=np.pi, scale=1.) 437 num = 20000 438 samples = student.sample(num, seed=123456) 439 pdfs = student.prob(samples) 440 mean = student.mean() 441 mean_pdf = student.prob(student.mean()) 442 sample_vals, pdf_vals, mean_val, mean_pdf_val = sess.run( 443 [samples, pdfs, student.mean(), mean_pdf]) 444 self.assertEqual(samples.get_shape(), (num,)) 445 self.assertEqual(pdfs.get_shape(), (num,)) 446 self.assertEqual(mean.get_shape(), ()) 447 self.assertNear(np.pi, np.mean(sample_vals), err=0.02) 448 self.assertNear(np.pi, mean_val, err=1e-6) 449 # Verify integral over sample*pdf ~= 1. 450 self._assertIntegral(sample_vals, pdf_vals, err=2e-3) 451 if not stats: 452 return 453 self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi), mean_pdf_val, err=1e-6) 454 455 def testPdfOfSampleMultiDims(self): 456 with self.test_session() as sess: 457 student = student_t.StudentT(df=[7., 11.], loc=[[5.], [6.]], scale=3.) 458 self.assertAllEqual([], student.event_shape) 459 self.assertAllEqual([], student.event_shape_tensor().eval()) 460 self.assertAllEqual([2, 2], student.batch_shape) 461 self.assertAllEqual([2, 2], student.batch_shape_tensor().eval()) 462 num = 50000 463 samples = student.sample(num, seed=123456) 464 pdfs = student.prob(samples) 465 sample_vals, pdf_vals = sess.run([samples, pdfs]) 466 self.assertEqual(samples.get_shape(), (num, 2, 2)) 467 self.assertEqual(pdfs.get_shape(), (num, 2, 2)) 468 self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=.03) 469 self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=.03) 470 self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02) 471 self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02) 472 self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02) 473 self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02) 474 if not stats: 475 return 476 self.assertNear( 477 stats.t.var(7., loc=0., scale=3.), # loc d.n. effect var 478 np.var(sample_vals[:, :, 0]), 479 err=.4) 480 self.assertNear( 481 stats.t.var(11., loc=0., scale=3.), # loc d.n. effect var 482 np.var(sample_vals[:, :, 1]), 483 err=.4) 484 485 def _assertIntegral(self, sample_vals, pdf_vals, err=1.5e-3): 486 s_p = zip(sample_vals, pdf_vals) 487 prev = (sample_vals.min() - 1000, 0) 488 total = 0 489 for k in sorted(s_p, key=lambda x: x[0]): 490 pair_pdf = (k[1] + prev[1]) / 2 491 total += (k[0] - prev[0]) * pair_pdf 492 prev = k 493 self.assertNear(1., total, err=err) 494 495 def testNegativeDofFails(self): 496 with self.test_session(): 497 student = student_t.StudentT(df=[2, -5.], loc=0., scale=1., 498 validate_args=True, name="S") 499 with self.assertRaisesOpError(r"Condition x > 0 did not hold"): 500 student.mean().eval() 501 502 def testStudentTWithAbsDfSoftplusScale(self): 503 with self.test_session(): 504 df = constant_op.constant([-3.2, -4.6]) 505 mu = constant_op.constant([-4.2, 3.4]) 506 sigma = constant_op.constant([-6.4, -8.8]) 507 student = student_t.StudentTWithAbsDfSoftplusScale( 508 df=df, loc=mu, scale=sigma) 509 self.assertAllClose( 510 math_ops.floor(math_ops.abs(df)).eval(), student.df.eval()) 511 self.assertAllClose(mu.eval(), student.loc.eval()) 512 self.assertAllClose(nn_ops.softplus(sigma).eval(), student.scale.eval()) 513 514 515 if __name__ == "__main__": 516 test.main() 517