Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for MultivariateNormal."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 from scipy import stats
     23 from tensorflow.contrib import distributions
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.ops import math_ops
     27 from tensorflow.python.ops import nn_ops
     28 from tensorflow.python.platform import test
     29 from tensorflow.python.platform import tf_logging as logging
     30 
     31 
     32 ds = distributions
     33 
     34 
     35 class MultivariateNormalTriLTest(test.TestCase):
     36 
     37   def setUp(self):
     38     self._rng = np.random.RandomState(42)
     39 
     40   def _random_chol(self, *shape):
     41     mat = self._rng.rand(*shape)
     42     chol = ds.matrix_diag_transform(mat, transform=nn_ops.softplus)
     43     chol = array_ops.matrix_band_part(chol, -1, 0)
     44     sigma = math_ops.matmul(chol, chol, adjoint_b=True)
     45     return chol.eval(), sigma.eval()
     46 
     47   def testLogPDFScalarBatch(self):
     48     with self.test_session():
     49       mu = self._rng.rand(2)
     50       chol, sigma = self._random_chol(2, 2)
     51       chol[1, 1] = -chol[1, 1]
     52       mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
     53       x = self._rng.rand(2)
     54 
     55       log_pdf = mvn.log_prob(x)
     56       pdf = mvn.prob(x)
     57 
     58       scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma)
     59 
     60       expected_log_pdf = scipy_mvn.logpdf(x)
     61       expected_pdf = scipy_mvn.pdf(x)
     62       self.assertEqual((), log_pdf.get_shape())
     63       self.assertEqual((), pdf.get_shape())
     64       self.assertAllClose(expected_log_pdf, log_pdf.eval())
     65       self.assertAllClose(expected_pdf, pdf.eval())
     66 
     67   def testLogPDFXIsHigherRank(self):
     68     with self.test_session():
     69       mu = self._rng.rand(2)
     70       chol, sigma = self._random_chol(2, 2)
     71       chol[0, 0] = -chol[0, 0]
     72       mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
     73       x = self._rng.rand(3, 2)
     74 
     75       log_pdf = mvn.log_prob(x)
     76       pdf = mvn.prob(x)
     77 
     78       scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma)
     79 
     80       expected_log_pdf = scipy_mvn.logpdf(x)
     81       expected_pdf = scipy_mvn.pdf(x)
     82       self.assertEqual((3,), log_pdf.get_shape())
     83       self.assertEqual((3,), pdf.get_shape())
     84       self.assertAllClose(expected_log_pdf, log_pdf.eval(), atol=0., rtol=0.02)
     85       self.assertAllClose(expected_pdf, pdf.eval(), atol=0., rtol=0.03)
     86 
     87   def testLogPDFXLowerDimension(self):
     88     with self.test_session():
     89       mu = self._rng.rand(3, 2)
     90       chol, sigma = self._random_chol(3, 2, 2)
     91       chol[0, 0, 0] = -chol[0, 0, 0]
     92       chol[2, 1, 1] = -chol[2, 1, 1]
     93       mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
     94       x = self._rng.rand(2)
     95 
     96       log_pdf = mvn.log_prob(x)
     97       pdf = mvn.prob(x)
     98 
     99       self.assertEqual((3,), log_pdf.get_shape())
    100       self.assertEqual((3,), pdf.get_shape())
    101 
    102       # scipy can't do batches, so just test one of them.
    103       scipy_mvn = stats.multivariate_normal(mean=mu[1, :], cov=sigma[1, :, :])
    104       expected_log_pdf = scipy_mvn.logpdf(x)
    105       expected_pdf = scipy_mvn.pdf(x)
    106 
    107       self.assertAllClose(expected_log_pdf, log_pdf.eval()[1])
    108       self.assertAllClose(expected_pdf, pdf.eval()[1])
    109 
    110   def testEntropy(self):
    111     with self.test_session():
    112       mu = self._rng.rand(2)
    113       chol, sigma = self._random_chol(2, 2)
    114       chol[0, 0] = -chol[0, 0]
    115       mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
    116       entropy = mvn.entropy()
    117 
    118       scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma)
    119       expected_entropy = scipy_mvn.entropy()
    120       self.assertEqual(entropy.get_shape(), ())
    121       self.assertAllClose(expected_entropy, entropy.eval())
    122 
    123   def testEntropyMultidimensional(self):
    124     with self.test_session():
    125       mu = self._rng.rand(3, 5, 2)
    126       chol, sigma = self._random_chol(3, 5, 2, 2)
    127       chol[1, 0, 0, 0] = -chol[1, 0, 0, 0]
    128       chol[2, 3, 1, 1] = -chol[2, 3, 1, 1]
    129       mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
    130       entropy = mvn.entropy()
    131 
    132       # Scipy doesn't do batches, so test one of them.
    133       expected_entropy = stats.multivariate_normal(
    134           mean=mu[1, 1, :], cov=sigma[1, 1, :, :]).entropy()
    135       self.assertEqual(entropy.get_shape(), (3, 5))
    136       self.assertAllClose(expected_entropy, entropy.eval()[1, 1])
    137 
    138   def testSample(self):
    139     with self.test_session():
    140       mu = self._rng.rand(2)
    141       chol, sigma = self._random_chol(2, 2)
    142       chol[0, 0] = -chol[0, 0]
    143       sigma[0, 1] = -sigma[0, 1]
    144       sigma[1, 0] = -sigma[1, 0]
    145 
    146       n = constant_op.constant(100000)
    147       mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
    148       samples = mvn.sample(n, seed=137)
    149       sample_values = samples.eval()
    150       self.assertEqual(samples.get_shape(), [int(100e3), 2])
    151       self.assertAllClose(sample_values.mean(axis=0), mu, atol=1e-2)
    152       self.assertAllClose(np.cov(sample_values, rowvar=0), sigma, atol=0.06)
    153 
    154   def testSingularScaleRaises(self):
    155     with self.test_session():
    156       mu = None
    157       chol = [[1., 0.], [0., 0.]]
    158       mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
    159       with self.assertRaisesOpError("Singular operator"):
    160         mvn.sample().eval()
    161 
    162   def testSampleWithSampleShape(self):
    163     with self.test_session():
    164       mu = self._rng.rand(3, 5, 2)
    165       chol, sigma = self._random_chol(3, 5, 2, 2)
    166       chol[1, 0, 0, 0] = -chol[1, 0, 0, 0]
    167       chol[2, 3, 1, 1] = -chol[2, 3, 1, 1]
    168 
    169       mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
    170       samples_val = mvn.sample((10, 11, 12), seed=137).eval()
    171 
    172       # Check sample shape
    173       self.assertEqual((10, 11, 12, 3, 5, 2), samples_val.shape)
    174 
    175       # Check sample means
    176       x = samples_val[:, :, :, 1, 1, :]
    177       self.assertAllClose(
    178           x.reshape(10 * 11 * 12, 2).mean(axis=0), mu[1, 1], atol=0.05)
    179 
    180       # Check that log_prob(samples) works
    181       log_prob_val = mvn.log_prob(samples_val).eval()
    182       x_log_pdf = log_prob_val[:, :, :, 1, 1]
    183       expected_log_pdf = stats.multivariate_normal(
    184           mean=mu[1, 1, :], cov=sigma[1, 1, :, :]).logpdf(x)
    185       self.assertAllClose(expected_log_pdf, x_log_pdf)
    186 
    187   def testSampleMultiDimensional(self):
    188     with self.test_session():
    189       mu = self._rng.rand(3, 5, 2)
    190       chol, sigma = self._random_chol(3, 5, 2, 2)
    191       chol[1, 0, 0, 0] = -chol[1, 0, 0, 0]
    192       chol[2, 3, 1, 1] = -chol[2, 3, 1, 1]
    193 
    194       mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
    195       n = constant_op.constant(100000)
    196       samples = mvn.sample(n, seed=137)
    197       sample_values = samples.eval()
    198 
    199       self.assertEqual(samples.get_shape(), (100000, 3, 5, 2))
    200       self.assertAllClose(
    201           sample_values[:, 1, 1, :].mean(axis=0), mu[1, 1, :], atol=0.05)
    202       self.assertAllClose(
    203           np.cov(sample_values[:, 1, 1, :], rowvar=0),
    204           sigma[1, 1, :, :],
    205           atol=1e-1)
    206 
    207   def testShapes(self):
    208     with self.test_session():
    209       mu = self._rng.rand(3, 5, 2)
    210       chol, _ = self._random_chol(3, 5, 2, 2)
    211       chol[1, 0, 0, 0] = -chol[1, 0, 0, 0]
    212       chol[2, 3, 1, 1] = -chol[2, 3, 1, 1]
    213 
    214       mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
    215 
    216       # Shapes known at graph construction time.
    217       self.assertEqual((2,), tuple(mvn.event_shape.as_list()))
    218       self.assertEqual((3, 5), tuple(mvn.batch_shape.as_list()))
    219 
    220       # Shapes known at runtime.
    221       self.assertEqual((2,), tuple(mvn.event_shape_tensor().eval()))
    222       self.assertEqual((3, 5), tuple(mvn.batch_shape_tensor().eval()))
    223 
    224   def _random_mu_and_sigma(self, batch_shape, event_shape):
    225     # This ensures sigma is positive def.
    226     mat_shape = batch_shape + event_shape + event_shape
    227     mat = self._rng.randn(*mat_shape)
    228     perm = np.arange(mat.ndim)
    229     perm[-2:] = [perm[-1], perm[-2]]
    230     sigma = np.matmul(mat, np.transpose(mat, perm))
    231 
    232     mu_shape = batch_shape + event_shape
    233     mu = self._rng.randn(*mu_shape)
    234 
    235     return mu, sigma
    236 
    237   def testKLNonBatch(self):
    238     batch_shape = ()
    239     event_shape = (2,)
    240     with self.test_session():
    241       mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
    242       mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
    243       mvn_a = ds.MultivariateNormalTriL(
    244           loc=mu_a,
    245           scale_tril=np.linalg.cholesky(sigma_a),
    246           validate_args=True)
    247       mvn_b = ds.MultivariateNormalTriL(
    248           loc=mu_b,
    249           scale_tril=np.linalg.cholesky(sigma_b),
    250           validate_args=True)
    251 
    252       kl = ds.kl_divergence(mvn_a, mvn_b)
    253       self.assertEqual(batch_shape, kl.get_shape())
    254 
    255       kl_v = kl.eval()
    256       expected_kl = _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b)
    257       self.assertAllClose(expected_kl, kl_v)
    258 
    259   def testKLBatch(self):
    260     batch_shape = (2,)
    261     event_shape = (3,)
    262     with self.test_session():
    263       mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
    264       mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape)
    265       mvn_a = ds.MultivariateNormalTriL(
    266           loc=mu_a,
    267           scale_tril=np.linalg.cholesky(sigma_a),
    268           validate_args=True)
    269       mvn_b = ds.MultivariateNormalTriL(
    270           loc=mu_b,
    271           scale_tril=np.linalg.cholesky(sigma_b),
    272           validate_args=True)
    273 
    274       kl = ds.kl_divergence(mvn_a, mvn_b)
    275       self.assertEqual(batch_shape, kl.get_shape())
    276 
    277       kl_v = kl.eval()
    278       expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :],
    279                                             mu_b[0, :], sigma_b[0, :])
    280       expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :],
    281                                             mu_b[1, :], sigma_b[1, :])
    282       self.assertAllClose(expected_kl_0, kl_v[0])
    283       self.assertAllClose(expected_kl_1, kl_v[1])
    284 
    285   def testKLTwoIdenticalDistributionsIsZero(self):
    286     batch_shape = (2,)
    287     event_shape = (3,)
    288     with self.test_session():
    289       mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape)
    290       mvn_a = ds.MultivariateNormalTriL(
    291           loc=mu_a,
    292           scale_tril=np.linalg.cholesky(sigma_a),
    293           validate_args=True)
    294 
    295       # Should be zero since KL(p || p) = =.
    296       kl = ds.kl_divergence(mvn_a, mvn_a)
    297       self.assertEqual(batch_shape, kl.get_shape())
    298 
    299       kl_v = kl.eval()
    300       self.assertAllClose(np.zeros(*batch_shape), kl_v)
    301 
    302   def testSampleLarge(self):
    303     mu = np.array([-1., 1], dtype=np.float32)
    304     scale_tril = np.array([[3., 0], [1, -2]], dtype=np.float32) / 3.
    305 
    306     true_mean = mu
    307     true_scale = scale_tril
    308     true_covariance = np.matmul(true_scale, true_scale.T)
    309     true_variance = np.diag(true_covariance)
    310     true_stddev = np.sqrt(true_variance)
    311 
    312     with self.test_session() as sess:
    313       dist = ds.MultivariateNormalTriL(
    314           loc=mu,
    315           scale_tril=scale_tril,
    316           validate_args=True)
    317 
    318       # The following distributions will test the KL divergence calculation.
    319       mvn_chol = ds.MultivariateNormalTriL(
    320           loc=np.array([0.5, 1.2], dtype=np.float32),
    321           scale_tril=np.array([[3., 0], [1, 2]], dtype=np.float32),
    322           validate_args=True)
    323 
    324       n = int(10e3)
    325       samps = dist.sample(n, seed=0)
    326       sample_mean = math_ops.reduce_mean(samps, 0)
    327       x = samps - sample_mean
    328       sample_covariance = math_ops.matmul(x, x, transpose_a=True) / n
    329 
    330       sample_kl_chol = math_ops.reduce_mean(
    331           dist.log_prob(samps) - mvn_chol.log_prob(samps), 0)
    332       analytical_kl_chol = ds.kl_divergence(dist, mvn_chol)
    333 
    334       scale = dist.scale.to_dense()
    335 
    336       [
    337           sample_mean_,
    338           analytical_mean_,
    339           sample_covariance_,
    340           analytical_covariance_,
    341           analytical_variance_,
    342           analytical_stddev_,
    343           sample_kl_chol_, analytical_kl_chol_,
    344           scale_,
    345       ] = sess.run([
    346           sample_mean,
    347           dist.mean(),
    348           sample_covariance,
    349           dist.covariance(),
    350           dist.variance(),
    351           dist.stddev(),
    352           sample_kl_chol, analytical_kl_chol,
    353           scale,
    354       ])
    355 
    356       sample_variance_ = np.diag(sample_covariance_)
    357       sample_stddev_ = np.sqrt(sample_variance_)
    358 
    359       logging.vlog(2, "true_mean:\n{}  ".format(true_mean))
    360       logging.vlog(2, "sample_mean:\n{}".format(sample_mean_))
    361       logging.vlog(2, "analytical_mean:\n{}".format(analytical_mean_))
    362 
    363       logging.vlog(2, "true_covariance:\n{}".format(true_covariance))
    364       logging.vlog(2, "sample_covariance:\n{}".format(sample_covariance_))
    365       logging.vlog(
    366           2, "analytical_covariance:\n{}".format(analytical_covariance_))
    367 
    368       logging.vlog(2, "true_variance:\n{}".format(true_variance))
    369       logging.vlog(2, "sample_variance:\n{}".format(sample_variance_))
    370       logging.vlog(2, "analytical_variance:\n{}".format(analytical_variance_))
    371 
    372       logging.vlog(2, "true_stddev:\n{}".format(true_stddev))
    373       logging.vlog(2, "sample_stddev:\n{}".format(sample_stddev_))
    374       logging.vlog(2, "analytical_stddev:\n{}".format(analytical_stddev_))
    375 
    376       logging.vlog(2, "true_scale:\n{}".format(true_scale))
    377       logging.vlog(2, "scale:\n{}".format(scale_))
    378 
    379       logging.vlog(2, "kl_chol:      analytical:{}  sample:{}".format(
    380           analytical_kl_chol_, sample_kl_chol_))
    381 
    382       self.assertAllClose(true_mean, sample_mean_,
    383                           atol=0., rtol=0.03)
    384       self.assertAllClose(true_mean, analytical_mean_,
    385                           atol=0., rtol=1e-6)
    386 
    387       self.assertAllClose(true_covariance, sample_covariance_,
    388                           atol=0., rtol=0.03)
    389       self.assertAllClose(true_covariance, analytical_covariance_,
    390                           atol=0., rtol=1e-6)
    391 
    392       self.assertAllClose(true_variance, sample_variance_,
    393                           atol=0., rtol=0.02)
    394       self.assertAllClose(true_variance, analytical_variance_,
    395                           atol=0., rtol=1e-6)
    396 
    397       self.assertAllClose(true_stddev, sample_stddev_,
    398                           atol=0., rtol=0.01)
    399       self.assertAllClose(true_stddev, analytical_stddev_,
    400                           atol=0., rtol=1e-6)
    401 
    402       self.assertAllClose(true_scale, scale_,
    403                           atol=0., rtol=1e-6)
    404 
    405       self.assertAllClose(sample_kl_chol_, analytical_kl_chol_,
    406                           atol=0., rtol=0.02)
    407 
    408 
    409 def _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b):
    410   """Non-batch KL for N(mu_a, sigma_a), N(mu_b, sigma_b)."""
    411   # Check using numpy operations
    412   # This mostly repeats the tensorflow code _kl_mvn_mvn(), but in numpy.
    413   # So it is important to also check that KL(mvn, mvn) = 0.
    414   sigma_b_inv = np.linalg.inv(sigma_b)
    415 
    416   t = np.trace(sigma_b_inv.dot(sigma_a))
    417   q = (mu_b - mu_a).dot(sigma_b_inv).dot(mu_b - mu_a)
    418   k = mu_a.shape[0]
    419   l = np.log(np.linalg.det(sigma_b) / np.linalg.det(sigma_a))
    420 
    421   return 0.5 * (t + q - k + l)
    422 
    423 
    424 if __name__ == "__main__":
    425   test.main()
    426