1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for Hamiltonian Monte Carlo.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import collections 22 23 import numpy as np 24 from scipy import stats 25 26 from tensorflow.contrib.bayesflow.python.ops import hmc 27 from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _compute_energy_change 28 from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _leapfrog_integrator 29 30 from tensorflow.contrib.distributions.python.ops import independent as independent_lib 31 from tensorflow.python.framework import ops 32 from tensorflow.python.framework import random_seed 33 from tensorflow.python.ops import array_ops 34 from tensorflow.python.ops import gen_linalg_ops 35 from tensorflow.python.ops import gradients_impl as gradients_ops 36 from tensorflow.python.ops import math_ops 37 from tensorflow.python.ops import random_ops 38 from tensorflow.python.ops.distributions import gamma as gamma_lib 39 from tensorflow.python.ops.distributions import normal as normal_lib 40 from tensorflow.python.platform import test 41 from tensorflow.python.platform import tf_logging as logging_ops 42 43 44 def _reduce_variance(x, axis=None, keepdims=False): 45 sample_mean = math_ops.reduce_mean(x, axis, keepdims=True) 46 return math_ops.reduce_mean( 47 math_ops.squared_difference(x, sample_mean), axis, keepdims) 48 49 50 class HMCTest(test.TestCase): 51 52 def setUp(self): 53 self._shape_param = 5. 54 self._rate_param = 10. 55 56 random_seed.set_random_seed(10003) 57 np.random.seed(10003) 58 59 def assertAllFinite(self, x): 60 self.assertAllEqual(np.ones_like(x).astype(bool), np.isfinite(x)) 61 62 def _log_gamma_log_prob(self, x, event_dims=()): 63 """Computes log-pdf of a log-gamma random variable. 64 65 Args: 66 x: Value of the random variable. 67 event_dims: Dimensions not to treat as independent. 68 69 Returns: 70 log_prob: The log-pdf up to a normalizing constant. 71 """ 72 return math_ops.reduce_sum(self._shape_param * x - 73 self._rate_param * math_ops.exp(x), 74 event_dims) 75 76 def _integrator_conserves_energy(self, x, independent_chain_ndims, sess, 77 feed_dict=None): 78 step_size = array_ops.placeholder(np.float32, [], name="step_size") 79 hmc_lf_steps = array_ops.placeholder(np.int32, [], name="hmc_lf_steps") 80 81 if feed_dict is None: 82 feed_dict = {} 83 feed_dict[hmc_lf_steps] = 1000 84 85 event_dims = math_ops.range(independent_chain_ndims, 86 array_ops.rank(x)) 87 88 m = random_ops.random_normal(array_ops.shape(x)) 89 log_prob_0 = self._log_gamma_log_prob(x, event_dims) 90 grad_0 = gradients_ops.gradients(log_prob_0, x) 91 old_energy = -log_prob_0 + 0.5 * math_ops.reduce_sum(m**2., event_dims) 92 93 new_m, _, log_prob_1, _ = _leapfrog_integrator( 94 current_momentums=[m], 95 target_log_prob_fn=lambda x: self._log_gamma_log_prob(x, event_dims), 96 current_state_parts=[x], 97 step_sizes=[step_size], 98 num_leapfrog_steps=hmc_lf_steps, 99 current_target_log_prob=log_prob_0, 100 current_grads_target_log_prob=grad_0) 101 new_m = new_m[0] 102 103 new_energy = -log_prob_1 + 0.5 * math_ops.reduce_sum(new_m * new_m, 104 event_dims) 105 106 x_shape = sess.run(x, feed_dict).shape 107 event_size = np.prod(x_shape[independent_chain_ndims:]) 108 feed_dict[step_size] = 0.1 / event_size 109 old_energy_, new_energy_ = sess.run([old_energy, new_energy], 110 feed_dict) 111 logging_ops.vlog(1, "average energy relative change: {}".format( 112 (1. - new_energy_ / old_energy_).mean())) 113 self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02) 114 115 def _integrator_conserves_energy_wrapper(self, independent_chain_ndims): 116 """Tests the long-term energy conservation of the leapfrog integrator. 117 118 The leapfrog integrator is symplectic, so for sufficiently small step 119 sizes it should be possible to run it more or less indefinitely without 120 the energy of the system blowing up or collapsing. 121 122 Args: 123 independent_chain_ndims: Python `int` scalar representing the number of 124 dims associated with independent chains. 125 """ 126 with self.test_session(graph=ops.Graph()) as sess: 127 x_ph = array_ops.placeholder(np.float32, name="x_ph") 128 feed_dict = {x_ph: np.random.rand(50, 10, 2)} 129 self._integrator_conserves_energy(x_ph, independent_chain_ndims, 130 sess, feed_dict) 131 132 def testIntegratorEnergyConservationNullShape(self): 133 self._integrator_conserves_energy_wrapper(0) 134 135 def testIntegratorEnergyConservation1(self): 136 self._integrator_conserves_energy_wrapper(1) 137 138 def testIntegratorEnergyConservation2(self): 139 self._integrator_conserves_energy_wrapper(2) 140 141 def testIntegratorEnergyConservation3(self): 142 self._integrator_conserves_energy_wrapper(3) 143 144 def testSampleChainSeedReproducibleWorksCorrectly(self): 145 with self.test_session(graph=ops.Graph()) as sess: 146 num_results = 10 147 independent_chain_ndims = 1 148 149 def log_gamma_log_prob(x): 150 event_dims = math_ops.range(independent_chain_ndims, 151 array_ops.rank(x)) 152 return self._log_gamma_log_prob(x, event_dims) 153 154 kwargs = dict( 155 target_log_prob_fn=log_gamma_log_prob, 156 current_state=np.random.rand(4, 3, 2), 157 step_size=0.1, 158 num_leapfrog_steps=2, 159 num_burnin_steps=150, 160 seed=52, 161 ) 162 163 samples0, kernel_results0 = hmc.sample_chain( 164 **dict(list(kwargs.items()) + list(dict( 165 num_results=2 * num_results, 166 num_steps_between_results=0).items()))) 167 168 samples1, kernel_results1 = hmc.sample_chain( 169 **dict(list(kwargs.items()) + list(dict( 170 num_results=num_results, 171 num_steps_between_results=1).items()))) 172 173 [ 174 samples0_, 175 samples1_, 176 target_log_prob0_, 177 target_log_prob1_, 178 ] = sess.run([ 179 samples0, 180 samples1, 181 kernel_results0.current_target_log_prob, 182 kernel_results1.current_target_log_prob, 183 ]) 184 self.assertAllClose(samples0_[::2], samples1_, 185 atol=1e-5, rtol=1e-5) 186 self.assertAllClose(target_log_prob0_[::2], target_log_prob1_, 187 atol=1e-5, rtol=1e-5) 188 189 def _chain_gets_correct_expectations(self, x, independent_chain_ndims, 190 sess, feed_dict=None): 191 counter = collections.Counter() 192 def log_gamma_log_prob(x): 193 counter["target_calls"] += 1 194 event_dims = math_ops.range(independent_chain_ndims, 195 array_ops.rank(x)) 196 return self._log_gamma_log_prob(x, event_dims) 197 198 num_results = array_ops.placeholder( 199 np.int32, [], name="num_results") 200 step_size = array_ops.placeholder( 201 np.float32, [], name="step_size") 202 num_leapfrog_steps = array_ops.placeholder( 203 np.int32, [], name="num_leapfrog_steps") 204 205 if feed_dict is None: 206 feed_dict = {} 207 feed_dict.update({num_results: 150, 208 step_size: 0.05, 209 num_leapfrog_steps: 2}) 210 211 samples, kernel_results = hmc.sample_chain( 212 num_results=num_results, 213 target_log_prob_fn=log_gamma_log_prob, 214 current_state=x, 215 step_size=step_size, 216 num_leapfrog_steps=num_leapfrog_steps, 217 num_burnin_steps=150, 218 seed=42) 219 220 self.assertAllEqual(dict(target_calls=2), counter) 221 222 expected_x = (math_ops.digamma(self._shape_param) 223 - np.log(self._rate_param)) 224 225 expected_exp_x = self._shape_param / self._rate_param 226 227 acceptance_probs_, samples_, expected_x_ = sess.run( 228 [kernel_results.acceptance_probs, samples, expected_x], 229 feed_dict) 230 231 actual_x = samples_.mean() 232 actual_exp_x = np.exp(samples_).mean() 233 234 logging_ops.vlog(1, "True E[x, exp(x)]: {}\t{}".format( 235 expected_x_, expected_exp_x)) 236 logging_ops.vlog(1, "Estimated E[x, exp(x)]: {}\t{}".format( 237 actual_x, actual_exp_x)) 238 self.assertNear(actual_x, expected_x_, 2e-2) 239 self.assertNear(actual_exp_x, expected_exp_x, 2e-2) 240 self.assertAllEqual(np.ones_like(acceptance_probs_, np.bool), 241 acceptance_probs_ > 0.5) 242 self.assertAllEqual(np.ones_like(acceptance_probs_, np.bool), 243 acceptance_probs_ <= 1.) 244 245 def _chain_gets_correct_expectations_wrapper(self, independent_chain_ndims): 246 with self.test_session(graph=ops.Graph()) as sess: 247 x_ph = array_ops.placeholder(np.float32, name="x_ph") 248 feed_dict = {x_ph: np.random.rand(50, 10, 2)} 249 self._chain_gets_correct_expectations(x_ph, independent_chain_ndims, 250 sess, feed_dict) 251 252 def testHMCChainExpectationsNullShape(self): 253 self._chain_gets_correct_expectations_wrapper(0) 254 255 def testHMCChainExpectations1(self): 256 self._chain_gets_correct_expectations_wrapper(1) 257 258 def testHMCChainExpectations2(self): 259 self._chain_gets_correct_expectations_wrapper(2) 260 261 def testKernelResultsUsingTruncatedDistribution(self): 262 def log_prob(x): 263 return array_ops.where( 264 x >= 0., 265 -x - x**2, # Non-constant gradient. 266 array_ops.fill(x.shape, math_ops.cast(-np.inf, x.dtype))) 267 # This log_prob has the property that it is likely to attract 268 # the HMC flow toward, and below, zero...but for x <=0, 269 # log_prob(x) = -inf, which should result in rejection, as well 270 # as a non-finite log_prob. Thus, this distribution gives us an opportunity 271 # to test out the kernel results ability to correctly capture rejections due 272 # to finite AND non-finite reasons. 273 # Why use a non-constant gradient? This ensures the leapfrog integrator 274 # will not be exact. 275 276 num_results = 1000 277 # Large step size, will give rejections due to integration error in addition 278 # to rejection due to going into a region of log_prob = -inf. 279 step_size = 0.1 280 num_leapfrog_steps = 5 281 num_chains = 2 282 283 with self.test_session(graph=ops.Graph()) as sess: 284 285 # Start multiple independent chains. 286 initial_state = ops.convert_to_tensor([0.1] * num_chains) 287 288 states, kernel_results = hmc.sample_chain( 289 num_results=num_results, 290 target_log_prob_fn=log_prob, 291 current_state=initial_state, 292 step_size=step_size, 293 num_leapfrog_steps=num_leapfrog_steps, 294 seed=42) 295 296 states_, kernel_results_ = sess.run([states, kernel_results]) 297 pstates_ = kernel_results_.proposed_state 298 299 neg_inf_mask = np.isneginf(kernel_results_.proposed_target_log_prob) 300 301 # First: Test that the mathematical properties of the above log prob 302 # function in conjunction with HMC show up as expected in kernel_results_. 303 304 # We better have log_prob = -inf some of the time. 305 self.assertLess(0, neg_inf_mask.sum()) 306 # We better have some rejections due to something other than -inf. 307 self.assertLess(neg_inf_mask.sum(), (~kernel_results_.is_accepted).sum()) 308 # We better have been accepted a decent amount, even near the end of the 309 # chain, or else this HMC run just got stuck at some point. 310 self.assertLess( 311 0.1, kernel_results_.is_accepted[int(0.9 * num_results):].mean()) 312 # We better not have any NaNs in proposed state or log_prob. 313 # We may have some NaN in grads, which involve multiplication/addition due 314 # to gradient rules. This is the known "NaN grad issue with tf.where." 315 self.assertAllEqual(np.zeros_like(states_), 316 np.isnan(kernel_results_.proposed_target_log_prob)) 317 self.assertAllEqual(np.zeros_like(states_), 318 np.isnan(states_)) 319 # We better not have any +inf in states, grads, or log_prob. 320 self.assertAllEqual(np.zeros_like(states_), 321 np.isposinf(kernel_results_.proposed_target_log_prob)) 322 self.assertAllEqual( 323 np.zeros_like(states_), 324 np.isposinf(kernel_results_.proposed_grads_target_log_prob[0])) 325 self.assertAllEqual(np.zeros_like(states_), 326 np.isposinf(states_)) 327 328 # Second: Test that kernel_results is congruent with itself and 329 # acceptance/rejection of states. 330 331 # Proposed state is negative iff proposed target log prob is -inf. 332 np.testing.assert_array_less(pstates_[neg_inf_mask], 0.) 333 np.testing.assert_array_less(0., pstates_[~neg_inf_mask]) 334 335 # Acceptance probs are zero whenever proposed state is negative. 336 self.assertAllEqual( 337 np.zeros_like(pstates_[neg_inf_mask]), 338 kernel_results_.acceptance_probs[neg_inf_mask]) 339 340 # The move is accepted ==> state = proposed state. 341 self.assertAllEqual( 342 states_[kernel_results_.is_accepted], 343 pstates_[kernel_results_.is_accepted], 344 ) 345 # The move was rejected <==> state[t] == state[t - 1]. 346 for t in range(1, num_results): 347 for i in range(num_chains): 348 if kernel_results_.is_accepted[t, i]: 349 self.assertNotEqual(states_[t, i], states_[t - 1, i]) 350 else: 351 self.assertEqual(states_[t, i], states_[t - 1, i]) 352 353 def _kernel_leaves_target_invariant(self, initial_draws, 354 independent_chain_ndims, 355 sess, feed_dict=None): 356 def log_gamma_log_prob(x): 357 event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) 358 return self._log_gamma_log_prob(x, event_dims) 359 360 def fake_log_prob(x): 361 """Cooled version of the target distribution.""" 362 return 1.1 * log_gamma_log_prob(x) 363 364 step_size = array_ops.placeholder(np.float32, [], name="step_size") 365 366 if feed_dict is None: 367 feed_dict = {} 368 369 feed_dict[step_size] = 0.4 370 371 sample, kernel_results = hmc.kernel( 372 target_log_prob_fn=log_gamma_log_prob, 373 current_state=initial_draws, 374 step_size=step_size, 375 num_leapfrog_steps=5, 376 seed=43) 377 378 bad_sample, bad_kernel_results = hmc.kernel( 379 target_log_prob_fn=fake_log_prob, 380 current_state=initial_draws, 381 step_size=step_size, 382 num_leapfrog_steps=5, 383 seed=44) 384 385 [ 386 acceptance_probs_, 387 bad_acceptance_probs_, 388 initial_draws_, 389 updated_draws_, 390 fake_draws_, 391 ] = sess.run([ 392 kernel_results.acceptance_probs, 393 bad_kernel_results.acceptance_probs, 394 initial_draws, 395 sample, 396 bad_sample, 397 ], feed_dict) 398 399 # Confirm step size is small enough that we usually accept. 400 self.assertGreater(acceptance_probs_.mean(), 0.5) 401 self.assertGreater(bad_acceptance_probs_.mean(), 0.5) 402 403 # Confirm step size is large enough that we sometimes reject. 404 self.assertLess(acceptance_probs_.mean(), 0.99) 405 self.assertLess(bad_acceptance_probs_.mean(), 0.99) 406 407 _, ks_p_value_true = stats.ks_2samp(initial_draws_.flatten(), 408 updated_draws_.flatten()) 409 _, ks_p_value_fake = stats.ks_2samp(initial_draws_.flatten(), 410 fake_draws_.flatten()) 411 412 logging_ops.vlog(1, "acceptance rate for true target: {}".format( 413 acceptance_probs_.mean())) 414 logging_ops.vlog(1, "acceptance rate for fake target: {}".format( 415 bad_acceptance_probs_.mean())) 416 logging_ops.vlog(1, "K-S p-value for true target: {}".format( 417 ks_p_value_true)) 418 logging_ops.vlog(1, "K-S p-value for fake target: {}".format( 419 ks_p_value_fake)) 420 # Make sure that the MCMC update hasn't changed the empirical CDF much. 421 self.assertGreater(ks_p_value_true, 1e-3) 422 # Confirm that targeting the wrong distribution does 423 # significantly change the empirical CDF. 424 self.assertLess(ks_p_value_fake, 1e-6) 425 426 def _kernel_leaves_target_invariant_wrapper(self, independent_chain_ndims): 427 """Tests that the kernel leaves the target distribution invariant. 428 429 Draws some independent samples from the target distribution, 430 applies an iteration of the MCMC kernel, then runs a 431 Kolmogorov-Smirnov test to determine if the distribution of the 432 MCMC-updated samples has changed. 433 434 We also confirm that running the kernel with a different log-pdf 435 does change the target distribution. (And that we can detect that.) 436 437 Args: 438 independent_chain_ndims: Python `int` scalar representing the number of 439 dims associated with independent chains. 440 """ 441 with self.test_session(graph=ops.Graph()) as sess: 442 initial_draws = np.log(np.random.gamma(self._shape_param, 443 size=[50000, 2, 2])) 444 initial_draws -= np.log(self._rate_param) 445 x_ph = array_ops.placeholder(np.float32, name="x_ph") 446 447 feed_dict = {x_ph: initial_draws} 448 449 self._kernel_leaves_target_invariant(x_ph, independent_chain_ndims, 450 sess, feed_dict) 451 452 def testKernelLeavesTargetInvariant1(self): 453 self._kernel_leaves_target_invariant_wrapper(1) 454 455 def testKernelLeavesTargetInvariant2(self): 456 self._kernel_leaves_target_invariant_wrapper(2) 457 458 def testKernelLeavesTargetInvariant3(self): 459 self._kernel_leaves_target_invariant_wrapper(3) 460 461 def _ais_gets_correct_log_normalizer(self, init, independent_chain_ndims, 462 sess, feed_dict=None): 463 counter = collections.Counter() 464 465 def proposal_log_prob(x): 466 counter["proposal_calls"] += 1 467 event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) 468 return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi), 469 axis=event_dims) 470 471 def target_log_prob(x): 472 counter["target_calls"] += 1 473 event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) 474 return self._log_gamma_log_prob(x, event_dims) 475 476 if feed_dict is None: 477 feed_dict = {} 478 479 num_steps = 200 480 481 _, ais_weights, _ = hmc.sample_annealed_importance_chain( 482 proposal_log_prob_fn=proposal_log_prob, 483 num_steps=num_steps, 484 target_log_prob_fn=target_log_prob, 485 step_size=0.5, 486 current_state=init, 487 num_leapfrog_steps=2, 488 seed=45) 489 490 # We have three calls because the calculation of `ais_weights` entails 491 # another call to the `convex_combined_log_prob_fn`. We could refactor 492 # things to avoid this, if needed (eg, b/72994218). 493 self.assertAllEqual(dict(target_calls=3, proposal_calls=3), counter) 494 495 event_shape = array_ops.shape(init)[independent_chain_ndims:] 496 event_size = math_ops.reduce_prod(event_shape) 497 498 log_true_normalizer = ( 499 -self._shape_param * math_ops.log(self._rate_param) 500 + math_ops.lgamma(self._shape_param)) 501 log_true_normalizer *= math_ops.cast(event_size, log_true_normalizer.dtype) 502 503 log_estimated_normalizer = (math_ops.reduce_logsumexp(ais_weights) 504 - np.log(num_steps)) 505 506 ratio_estimate_true = math_ops.exp(ais_weights - log_true_normalizer) 507 ais_weights_size = array_ops.size(ais_weights) 508 standard_error = math_ops.sqrt( 509 _reduce_variance(ratio_estimate_true) 510 / math_ops.cast(ais_weights_size, ratio_estimate_true.dtype)) 511 512 [ 513 ratio_estimate_true_, 514 log_true_normalizer_, 515 log_estimated_normalizer_, 516 standard_error_, 517 ais_weights_size_, 518 event_size_, 519 ] = sess.run([ 520 ratio_estimate_true, 521 log_true_normalizer, 522 log_estimated_normalizer, 523 standard_error, 524 ais_weights_size, 525 event_size, 526 ], feed_dict) 527 528 logging_ops.vlog(1, " log_true_normalizer: {}\n" 529 " log_estimated_normalizer: {}\n" 530 " ais_weights_size: {}\n" 531 " event_size: {}\n".format( 532 log_true_normalizer_, 533 log_estimated_normalizer_, 534 ais_weights_size_, 535 event_size_)) 536 self.assertNear(ratio_estimate_true_.mean(), 1., 4. * standard_error_) 537 538 def _ais_gets_correct_log_normalizer_wrapper(self, independent_chain_ndims): 539 """Tests that AIS yields reasonable estimates of normalizers.""" 540 with self.test_session(graph=ops.Graph()) as sess: 541 x_ph = array_ops.placeholder(np.float32, name="x_ph") 542 initial_draws = np.random.normal(size=[30, 2, 1]) 543 self._ais_gets_correct_log_normalizer( 544 x_ph, 545 independent_chain_ndims, 546 sess, 547 feed_dict={x_ph: initial_draws}) 548 549 def testAIS1(self): 550 self._ais_gets_correct_log_normalizer_wrapper(1) 551 552 def testAIS2(self): 553 self._ais_gets_correct_log_normalizer_wrapper(2) 554 555 def testAIS3(self): 556 self._ais_gets_correct_log_normalizer_wrapper(3) 557 558 def testSampleAIChainSeedReproducibleWorksCorrectly(self): 559 with self.test_session(graph=ops.Graph()) as sess: 560 independent_chain_ndims = 1 561 x = np.random.rand(4, 3, 2) 562 563 def proposal_log_prob(x): 564 event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) 565 return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi), 566 axis=event_dims) 567 568 def target_log_prob(x): 569 event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) 570 return self._log_gamma_log_prob(x, event_dims) 571 572 ais_kwargs = dict( 573 proposal_log_prob_fn=proposal_log_prob, 574 num_steps=200, 575 target_log_prob_fn=target_log_prob, 576 step_size=0.5, 577 current_state=x, 578 num_leapfrog_steps=2, 579 seed=53) 580 581 _, ais_weights0, _ = hmc.sample_annealed_importance_chain( 582 **ais_kwargs) 583 584 _, ais_weights1, _ = hmc.sample_annealed_importance_chain( 585 **ais_kwargs) 586 587 [ais_weights0_, ais_weights1_] = sess.run([ 588 ais_weights0, ais_weights1]) 589 590 self.assertAllClose(ais_weights0_, ais_weights1_, 591 atol=1e-5, rtol=1e-5) 592 593 def testNanRejection(self): 594 """Tests that an update that yields NaN potentials gets rejected. 595 596 We run HMC with a target distribution that returns NaN 597 log-likelihoods if any element of x < 0, and unit-scale 598 exponential log-likelihoods otherwise. The exponential potential 599 pushes x towards 0, ensuring that any reasonably large update will 600 push us over the edge into NaN territory. 601 """ 602 def _unbounded_exponential_log_prob(x): 603 """An exponential distribution with log-likelihood NaN for x < 0.""" 604 per_element_potentials = array_ops.where( 605 x < 0., 606 array_ops.fill(array_ops.shape(x), x.dtype.as_numpy_dtype(np.nan)), 607 -x) 608 return math_ops.reduce_sum(per_element_potentials) 609 610 with self.test_session(graph=ops.Graph()) as sess: 611 initial_x = math_ops.linspace(0.01, 5, 10) 612 updated_x, kernel_results = hmc.kernel( 613 target_log_prob_fn=_unbounded_exponential_log_prob, 614 current_state=initial_x, 615 step_size=2., 616 num_leapfrog_steps=5, 617 seed=46) 618 initial_x_, updated_x_, acceptance_probs_ = sess.run( 619 [initial_x, updated_x, kernel_results.acceptance_probs]) 620 621 logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) 622 logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) 623 logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_)) 624 625 self.assertAllEqual(initial_x_, updated_x_) 626 self.assertEqual(acceptance_probs_, 0.) 627 628 def testNanFromGradsDontPropagate(self): 629 """Test that update with NaN gradients does not cause NaN in results.""" 630 def _nan_log_prob_with_nan_gradient(x): 631 return np.nan * math_ops.reduce_sum(x) 632 633 with self.test_session(graph=ops.Graph()) as sess: 634 initial_x = math_ops.linspace(0.01, 5, 10) 635 updated_x, kernel_results = hmc.kernel( 636 target_log_prob_fn=_nan_log_prob_with_nan_gradient, 637 current_state=initial_x, 638 step_size=2., 639 num_leapfrog_steps=5, 640 seed=47) 641 initial_x_, updated_x_, acceptance_probs_ = sess.run( 642 [initial_x, updated_x, kernel_results.acceptance_probs]) 643 644 logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) 645 logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) 646 logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_)) 647 648 self.assertAllEqual(initial_x_, updated_x_) 649 self.assertEqual(acceptance_probs_, 0.) 650 651 self.assertAllFinite( 652 gradients_ops.gradients(updated_x, initial_x)[0].eval()) 653 self.assertAllEqual([True], [g is None for g in gradients_ops.gradients( 654 kernel_results.proposed_grads_target_log_prob, initial_x)]) 655 self.assertAllEqual([False], [g is None for g in gradients_ops.gradients( 656 kernel_results.proposed_grads_target_log_prob, 657 kernel_results.proposed_state)]) 658 659 # Gradients of the acceptance probs and new log prob are not finite. 660 # self.assertAllFinite( 661 # gradients_ops.gradients(acceptance_probs, initial_x)[0].eval()) 662 # self.assertAllFinite( 663 # gradients_ops.gradients(new_log_prob, initial_x)[0].eval()) 664 665 def _testChainWorksDtype(self, dtype): 666 with self.test_session(graph=ops.Graph()) as sess: 667 states, kernel_results = hmc.sample_chain( 668 num_results=10, 669 target_log_prob_fn=lambda x: -math_ops.reduce_sum(x**2., axis=-1), 670 current_state=np.zeros(5).astype(dtype), 671 step_size=0.01, 672 num_leapfrog_steps=10, 673 seed=48) 674 states_, acceptance_probs_ = sess.run( 675 [states, kernel_results.acceptance_probs]) 676 self.assertEqual(dtype, states_.dtype) 677 self.assertEqual(dtype, acceptance_probs_.dtype) 678 679 def testChainWorksIn64Bit(self): 680 self._testChainWorksDtype(np.float64) 681 682 def testChainWorksIn16Bit(self): 683 self._testChainWorksDtype(np.float16) 684 685 def testChainWorksCorrelatedMultivariate(self): 686 dtype = np.float32 687 true_mean = dtype([0, 0]) 688 true_cov = dtype([[1, 0.5], 689 [0.5, 1]]) 690 num_results = 2000 691 counter = collections.Counter() 692 with self.test_session(graph=ops.Graph()) as sess: 693 def target_log_prob(x, y): 694 counter["target_calls"] += 1 695 # Corresponds to unnormalized MVN. 696 # z = matmul(inv(chol(true_cov)), [x, y] - true_mean) 697 z = array_ops.stack([x, y], axis=-1) - true_mean 698 z = array_ops.squeeze( 699 gen_linalg_ops.matrix_triangular_solve( 700 np.linalg.cholesky(true_cov), 701 z[..., array_ops.newaxis]), 702 axis=-1) 703 return -0.5 * math_ops.reduce_sum(z**2., axis=-1) 704 states, _ = hmc.sample_chain( 705 num_results=num_results, 706 target_log_prob_fn=target_log_prob, 707 current_state=[dtype(-2), dtype(2)], 708 step_size=[0.5, 0.5], 709 num_leapfrog_steps=2, 710 num_burnin_steps=200, 711 num_steps_between_results=1, 712 seed=54) 713 self.assertAllEqual(dict(target_calls=2), counter) 714 states = array_ops.stack(states, axis=-1) 715 self.assertEqual(num_results, states.shape[0].value) 716 sample_mean = math_ops.reduce_mean(states, axis=0) 717 x = states - sample_mean 718 sample_cov = math_ops.matmul(x, x, transpose_a=True) / dtype(num_results) 719 [sample_mean_, sample_cov_] = sess.run([ 720 sample_mean, sample_cov]) 721 self.assertAllClose(true_mean, sample_mean_, 722 atol=0.05, rtol=0.) 723 self.assertAllClose(true_cov, sample_cov_, 724 atol=0., rtol=0.1) 725 726 727 class _EnergyComputationTest(object): 728 729 def testHandlesNanFromPotential(self): 730 with self.test_session(graph=ops.Graph()) as sess: 731 x = [1, np.inf, -np.inf, np.nan] 732 target_log_prob, proposed_target_log_prob = [ 733 self.dtype(x.flatten()) for x in np.meshgrid(x, x)] 734 num_chains = len(target_log_prob) 735 dummy_momentums = [-1, 1] 736 momentums = [self.dtype([dummy_momentums] * num_chains)] 737 proposed_momentums = [self.dtype([dummy_momentums] * num_chains)] 738 739 target_log_prob = ops.convert_to_tensor(target_log_prob) 740 momentums = [ops.convert_to_tensor(momentums[0])] 741 proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob) 742 proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])] 743 744 energy = _compute_energy_change( 745 target_log_prob, 746 momentums, 747 proposed_target_log_prob, 748 proposed_momentums, 749 independent_chain_ndims=1) 750 grads = gradients_ops.gradients(energy, momentums) 751 752 [actual_energy, grads_] = sess.run([energy, grads]) 753 754 # Ensure energy is `inf` (note: that's positive inf) in weird cases and 755 # finite otherwise. 756 expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1)) 757 self.assertAllEqual(expected_energy, actual_energy) 758 759 # Ensure gradient is finite. 760 self.assertAllEqual(np.ones_like(grads_).astype(np.bool), 761 np.isfinite(grads_)) 762 763 def testHandlesNanFromKinetic(self): 764 with self.test_session(graph=ops.Graph()) as sess: 765 x = [1, np.inf, -np.inf, np.nan] 766 momentums, proposed_momentums = [ 767 [np.reshape(self.dtype(x), [-1, 1])] 768 for x in np.meshgrid(x, x)] 769 num_chains = len(momentums[0]) 770 target_log_prob = np.ones(num_chains, self.dtype) 771 proposed_target_log_prob = np.ones(num_chains, self.dtype) 772 773 target_log_prob = ops.convert_to_tensor(target_log_prob) 774 momentums = [ops.convert_to_tensor(momentums[0])] 775 proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob) 776 proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])] 777 778 energy = _compute_energy_change( 779 target_log_prob, 780 momentums, 781 proposed_target_log_prob, 782 proposed_momentums, 783 independent_chain_ndims=1) 784 grads = gradients_ops.gradients(energy, momentums) 785 786 [actual_energy, grads_] = sess.run([energy, grads]) 787 788 # Ensure energy is `inf` (note: that's positive inf) in weird cases and 789 # finite otherwise. 790 expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1)) 791 self.assertAllEqual(expected_energy, actual_energy) 792 793 # Ensure gradient is finite. 794 g = grads_[0].reshape([len(x), len(x)])[:, 0] 795 self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isfinite(g)) 796 797 # The remaining gradients are nan because the momentum was itself nan or 798 # inf. 799 g = grads_[0].reshape([len(x), len(x)])[:, 1:] 800 self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isnan(g)) 801 802 803 class EnergyComputationTest16(test.TestCase, _EnergyComputationTest): 804 dtype = np.float16 805 806 807 class EnergyComputationTest32(test.TestCase, _EnergyComputationTest): 808 dtype = np.float32 809 810 811 class EnergyComputationTest64(test.TestCase, _EnergyComputationTest): 812 dtype = np.float64 813 814 815 class _HMCHandlesLists(object): 816 817 def testStateParts(self): 818 with self.test_session(graph=ops.Graph()) as sess: 819 dist_x = normal_lib.Normal(loc=self.dtype(0), scale=self.dtype(1)) 820 dist_y = independent_lib.Independent( 821 gamma_lib.Gamma(concentration=self.dtype([1, 2]), 822 rate=self.dtype([0.5, 0.75])), 823 reinterpreted_batch_ndims=1) 824 def target_log_prob(x, y): 825 return dist_x.log_prob(x) + dist_y.log_prob(y) 826 x0 = [dist_x.sample(seed=1), dist_y.sample(seed=2)] 827 samples, _ = hmc.sample_chain( 828 num_results=int(2e3), 829 target_log_prob_fn=target_log_prob, 830 current_state=x0, 831 step_size=0.85, 832 num_leapfrog_steps=3, 833 num_burnin_steps=int(250), 834 seed=49) 835 actual_means = [math_ops.reduce_mean(s, axis=0) for s in samples] 836 actual_vars = [_reduce_variance(s, axis=0) for s in samples] 837 expected_means = [dist_x.mean(), dist_y.mean()] 838 expected_vars = [dist_x.variance(), dist_y.variance()] 839 [ 840 actual_means_, 841 actual_vars_, 842 expected_means_, 843 expected_vars_, 844 ] = sess.run([ 845 actual_means, 846 actual_vars, 847 expected_means, 848 expected_vars, 849 ]) 850 self.assertAllClose(expected_means_, actual_means_, atol=0.05, rtol=0.16) 851 self.assertAllClose(expected_vars_, actual_vars_, atol=0., rtol=0.25) 852 853 854 class HMCHandlesLists32(_HMCHandlesLists, test.TestCase): 855 dtype = np.float32 856 857 858 class HMCHandlesLists64(_HMCHandlesLists, test.TestCase): 859 dtype = np.float64 860 861 862 if __name__ == "__main__": 863 test.main() 864