1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for MultivariateNormal.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 from scipy import stats 23 from tensorflow.contrib import distributions 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.ops import array_ops 26 from tensorflow.python.ops import math_ops 27 from tensorflow.python.ops import nn_ops 28 from tensorflow.python.platform import test 29 from tensorflow.python.platform import tf_logging as logging 30 31 32 ds = distributions 33 34 35 class MultivariateNormalTriLTest(test.TestCase): 36 37 def setUp(self): 38 self._rng = np.random.RandomState(42) 39 40 def _random_chol(self, *shape): 41 mat = self._rng.rand(*shape) 42 chol = ds.matrix_diag_transform(mat, transform=nn_ops.softplus) 43 chol = array_ops.matrix_band_part(chol, -1, 0) 44 sigma = math_ops.matmul(chol, chol, adjoint_b=True) 45 return chol.eval(), sigma.eval() 46 47 def testLogPDFScalarBatch(self): 48 with self.test_session(): 49 mu = self._rng.rand(2) 50 chol, sigma = self._random_chol(2, 2) 51 chol[1, 1] = -chol[1, 1] 52 mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) 53 x = self._rng.rand(2) 54 55 log_pdf = mvn.log_prob(x) 56 pdf = mvn.prob(x) 57 58 scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma) 59 60 expected_log_pdf = scipy_mvn.logpdf(x) 61 expected_pdf = scipy_mvn.pdf(x) 62 self.assertEqual((), log_pdf.get_shape()) 63 self.assertEqual((), pdf.get_shape()) 64 self.assertAllClose(expected_log_pdf, log_pdf.eval()) 65 self.assertAllClose(expected_pdf, pdf.eval()) 66 67 def testLogPDFXIsHigherRank(self): 68 with self.test_session(): 69 mu = self._rng.rand(2) 70 chol, sigma = self._random_chol(2, 2) 71 chol[0, 0] = -chol[0, 0] 72 mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) 73 x = self._rng.rand(3, 2) 74 75 log_pdf = mvn.log_prob(x) 76 pdf = mvn.prob(x) 77 78 scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma) 79 80 expected_log_pdf = scipy_mvn.logpdf(x) 81 expected_pdf = scipy_mvn.pdf(x) 82 self.assertEqual((3,), log_pdf.get_shape()) 83 self.assertEqual((3,), pdf.get_shape()) 84 self.assertAllClose(expected_log_pdf, log_pdf.eval(), atol=0., rtol=0.02) 85 self.assertAllClose(expected_pdf, pdf.eval(), atol=0., rtol=0.03) 86 87 def testLogPDFXLowerDimension(self): 88 with self.test_session(): 89 mu = self._rng.rand(3, 2) 90 chol, sigma = self._random_chol(3, 2, 2) 91 chol[0, 0, 0] = -chol[0, 0, 0] 92 chol[2, 1, 1] = -chol[2, 1, 1] 93 mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) 94 x = self._rng.rand(2) 95 96 log_pdf = mvn.log_prob(x) 97 pdf = mvn.prob(x) 98 99 self.assertEqual((3,), log_pdf.get_shape()) 100 self.assertEqual((3,), pdf.get_shape()) 101 102 # scipy can't do batches, so just test one of them. 103 scipy_mvn = stats.multivariate_normal(mean=mu[1, :], cov=sigma[1, :, :]) 104 expected_log_pdf = scipy_mvn.logpdf(x) 105 expected_pdf = scipy_mvn.pdf(x) 106 107 self.assertAllClose(expected_log_pdf, log_pdf.eval()[1]) 108 self.assertAllClose(expected_pdf, pdf.eval()[1]) 109 110 def testEntropy(self): 111 with self.test_session(): 112 mu = self._rng.rand(2) 113 chol, sigma = self._random_chol(2, 2) 114 chol[0, 0] = -chol[0, 0] 115 mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) 116 entropy = mvn.entropy() 117 118 scipy_mvn = stats.multivariate_normal(mean=mu, cov=sigma) 119 expected_entropy = scipy_mvn.entropy() 120 self.assertEqual(entropy.get_shape(), ()) 121 self.assertAllClose(expected_entropy, entropy.eval()) 122 123 def testEntropyMultidimensional(self): 124 with self.test_session(): 125 mu = self._rng.rand(3, 5, 2) 126 chol, sigma = self._random_chol(3, 5, 2, 2) 127 chol[1, 0, 0, 0] = -chol[1, 0, 0, 0] 128 chol[2, 3, 1, 1] = -chol[2, 3, 1, 1] 129 mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) 130 entropy = mvn.entropy() 131 132 # Scipy doesn't do batches, so test one of them. 133 expected_entropy = stats.multivariate_normal( 134 mean=mu[1, 1, :], cov=sigma[1, 1, :, :]).entropy() 135 self.assertEqual(entropy.get_shape(), (3, 5)) 136 self.assertAllClose(expected_entropy, entropy.eval()[1, 1]) 137 138 def testSample(self): 139 with self.test_session(): 140 mu = self._rng.rand(2) 141 chol, sigma = self._random_chol(2, 2) 142 chol[0, 0] = -chol[0, 0] 143 sigma[0, 1] = -sigma[0, 1] 144 sigma[1, 0] = -sigma[1, 0] 145 146 n = constant_op.constant(100000) 147 mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) 148 samples = mvn.sample(n, seed=137) 149 sample_values = samples.eval() 150 self.assertEqual(samples.get_shape(), [int(100e3), 2]) 151 self.assertAllClose(sample_values.mean(axis=0), mu, atol=1e-2) 152 self.assertAllClose(np.cov(sample_values, rowvar=0), sigma, atol=0.06) 153 154 def testSingularScaleRaises(self): 155 with self.test_session(): 156 mu = None 157 chol = [[1., 0.], [0., 0.]] 158 mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) 159 with self.assertRaisesOpError("Singular operator"): 160 mvn.sample().eval() 161 162 def testSampleWithSampleShape(self): 163 with self.test_session(): 164 mu = self._rng.rand(3, 5, 2) 165 chol, sigma = self._random_chol(3, 5, 2, 2) 166 chol[1, 0, 0, 0] = -chol[1, 0, 0, 0] 167 chol[2, 3, 1, 1] = -chol[2, 3, 1, 1] 168 169 mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) 170 samples_val = mvn.sample((10, 11, 12), seed=137).eval() 171 172 # Check sample shape 173 self.assertEqual((10, 11, 12, 3, 5, 2), samples_val.shape) 174 175 # Check sample means 176 x = samples_val[:, :, :, 1, 1, :] 177 self.assertAllClose( 178 x.reshape(10 * 11 * 12, 2).mean(axis=0), mu[1, 1], atol=0.05) 179 180 # Check that log_prob(samples) works 181 log_prob_val = mvn.log_prob(samples_val).eval() 182 x_log_pdf = log_prob_val[:, :, :, 1, 1] 183 expected_log_pdf = stats.multivariate_normal( 184 mean=mu[1, 1, :], cov=sigma[1, 1, :, :]).logpdf(x) 185 self.assertAllClose(expected_log_pdf, x_log_pdf) 186 187 def testSampleMultiDimensional(self): 188 with self.test_session(): 189 mu = self._rng.rand(3, 5, 2) 190 chol, sigma = self._random_chol(3, 5, 2, 2) 191 chol[1, 0, 0, 0] = -chol[1, 0, 0, 0] 192 chol[2, 3, 1, 1] = -chol[2, 3, 1, 1] 193 194 mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) 195 n = constant_op.constant(100000) 196 samples = mvn.sample(n, seed=137) 197 sample_values = samples.eval() 198 199 self.assertEqual(samples.get_shape(), (100000, 3, 5, 2)) 200 self.assertAllClose( 201 sample_values[:, 1, 1, :].mean(axis=0), mu[1, 1, :], atol=0.05) 202 self.assertAllClose( 203 np.cov(sample_values[:, 1, 1, :], rowvar=0), 204 sigma[1, 1, :, :], 205 atol=1e-1) 206 207 def testShapes(self): 208 with self.test_session(): 209 mu = self._rng.rand(3, 5, 2) 210 chol, _ = self._random_chol(3, 5, 2, 2) 211 chol[1, 0, 0, 0] = -chol[1, 0, 0, 0] 212 chol[2, 3, 1, 1] = -chol[2, 3, 1, 1] 213 214 mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True) 215 216 # Shapes known at graph construction time. 217 self.assertEqual((2,), tuple(mvn.event_shape.as_list())) 218 self.assertEqual((3, 5), tuple(mvn.batch_shape.as_list())) 219 220 # Shapes known at runtime. 221 self.assertEqual((2,), tuple(mvn.event_shape_tensor().eval())) 222 self.assertEqual((3, 5), tuple(mvn.batch_shape_tensor().eval())) 223 224 def _random_mu_and_sigma(self, batch_shape, event_shape): 225 # This ensures sigma is positive def. 226 mat_shape = batch_shape + event_shape + event_shape 227 mat = self._rng.randn(*mat_shape) 228 perm = np.arange(mat.ndim) 229 perm[-2:] = [perm[-1], perm[-2]] 230 sigma = np.matmul(mat, np.transpose(mat, perm)) 231 232 mu_shape = batch_shape + event_shape 233 mu = self._rng.randn(*mu_shape) 234 235 return mu, sigma 236 237 def testKLNonBatch(self): 238 batch_shape = () 239 event_shape = (2,) 240 with self.test_session(): 241 mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) 242 mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) 243 mvn_a = ds.MultivariateNormalTriL( 244 loc=mu_a, 245 scale_tril=np.linalg.cholesky(sigma_a), 246 validate_args=True) 247 mvn_b = ds.MultivariateNormalTriL( 248 loc=mu_b, 249 scale_tril=np.linalg.cholesky(sigma_b), 250 validate_args=True) 251 252 kl = ds.kl_divergence(mvn_a, mvn_b) 253 self.assertEqual(batch_shape, kl.get_shape()) 254 255 kl_v = kl.eval() 256 expected_kl = _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b) 257 self.assertAllClose(expected_kl, kl_v) 258 259 def testKLBatch(self): 260 batch_shape = (2,) 261 event_shape = (3,) 262 with self.test_session(): 263 mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) 264 mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) 265 mvn_a = ds.MultivariateNormalTriL( 266 loc=mu_a, 267 scale_tril=np.linalg.cholesky(sigma_a), 268 validate_args=True) 269 mvn_b = ds.MultivariateNormalTriL( 270 loc=mu_b, 271 scale_tril=np.linalg.cholesky(sigma_b), 272 validate_args=True) 273 274 kl = ds.kl_divergence(mvn_a, mvn_b) 275 self.assertEqual(batch_shape, kl.get_shape()) 276 277 kl_v = kl.eval() 278 expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :], 279 mu_b[0, :], sigma_b[0, :]) 280 expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :], 281 mu_b[1, :], sigma_b[1, :]) 282 self.assertAllClose(expected_kl_0, kl_v[0]) 283 self.assertAllClose(expected_kl_1, kl_v[1]) 284 285 def testKLTwoIdenticalDistributionsIsZero(self): 286 batch_shape = (2,) 287 event_shape = (3,) 288 with self.test_session(): 289 mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) 290 mvn_a = ds.MultivariateNormalTriL( 291 loc=mu_a, 292 scale_tril=np.linalg.cholesky(sigma_a), 293 validate_args=True) 294 295 # Should be zero since KL(p || p) = =. 296 kl = ds.kl_divergence(mvn_a, mvn_a) 297 self.assertEqual(batch_shape, kl.get_shape()) 298 299 kl_v = kl.eval() 300 self.assertAllClose(np.zeros(*batch_shape), kl_v) 301 302 def testSampleLarge(self): 303 mu = np.array([-1., 1], dtype=np.float32) 304 scale_tril = np.array([[3., 0], [1, -2]], dtype=np.float32) / 3. 305 306 true_mean = mu 307 true_scale = scale_tril 308 true_covariance = np.matmul(true_scale, true_scale.T) 309 true_variance = np.diag(true_covariance) 310 true_stddev = np.sqrt(true_variance) 311 312 with self.test_session() as sess: 313 dist = ds.MultivariateNormalTriL( 314 loc=mu, 315 scale_tril=scale_tril, 316 validate_args=True) 317 318 # The following distributions will test the KL divergence calculation. 319 mvn_chol = ds.MultivariateNormalTriL( 320 loc=np.array([0.5, 1.2], dtype=np.float32), 321 scale_tril=np.array([[3., 0], [1, 2]], dtype=np.float32), 322 validate_args=True) 323 324 n = int(10e3) 325 samps = dist.sample(n, seed=0) 326 sample_mean = math_ops.reduce_mean(samps, 0) 327 x = samps - sample_mean 328 sample_covariance = math_ops.matmul(x, x, transpose_a=True) / n 329 330 sample_kl_chol = math_ops.reduce_mean( 331 dist.log_prob(samps) - mvn_chol.log_prob(samps), 0) 332 analytical_kl_chol = ds.kl_divergence(dist, mvn_chol) 333 334 scale = dist.scale.to_dense() 335 336 [ 337 sample_mean_, 338 analytical_mean_, 339 sample_covariance_, 340 analytical_covariance_, 341 analytical_variance_, 342 analytical_stddev_, 343 sample_kl_chol_, analytical_kl_chol_, 344 scale_, 345 ] = sess.run([ 346 sample_mean, 347 dist.mean(), 348 sample_covariance, 349 dist.covariance(), 350 dist.variance(), 351 dist.stddev(), 352 sample_kl_chol, analytical_kl_chol, 353 scale, 354 ]) 355 356 sample_variance_ = np.diag(sample_covariance_) 357 sample_stddev_ = np.sqrt(sample_variance_) 358 359 logging.vlog(2, "true_mean:\n{} ".format(true_mean)) 360 logging.vlog(2, "sample_mean:\n{}".format(sample_mean_)) 361 logging.vlog(2, "analytical_mean:\n{}".format(analytical_mean_)) 362 363 logging.vlog(2, "true_covariance:\n{}".format(true_covariance)) 364 logging.vlog(2, "sample_covariance:\n{}".format(sample_covariance_)) 365 logging.vlog( 366 2, "analytical_covariance:\n{}".format(analytical_covariance_)) 367 368 logging.vlog(2, "true_variance:\n{}".format(true_variance)) 369 logging.vlog(2, "sample_variance:\n{}".format(sample_variance_)) 370 logging.vlog(2, "analytical_variance:\n{}".format(analytical_variance_)) 371 372 logging.vlog(2, "true_stddev:\n{}".format(true_stddev)) 373 logging.vlog(2, "sample_stddev:\n{}".format(sample_stddev_)) 374 logging.vlog(2, "analytical_stddev:\n{}".format(analytical_stddev_)) 375 376 logging.vlog(2, "true_scale:\n{}".format(true_scale)) 377 logging.vlog(2, "scale:\n{}".format(scale_)) 378 379 logging.vlog(2, "kl_chol: analytical:{} sample:{}".format( 380 analytical_kl_chol_, sample_kl_chol_)) 381 382 self.assertAllClose(true_mean, sample_mean_, 383 atol=0., rtol=0.03) 384 self.assertAllClose(true_mean, analytical_mean_, 385 atol=0., rtol=1e-6) 386 387 self.assertAllClose(true_covariance, sample_covariance_, 388 atol=0., rtol=0.03) 389 self.assertAllClose(true_covariance, analytical_covariance_, 390 atol=0., rtol=1e-6) 391 392 self.assertAllClose(true_variance, sample_variance_, 393 atol=0., rtol=0.02) 394 self.assertAllClose(true_variance, analytical_variance_, 395 atol=0., rtol=1e-6) 396 397 self.assertAllClose(true_stddev, sample_stddev_, 398 atol=0., rtol=0.01) 399 self.assertAllClose(true_stddev, analytical_stddev_, 400 atol=0., rtol=1e-6) 401 402 self.assertAllClose(true_scale, scale_, 403 atol=0., rtol=1e-6) 404 405 self.assertAllClose(sample_kl_chol_, analytical_kl_chol_, 406 atol=0., rtol=0.02) 407 408 409 def _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b): 410 """Non-batch KL for N(mu_a, sigma_a), N(mu_b, sigma_b).""" 411 # Check using numpy operations 412 # This mostly repeats the tensorflow code _kl_mvn_mvn(), but in numpy. 413 # So it is important to also check that KL(mvn, mvn) = 0. 414 sigma_b_inv = np.linalg.inv(sigma_b) 415 416 t = np.trace(sigma_b_inv.dot(sigma_a)) 417 q = (mu_b - mu_a).dot(sigma_b_inv).dot(mu_b - mu_a) 418 k = mu_a.shape[0] 419 l = np.log(np.linalg.det(sigma_b) / np.linalg.det(sigma_a)) 420 421 return 0.5 * (t + q - k + l) 422 423 424 if __name__ == "__main__": 425 test.main() 426