Home | History | Annotate | Download | only in distributions
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for Student t distribution."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import importlib
     22 import math
     23 
     24 import numpy as np
     25 
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import random_seed
     28 from tensorflow.python.ops import math_ops
     29 from tensorflow.python.ops import nn_ops
     30 from tensorflow.python.ops.distributions import student_t
     31 from tensorflow.python.platform import test
     32 from tensorflow.python.platform import tf_logging
     33 
     34 
     35 def try_import(name):  # pylint: disable=invalid-name
     36   module = None
     37   try:
     38     module = importlib.import_module(name)
     39   except ImportError as e:
     40     tf_logging.warning("Could not import %s: %s" % (name, str(e)))
     41   return module
     42 
     43 
     44 stats = try_import("scipy.stats")
     45 
     46 
     47 class StudentTTest(test.TestCase):
     48 
     49   def testStudentPDFAndLogPDF(self):
     50     with self.test_session():
     51       batch_size = 6
     52       df = constant_op.constant([3.] * batch_size)
     53       mu = constant_op.constant([7.] * batch_size)
     54       sigma = constant_op.constant([8.] * batch_size)
     55       df_v = 3.
     56       mu_v = 7.
     57       sigma_v = 8.
     58       t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
     59       student = student_t.StudentT(df, loc=mu, scale=-sigma)
     60 
     61       log_pdf = student.log_prob(t)
     62       self.assertEquals(log_pdf.get_shape(), (6,))
     63       log_pdf_values = log_pdf.eval()
     64       pdf = student.prob(t)
     65       self.assertEquals(pdf.get_shape(), (6,))
     66       pdf_values = pdf.eval()
     67 
     68       if not stats:
     69         return
     70 
     71       expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
     72       expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
     73       self.assertAllClose(expected_log_pdf, log_pdf_values)
     74       self.assertAllClose(np.log(expected_pdf), log_pdf_values)
     75       self.assertAllClose(expected_pdf, pdf_values)
     76       self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
     77 
     78   def testStudentLogPDFMultidimensional(self):
     79     with self.test_session():
     80       batch_size = 6
     81       df = constant_op.constant([[1.5, 7.2]] * batch_size)
     82       mu = constant_op.constant([[3., -3.]] * batch_size)
     83       sigma = constant_op.constant([[-math.sqrt(10.), math.sqrt(15.)]] *
     84                                    batch_size)
     85       df_v = np.array([1.5, 7.2])
     86       mu_v = np.array([3., -3.])
     87       sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)])
     88       t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T
     89       student = student_t.StudentT(df, loc=mu, scale=sigma)
     90       log_pdf = student.log_prob(t)
     91       log_pdf_values = log_pdf.eval()
     92       self.assertEqual(log_pdf.get_shape(), (6, 2))
     93       pdf = student.prob(t)
     94       pdf_values = pdf.eval()
     95       self.assertEqual(pdf.get_shape(), (6, 2))
     96 
     97       if not stats:
     98         return
     99       expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
    100       expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
    101       self.assertAllClose(expected_log_pdf, log_pdf_values)
    102       self.assertAllClose(np.log(expected_pdf), log_pdf_values)
    103       self.assertAllClose(expected_pdf, pdf_values)
    104       self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
    105 
    106   def testStudentCDFAndLogCDF(self):
    107     with self.test_session():
    108       batch_size = 6
    109       df = constant_op.constant([3.] * batch_size)
    110       mu = constant_op.constant([7.] * batch_size)
    111       sigma = constant_op.constant([-8.] * batch_size)
    112       df_v = 3.
    113       mu_v = 7.
    114       sigma_v = 8.
    115       t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
    116       student = student_t.StudentT(df, loc=mu, scale=sigma)
    117 
    118       log_cdf = student.log_cdf(t)
    119       self.assertEquals(log_cdf.get_shape(), (6,))
    120       log_cdf_values = log_cdf.eval()
    121       cdf = student.cdf(t)
    122       self.assertEquals(cdf.get_shape(), (6,))
    123       cdf_values = cdf.eval()
    124 
    125       if not stats:
    126         return
    127       expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v)
    128       expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v)
    129       self.assertAllClose(expected_log_cdf, log_cdf_values, atol=0., rtol=1e-5)
    130       self.assertAllClose(
    131           np.log(expected_cdf), log_cdf_values, atol=0., rtol=1e-5)
    132       self.assertAllClose(expected_cdf, cdf_values, atol=0., rtol=1e-5)
    133       self.assertAllClose(
    134           np.exp(expected_log_cdf), cdf_values, atol=0., rtol=1e-5)
    135 
    136   def testStudentEntropy(self):
    137     df_v = np.array([[2., 3., 7.]])  # 1x3
    138     mu_v = np.array([[1., -1, 0]])  # 1x3
    139     sigma_v = np.array([[1., -2., 3.]]).T  # transposed => 3x1
    140     with self.test_session():
    141       student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v)
    142       ent = student.entropy()
    143       ent_values = ent.eval()
    144 
    145     # Help scipy broadcast to 3x3
    146     ones = np.array([[1, 1, 1]])
    147     sigma_bc = np.abs(sigma_v) * ones
    148     mu_bc = ones.T * mu_v
    149     df_bc = ones.T * df_v
    150     if not stats:
    151       return
    152     expected_entropy = stats.t.entropy(
    153         np.reshape(df_bc, [-1]),
    154         loc=np.reshape(mu_bc, [-1]),
    155         scale=np.reshape(sigma_bc, [-1]))
    156     expected_entropy = np.reshape(expected_entropy, df_bc.shape)
    157     self.assertAllClose(expected_entropy, ent_values)
    158 
    159   def testStudentSample(self):
    160     with self.test_session():
    161       df = constant_op.constant(4.)
    162       mu = constant_op.constant(3.)
    163       sigma = constant_op.constant(-math.sqrt(10.))
    164       df_v = 4.
    165       mu_v = 3.
    166       sigma_v = np.sqrt(10.)
    167       n = constant_op.constant(200000)
    168       student = student_t.StudentT(df=df, loc=mu, scale=sigma)
    169       samples = student.sample(n, seed=123456)
    170       sample_values = samples.eval()
    171       n_val = 200000
    172       self.assertEqual(sample_values.shape, (n_val,))
    173       self.assertAllClose(sample_values.mean(), mu_v, rtol=1e-2, atol=0)
    174       self.assertAllClose(
    175           sample_values.var(),
    176           sigma_v**2 * df_v / (df_v - 2),
    177           rtol=1e-2,
    178           atol=0)
    179       self._checkKLApprox(df_v, mu_v, sigma_v, sample_values)
    180 
    181   # Test that sampling with the same seed twice gives the same results.
    182   def testStudentSampleMultipleTimes(self):
    183     with self.test_session():
    184       df = constant_op.constant(4.)
    185       mu = constant_op.constant(3.)
    186       sigma = constant_op.constant(math.sqrt(10.))
    187       n = constant_op.constant(100)
    188 
    189       random_seed.set_random_seed(654321)
    190       student = student_t.StudentT(
    191           df=df, loc=mu, scale=sigma, name="student_t1")
    192       samples1 = student.sample(n, seed=123456).eval()
    193 
    194       random_seed.set_random_seed(654321)
    195       student2 = student_t.StudentT(
    196           df=df, loc=mu, scale=sigma, name="student_t2")
    197       samples2 = student2.sample(n, seed=123456).eval()
    198 
    199       self.assertAllClose(samples1, samples2)
    200 
    201   def testStudentSampleSmallDfNoNan(self):
    202     with self.test_session():
    203       df_v = [1e-1, 1e-5, 1e-10, 1e-20]
    204       df = constant_op.constant(df_v)
    205       n = constant_op.constant(200000)
    206       student = student_t.StudentT(df=df, loc=1., scale=1.)
    207       samples = student.sample(n, seed=123456)
    208       sample_values = samples.eval()
    209       n_val = 200000
    210       self.assertEqual(sample_values.shape, (n_val, 4))
    211       self.assertTrue(np.all(np.logical_not(np.isnan(sample_values))))
    212 
    213   def testStudentSampleMultiDimensional(self):
    214     with self.test_session():
    215       batch_size = 7
    216       df = constant_op.constant([[3., 7.]] * batch_size)
    217       mu = constant_op.constant([[3., -3.]] * batch_size)
    218       sigma = constant_op.constant([[math.sqrt(10.), math.sqrt(15.)]] *
    219                                    batch_size)
    220       df_v = [3., 7.]
    221       mu_v = [3., -3.]
    222       sigma_v = [np.sqrt(10.), np.sqrt(15.)]
    223       n = constant_op.constant(200000)
    224       student = student_t.StudentT(df=df, loc=mu, scale=sigma)
    225       samples = student.sample(n, seed=123456)
    226       sample_values = samples.eval()
    227       self.assertEqual(samples.get_shape(), (200000, batch_size, 2))
    228       self.assertAllClose(
    229           sample_values[:, 0, 0].mean(), mu_v[0], rtol=1e-2, atol=0)
    230       self.assertAllClose(
    231           sample_values[:, 0, 0].var(),
    232           sigma_v[0]**2 * df_v[0] / (df_v[0] - 2),
    233           rtol=1e-1,
    234           atol=0)
    235       self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0])
    236       self.assertAllClose(
    237           sample_values[:, 0, 1].mean(), mu_v[1], rtol=1e-2, atol=0)
    238       self.assertAllClose(
    239           sample_values[:, 0, 1].var(),
    240           sigma_v[1]**2 * df_v[1] / (df_v[1] - 2),
    241           rtol=1e-1,
    242           atol=0)
    243       self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 1])
    244 
    245   def _checkKLApprox(self, df, mu, sigma, samples):
    246     n = samples.size
    247     np.random.seed(137)
    248     if not stats:
    249       return
    250     sample_scipy = stats.t.rvs(df, loc=mu, scale=sigma, size=n)
    251     covg = 0.99
    252     r = stats.t.interval(covg, df, loc=mu, scale=sigma)
    253     bins = 100
    254     hist, _ = np.histogram(samples, bins=bins, range=r)
    255     hist_scipy, _ = np.histogram(sample_scipy, bins=bins, range=r)
    256     self.assertGreater(hist.sum(), n * (covg - .01))
    257     self.assertGreater(hist_scipy.sum(), n * (covg - .01))
    258     hist_min1 = hist + 1.  # put at least one item in each bucket
    259     hist_norm = hist_min1 / hist_min1.sum()
    260     hist_scipy_min1 = hist_scipy + 1.  # put at least one item in each bucket
    261     hist_scipy_norm = hist_scipy_min1 / hist_scipy_min1.sum()
    262     kl_appx = np.sum(np.log(hist_scipy_norm / hist_norm) * hist_scipy_norm)
    263     self.assertLess(kl_appx, 1)
    264 
    265   def testBroadcastingParams(self):
    266 
    267     def _check(student):
    268       self.assertEqual(student.mean().get_shape(), (3,))
    269       self.assertEqual(student.variance().get_shape(), (3,))
    270       self.assertEqual(student.entropy().get_shape(), (3,))
    271       self.assertEqual(student.log_prob(2.).get_shape(), (3,))
    272       self.assertEqual(student.prob(2.).get_shape(), (3,))
    273       self.assertEqual(student.sample(37, seed=123456).get_shape(), (37, 3,))
    274 
    275     _check(student_t.StudentT(df=[2., 3., 4.,], loc=2., scale=1.))
    276     _check(student_t.StudentT(df=7., loc=[2., 3., 4.,], scale=1.))
    277     _check(student_t.StudentT(df=7., loc=3., scale=[2., 3., 4.,]))
    278 
    279   def testBroadcastingPdfArgs(self):
    280 
    281     def _assert_shape(student, arg, shape):
    282       self.assertEqual(student.log_prob(arg).get_shape(), shape)
    283       self.assertEqual(student.prob(arg).get_shape(), shape)
    284 
    285     def _check(student):
    286       _assert_shape(student, 2., (3,))
    287       xs = np.array([2., 3., 4.], dtype=np.float32)
    288       _assert_shape(student, xs, (3,))
    289       xs = np.array([xs])
    290       _assert_shape(student, xs, (1, 3))
    291       xs = xs.T
    292       _assert_shape(student, xs, (3, 3))
    293 
    294     _check(student_t.StudentT(df=[2., 3., 4.,], loc=2., scale=1.))
    295     _check(student_t.StudentT(df=7., loc=[2., 3., 4.,], scale=1.))
    296     _check(student_t.StudentT(df=7., loc=3., scale=[2., 3., 4.,]))
    297 
    298     def _check2d(student):
    299       _assert_shape(student, 2., (1, 3))
    300       xs = np.array([2., 3., 4.], dtype=np.float32)
    301       _assert_shape(student, xs, (1, 3))
    302       xs = np.array([xs])
    303       _assert_shape(student, xs, (1, 3))
    304       xs = xs.T
    305       _assert_shape(student, xs, (3, 3))
    306 
    307     _check2d(student_t.StudentT(df=[[2., 3., 4.,]], loc=2., scale=1.))
    308     _check2d(student_t.StudentT(df=7., loc=[[2., 3., 4.,]], scale=1.))
    309     _check2d(student_t.StudentT(df=7., loc=3., scale=[[2., 3., 4.,]]))
    310 
    311     def _check2d_rows(student):
    312       _assert_shape(student, 2., (3, 1))
    313       xs = np.array([2., 3., 4.], dtype=np.float32)  # (3,)
    314       _assert_shape(student, xs, (3, 3))
    315       xs = np.array([xs])  # (1,3)
    316       _assert_shape(student, xs, (3, 3))
    317       xs = xs.T  # (3,1)
    318       _assert_shape(student, xs, (3, 1))
    319 
    320     _check2d_rows(student_t.StudentT(df=[[2.], [3.], [4.]], loc=2., scale=1.))
    321     _check2d_rows(student_t.StudentT(df=7., loc=[[2.], [3.], [4.]], scale=1.))
    322     _check2d_rows(student_t.StudentT(df=7., loc=3., scale=[[2.], [3.], [4.]]))
    323 
    324   def testMeanAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self):
    325     with self.test_session():
    326       mu = [1., 3.3, 4.4]
    327       student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.])
    328       mean = student.mean().eval()
    329       self.assertAllClose([1., 3.3, 4.4], mean)
    330 
    331   def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self):
    332     with self.test_session():
    333       mu = [1., 3.3, 4.4]
    334       student = student_t.StudentT(
    335           df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.],
    336           allow_nan_stats=False)
    337       with self.assertRaisesOpError("x < y"):
    338         student.mean().eval()
    339 
    340   def testMeanAllowNanStatsIsTrueReturnsNaNForUndefinedBatchMembers(self):
    341     with self.test_session():
    342       mu = [-2, 0., 1., 3.3, 4.4]
    343       sigma = [5., 4., 3., 2., 1.]
    344       student = student_t.StudentT(
    345           df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma,
    346           allow_nan_stats=True)
    347       mean = student.mean().eval()
    348       self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean)
    349 
    350   def testVarianceAllowNanStatsTrueReturnsNaNforUndefinedBatchMembers(self):
    351     with self.test_session():
    352       # df = 0.5 ==> undefined mean ==> undefined variance.
    353       # df = 1.5 ==> infinite variance.
    354       df = [0.5, 1.5, 3., 5., 7.]
    355       mu = [-2, 0., 1., 3.3, 4.4]
    356       sigma = [5., 4., 3., 2., 1.]
    357       student = student_t.StudentT(
    358           df=df, loc=mu, scale=sigma, allow_nan_stats=True)
    359       var = student.variance().eval()
    360       ## scipy uses inf for variance when the mean is undefined.  When mean is
    361       # undefined we say variance is undefined as well.  So test the first
    362       # member of var, making sure it is NaN, then replace with inf and compare
    363       # to scipy.
    364       self.assertTrue(np.isnan(var[0]))
    365       var[0] = np.inf
    366 
    367       if not stats:
    368         return
    369       expected_var = [
    370           stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
    371       ]
    372       self.assertAllClose(expected_var, var)
    373 
    374   def testVarianceAllowNanStatsFalseGivesCorrectValueForDefinedBatchMembers(
    375       self):
    376     with self.test_session():
    377       # df = 1.5 ==> infinite variance.
    378       df = [1.5, 3., 5., 7.]
    379       mu = [0., 1., 3.3, 4.4]
    380       sigma = [4., 3., 2., 1.]
    381       student = student_t.StudentT(df=df, loc=mu, scale=sigma)
    382       var = student.variance().eval()
    383 
    384       if not stats:
    385         return
    386       expected_var = [
    387           stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
    388       ]
    389       self.assertAllClose(expected_var, var)
    390 
    391   def testVarianceAllowNanStatsFalseRaisesForUndefinedBatchMembers(self):
    392     with self.test_session():
    393       # df <= 1 ==> variance not defined
    394       student = student_t.StudentT(
    395           df=1., loc=0., scale=1., allow_nan_stats=False)
    396       with self.assertRaisesOpError("x < y"):
    397         student.variance().eval()
    398 
    399     with self.test_session():
    400       # df <= 1 ==> variance not defined
    401       student = student_t.StudentT(
    402           df=0.5, loc=0., scale=1., allow_nan_stats=False)
    403       with self.assertRaisesOpError("x < y"):
    404         student.variance().eval()
    405 
    406   def testStd(self):
    407     with self.test_session():
    408       # Defined for all batch members.
    409       df = [3.5, 5., 3., 5., 7.]
    410       mu = [-2.2]
    411       sigma = [5., 4., 3., 2., 1.]
    412       student = student_t.StudentT(df=df, loc=mu, scale=sigma)
    413       # Test broadcast of mu across shape of df/sigma
    414       stddev = student.stddev().eval()
    415       mu *= len(df)
    416 
    417       if not stats:
    418         return
    419       expected_stddev = [
    420           stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
    421       ]
    422       self.assertAllClose(expected_stddev, stddev)
    423 
    424   def testMode(self):
    425     with self.test_session():
    426       df = [0.5, 1., 3]
    427       mu = [-1, 0., 1]
    428       sigma = [5., 4., 3.]
    429       student = student_t.StudentT(df=df, loc=mu, scale=sigma)
    430       # Test broadcast of mu across shape of df/sigma
    431       mode = student.mode().eval()
    432       self.assertAllClose([-1., 0, 1], mode)
    433 
    434   def testPdfOfSample(self):
    435     with self.test_session() as sess:
    436       student = student_t.StudentT(df=3., loc=np.pi, scale=1.)
    437       num = 20000
    438       samples = student.sample(num, seed=123456)
    439       pdfs = student.prob(samples)
    440       mean = student.mean()
    441       mean_pdf = student.prob(student.mean())
    442       sample_vals, pdf_vals, mean_val, mean_pdf_val = sess.run(
    443           [samples, pdfs, student.mean(), mean_pdf])
    444       self.assertEqual(samples.get_shape(), (num,))
    445       self.assertEqual(pdfs.get_shape(), (num,))
    446       self.assertEqual(mean.get_shape(), ())
    447       self.assertNear(np.pi, np.mean(sample_vals), err=0.02)
    448       self.assertNear(np.pi, mean_val, err=1e-6)
    449       # Verify integral over sample*pdf ~= 1.
    450       self._assertIntegral(sample_vals, pdf_vals, err=2e-3)
    451       if not stats:
    452         return
    453       self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi), mean_pdf_val, err=1e-6)
    454 
    455   def testPdfOfSampleMultiDims(self):
    456     with self.test_session() as sess:
    457       student = student_t.StudentT(df=[7., 11.], loc=[[5.], [6.]], scale=3.)
    458       self.assertAllEqual([], student.event_shape)
    459       self.assertAllEqual([], student.event_shape_tensor().eval())
    460       self.assertAllEqual([2, 2], student.batch_shape)
    461       self.assertAllEqual([2, 2], student.batch_shape_tensor().eval())
    462       num = 50000
    463       samples = student.sample(num, seed=123456)
    464       pdfs = student.prob(samples)
    465       sample_vals, pdf_vals = sess.run([samples, pdfs])
    466       self.assertEqual(samples.get_shape(), (num, 2, 2))
    467       self.assertEqual(pdfs.get_shape(), (num, 2, 2))
    468       self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=.03)
    469       self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=.03)
    470       self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
    471       self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
    472       self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
    473       self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
    474       if not stats:
    475         return
    476       self.assertNear(
    477           stats.t.var(7., loc=0., scale=3.),  # loc d.n. effect var
    478           np.var(sample_vals[:, :, 0]),
    479           err=.4)
    480       self.assertNear(
    481           stats.t.var(11., loc=0., scale=3.),  # loc d.n. effect var
    482           np.var(sample_vals[:, :, 1]),
    483           err=.4)
    484 
    485   def _assertIntegral(self, sample_vals, pdf_vals, err=1.5e-3):
    486     s_p = zip(sample_vals, pdf_vals)
    487     prev = (sample_vals.min() - 1000, 0)
    488     total = 0
    489     for k in sorted(s_p, key=lambda x: x[0]):
    490       pair_pdf = (k[1] + prev[1]) / 2
    491       total += (k[0] - prev[0]) * pair_pdf
    492       prev = k
    493     self.assertNear(1., total, err=err)
    494 
    495   def testNegativeDofFails(self):
    496     with self.test_session():
    497       student = student_t.StudentT(df=[2, -5.], loc=0., scale=1.,
    498                                    validate_args=True, name="S")
    499       with self.assertRaisesOpError(r"Condition x > 0 did not hold"):
    500         student.mean().eval()
    501 
    502   def testStudentTWithAbsDfSoftplusScale(self):
    503     with self.test_session():
    504       df = constant_op.constant([-3.2, -4.6])
    505       mu = constant_op.constant([-4.2, 3.4])
    506       sigma = constant_op.constant([-6.4, -8.8])
    507       student = student_t.StudentTWithAbsDfSoftplusScale(
    508           df=df, loc=mu, scale=sigma)
    509       self.assertAllClose(
    510           math_ops.floor(math_ops.abs(df)).eval(), student.df.eval())
    511       self.assertAllClose(mu.eval(), student.loc.eval())
    512       self.assertAllClose(nn_ops.softplus(sigma).eval(), student.scale.eval())
    513 
    514 
    515 if __name__ == "__main__":
    516   test.main()
    517