Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for Hamiltonian Monte Carlo."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 
     23 import numpy as np
     24 from scipy import stats
     25 
     26 from tensorflow.contrib.bayesflow.python.ops import hmc
     27 from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _compute_energy_change
     28 from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _leapfrog_integrator
     29 
     30 from tensorflow.contrib.distributions.python.ops import independent as independent_lib
     31 from tensorflow.python.framework import ops
     32 from tensorflow.python.framework import random_seed
     33 from tensorflow.python.ops import array_ops
     34 from tensorflow.python.ops import gen_linalg_ops
     35 from tensorflow.python.ops import gradients_impl as gradients_ops
     36 from tensorflow.python.ops import math_ops
     37 from tensorflow.python.ops import random_ops
     38 from tensorflow.python.ops.distributions import gamma as gamma_lib
     39 from tensorflow.python.ops.distributions import normal as normal_lib
     40 from tensorflow.python.platform import test
     41 from tensorflow.python.platform import tf_logging as logging_ops
     42 
     43 
     44 def _reduce_variance(x, axis=None, keepdims=False):
     45   sample_mean = math_ops.reduce_mean(x, axis, keepdims=True)
     46   return math_ops.reduce_mean(
     47       math_ops.squared_difference(x, sample_mean), axis, keepdims)
     48 
     49 
     50 class HMCTest(test.TestCase):
     51 
     52   def setUp(self):
     53     self._shape_param = 5.
     54     self._rate_param = 10.
     55 
     56     random_seed.set_random_seed(10003)
     57     np.random.seed(10003)
     58 
     59   def assertAllFinite(self, x):
     60     self.assertAllEqual(np.ones_like(x).astype(bool), np.isfinite(x))
     61 
     62   def _log_gamma_log_prob(self, x, event_dims=()):
     63     """Computes log-pdf of a log-gamma random variable.
     64 
     65     Args:
     66       x: Value of the random variable.
     67       event_dims: Dimensions not to treat as independent.
     68 
     69     Returns:
     70       log_prob: The log-pdf up to a normalizing constant.
     71     """
     72     return math_ops.reduce_sum(self._shape_param * x -
     73                                self._rate_param * math_ops.exp(x),
     74                                event_dims)
     75 
     76   def _integrator_conserves_energy(self, x, independent_chain_ndims, sess,
     77                                    feed_dict=None):
     78     step_size = array_ops.placeholder(np.float32, [], name="step_size")
     79     hmc_lf_steps = array_ops.placeholder(np.int32, [], name="hmc_lf_steps")
     80 
     81     if feed_dict is None:
     82       feed_dict = {}
     83     feed_dict[hmc_lf_steps] = 1000
     84 
     85     event_dims = math_ops.range(independent_chain_ndims,
     86                                 array_ops.rank(x))
     87 
     88     m = random_ops.random_normal(array_ops.shape(x))
     89     log_prob_0 = self._log_gamma_log_prob(x, event_dims)
     90     grad_0 = gradients_ops.gradients(log_prob_0, x)
     91     old_energy = -log_prob_0 + 0.5 * math_ops.reduce_sum(m**2., event_dims)
     92 
     93     new_m, _, log_prob_1, _ = _leapfrog_integrator(
     94         current_momentums=[m],
     95         target_log_prob_fn=lambda x: self._log_gamma_log_prob(x, event_dims),
     96         current_state_parts=[x],
     97         step_sizes=[step_size],
     98         num_leapfrog_steps=hmc_lf_steps,
     99         current_target_log_prob=log_prob_0,
    100         current_grads_target_log_prob=grad_0)
    101     new_m = new_m[0]
    102 
    103     new_energy = -log_prob_1 + 0.5 * math_ops.reduce_sum(new_m * new_m,
    104                                                          event_dims)
    105 
    106     x_shape = sess.run(x, feed_dict).shape
    107     event_size = np.prod(x_shape[independent_chain_ndims:])
    108     feed_dict[step_size] = 0.1 / event_size
    109     old_energy_, new_energy_ = sess.run([old_energy, new_energy],
    110                                         feed_dict)
    111     logging_ops.vlog(1, "average energy relative change: {}".format(
    112         (1. - new_energy_ / old_energy_).mean()))
    113     self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02)
    114 
    115   def _integrator_conserves_energy_wrapper(self, independent_chain_ndims):
    116     """Tests the long-term energy conservation of the leapfrog integrator.
    117 
    118     The leapfrog integrator is symplectic, so for sufficiently small step
    119     sizes it should be possible to run it more or less indefinitely without
    120     the energy of the system blowing up or collapsing.
    121 
    122     Args:
    123       independent_chain_ndims: Python `int` scalar representing the number of
    124         dims associated with independent chains.
    125     """
    126     with self.test_session(graph=ops.Graph()) as sess:
    127       x_ph = array_ops.placeholder(np.float32, name="x_ph")
    128       feed_dict = {x_ph: np.random.rand(50, 10, 2)}
    129       self._integrator_conserves_energy(x_ph, independent_chain_ndims,
    130                                         sess, feed_dict)
    131 
    132   def testIntegratorEnergyConservationNullShape(self):
    133     self._integrator_conserves_energy_wrapper(0)
    134 
    135   def testIntegratorEnergyConservation1(self):
    136     self._integrator_conserves_energy_wrapper(1)
    137 
    138   def testIntegratorEnergyConservation2(self):
    139     self._integrator_conserves_energy_wrapper(2)
    140 
    141   def testIntegratorEnergyConservation3(self):
    142     self._integrator_conserves_energy_wrapper(3)
    143 
    144   def testSampleChainSeedReproducibleWorksCorrectly(self):
    145     with self.test_session(graph=ops.Graph()) as sess:
    146       num_results = 10
    147       independent_chain_ndims = 1
    148 
    149       def log_gamma_log_prob(x):
    150         event_dims = math_ops.range(independent_chain_ndims,
    151                                     array_ops.rank(x))
    152         return self._log_gamma_log_prob(x, event_dims)
    153 
    154       kwargs = dict(
    155           target_log_prob_fn=log_gamma_log_prob,
    156           current_state=np.random.rand(4, 3, 2),
    157           step_size=0.1,
    158           num_leapfrog_steps=2,
    159           num_burnin_steps=150,
    160           seed=52,
    161       )
    162 
    163       samples0, kernel_results0 = hmc.sample_chain(
    164           **dict(list(kwargs.items()) + list(dict(
    165               num_results=2 * num_results,
    166               num_steps_between_results=0).items())))
    167 
    168       samples1, kernel_results1 = hmc.sample_chain(
    169           **dict(list(kwargs.items()) + list(dict(
    170               num_results=num_results,
    171               num_steps_between_results=1).items())))
    172 
    173       [
    174           samples0_,
    175           samples1_,
    176           target_log_prob0_,
    177           target_log_prob1_,
    178       ] = sess.run([
    179           samples0,
    180           samples1,
    181           kernel_results0.current_target_log_prob,
    182           kernel_results1.current_target_log_prob,
    183       ])
    184       self.assertAllClose(samples0_[::2], samples1_,
    185                           atol=1e-5, rtol=1e-5)
    186       self.assertAllClose(target_log_prob0_[::2], target_log_prob1_,
    187                           atol=1e-5, rtol=1e-5)
    188 
    189   def _chain_gets_correct_expectations(self, x, independent_chain_ndims,
    190                                        sess, feed_dict=None):
    191     counter = collections.Counter()
    192     def log_gamma_log_prob(x):
    193       counter["target_calls"] += 1
    194       event_dims = math_ops.range(independent_chain_ndims,
    195                                   array_ops.rank(x))
    196       return self._log_gamma_log_prob(x, event_dims)
    197 
    198     num_results = array_ops.placeholder(
    199         np.int32, [], name="num_results")
    200     step_size = array_ops.placeholder(
    201         np.float32, [], name="step_size")
    202     num_leapfrog_steps = array_ops.placeholder(
    203         np.int32, [], name="num_leapfrog_steps")
    204 
    205     if feed_dict is None:
    206       feed_dict = {}
    207     feed_dict.update({num_results: 150,
    208                       step_size: 0.05,
    209                       num_leapfrog_steps: 2})
    210 
    211     samples, kernel_results = hmc.sample_chain(
    212         num_results=num_results,
    213         target_log_prob_fn=log_gamma_log_prob,
    214         current_state=x,
    215         step_size=step_size,
    216         num_leapfrog_steps=num_leapfrog_steps,
    217         num_burnin_steps=150,
    218         seed=42)
    219 
    220     self.assertAllEqual(dict(target_calls=2), counter)
    221 
    222     expected_x = (math_ops.digamma(self._shape_param)
    223                   - np.log(self._rate_param))
    224 
    225     expected_exp_x = self._shape_param / self._rate_param
    226 
    227     acceptance_probs_, samples_, expected_x_ = sess.run(
    228         [kernel_results.acceptance_probs, samples, expected_x],
    229         feed_dict)
    230 
    231     actual_x = samples_.mean()
    232     actual_exp_x = np.exp(samples_).mean()
    233 
    234     logging_ops.vlog(1, "True      E[x, exp(x)]: {}\t{}".format(
    235         expected_x_, expected_exp_x))
    236     logging_ops.vlog(1, "Estimated E[x, exp(x)]: {}\t{}".format(
    237         actual_x, actual_exp_x))
    238     self.assertNear(actual_x, expected_x_, 2e-2)
    239     self.assertNear(actual_exp_x, expected_exp_x, 2e-2)
    240     self.assertAllEqual(np.ones_like(acceptance_probs_, np.bool),
    241                         acceptance_probs_ > 0.5)
    242     self.assertAllEqual(np.ones_like(acceptance_probs_, np.bool),
    243                         acceptance_probs_ <= 1.)
    244 
    245   def _chain_gets_correct_expectations_wrapper(self, independent_chain_ndims):
    246     with self.test_session(graph=ops.Graph()) as sess:
    247       x_ph = array_ops.placeholder(np.float32, name="x_ph")
    248       feed_dict = {x_ph: np.random.rand(50, 10, 2)}
    249       self._chain_gets_correct_expectations(x_ph, independent_chain_ndims,
    250                                             sess, feed_dict)
    251 
    252   def testHMCChainExpectationsNullShape(self):
    253     self._chain_gets_correct_expectations_wrapper(0)
    254 
    255   def testHMCChainExpectations1(self):
    256     self._chain_gets_correct_expectations_wrapper(1)
    257 
    258   def testHMCChainExpectations2(self):
    259     self._chain_gets_correct_expectations_wrapper(2)
    260 
    261   def testKernelResultsUsingTruncatedDistribution(self):
    262     def log_prob(x):
    263       return array_ops.where(
    264           x >= 0.,
    265           -x - x**2,  # Non-constant gradient.
    266           array_ops.fill(x.shape, math_ops.cast(-np.inf, x.dtype)))
    267     # This log_prob has the property that it is likely to attract
    268     # the HMC flow toward, and below, zero...but for x <=0,
    269     # log_prob(x) = -inf, which should result in rejection, as well
    270     # as a non-finite log_prob.  Thus, this distribution gives us an opportunity
    271     # to test out the kernel results ability to correctly capture rejections due
    272     # to finite AND non-finite reasons.
    273     # Why use a non-constant gradient?  This ensures the leapfrog integrator
    274     # will not be exact.
    275 
    276     num_results = 1000
    277     # Large step size, will give rejections due to integration error in addition
    278     # to rejection due to going into a region of log_prob = -inf.
    279     step_size = 0.1
    280     num_leapfrog_steps = 5
    281     num_chains = 2
    282 
    283     with self.test_session(graph=ops.Graph()) as sess:
    284 
    285       # Start multiple independent chains.
    286       initial_state = ops.convert_to_tensor([0.1] * num_chains)
    287 
    288       states, kernel_results = hmc.sample_chain(
    289           num_results=num_results,
    290           target_log_prob_fn=log_prob,
    291           current_state=initial_state,
    292           step_size=step_size,
    293           num_leapfrog_steps=num_leapfrog_steps,
    294           seed=42)
    295 
    296       states_, kernel_results_ = sess.run([states, kernel_results])
    297       pstates_ = kernel_results_.proposed_state
    298 
    299       neg_inf_mask = np.isneginf(kernel_results_.proposed_target_log_prob)
    300 
    301       # First:  Test that the mathematical properties of the above log prob
    302       # function in conjunction with HMC show up as expected in kernel_results_.
    303 
    304       # We better have log_prob = -inf some of the time.
    305       self.assertLess(0, neg_inf_mask.sum())
    306       # We better have some rejections due to something other than -inf.
    307       self.assertLess(neg_inf_mask.sum(), (~kernel_results_.is_accepted).sum())
    308       # We better have been accepted a decent amount, even near the end of the
    309       # chain, or else this HMC run just got stuck at some point.
    310       self.assertLess(
    311           0.1, kernel_results_.is_accepted[int(0.9 * num_results):].mean())
    312       # We better not have any NaNs in proposed state or log_prob.
    313       # We may have some NaN in grads, which involve multiplication/addition due
    314       # to gradient rules.  This is the known "NaN grad issue with tf.where."
    315       self.assertAllEqual(np.zeros_like(states_),
    316                           np.isnan(kernel_results_.proposed_target_log_prob))
    317       self.assertAllEqual(np.zeros_like(states_),
    318                           np.isnan(states_))
    319       # We better not have any +inf in states, grads, or log_prob.
    320       self.assertAllEqual(np.zeros_like(states_),
    321                           np.isposinf(kernel_results_.proposed_target_log_prob))
    322       self.assertAllEqual(
    323           np.zeros_like(states_),
    324           np.isposinf(kernel_results_.proposed_grads_target_log_prob[0]))
    325       self.assertAllEqual(np.zeros_like(states_),
    326                           np.isposinf(states_))
    327 
    328       # Second:  Test that kernel_results is congruent with itself and
    329       # acceptance/rejection of states.
    330 
    331       # Proposed state is negative iff proposed target log prob is -inf.
    332       np.testing.assert_array_less(pstates_[neg_inf_mask], 0.)
    333       np.testing.assert_array_less(0., pstates_[~neg_inf_mask])
    334 
    335       # Acceptance probs are zero whenever proposed state is negative.
    336       self.assertAllEqual(
    337           np.zeros_like(pstates_[neg_inf_mask]),
    338           kernel_results_.acceptance_probs[neg_inf_mask])
    339 
    340       # The move is accepted ==> state = proposed state.
    341       self.assertAllEqual(
    342           states_[kernel_results_.is_accepted],
    343           pstates_[kernel_results_.is_accepted],
    344       )
    345       # The move was rejected <==> state[t] == state[t - 1].
    346       for t in range(1, num_results):
    347         for i in range(num_chains):
    348           if kernel_results_.is_accepted[t, i]:
    349             self.assertNotEqual(states_[t, i], states_[t - 1, i])
    350           else:
    351             self.assertEqual(states_[t, i], states_[t - 1, i])
    352 
    353   def _kernel_leaves_target_invariant(self, initial_draws,
    354                                       independent_chain_ndims,
    355                                       sess, feed_dict=None):
    356     def log_gamma_log_prob(x):
    357       event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
    358       return self._log_gamma_log_prob(x, event_dims)
    359 
    360     def fake_log_prob(x):
    361       """Cooled version of the target distribution."""
    362       return 1.1 * log_gamma_log_prob(x)
    363 
    364     step_size = array_ops.placeholder(np.float32, [], name="step_size")
    365 
    366     if feed_dict is None:
    367       feed_dict = {}
    368 
    369     feed_dict[step_size] = 0.4
    370 
    371     sample, kernel_results = hmc.kernel(
    372         target_log_prob_fn=log_gamma_log_prob,
    373         current_state=initial_draws,
    374         step_size=step_size,
    375         num_leapfrog_steps=5,
    376         seed=43)
    377 
    378     bad_sample, bad_kernel_results = hmc.kernel(
    379         target_log_prob_fn=fake_log_prob,
    380         current_state=initial_draws,
    381         step_size=step_size,
    382         num_leapfrog_steps=5,
    383         seed=44)
    384 
    385     [
    386         acceptance_probs_,
    387         bad_acceptance_probs_,
    388         initial_draws_,
    389         updated_draws_,
    390         fake_draws_,
    391     ] = sess.run([
    392         kernel_results.acceptance_probs,
    393         bad_kernel_results.acceptance_probs,
    394         initial_draws,
    395         sample,
    396         bad_sample,
    397     ], feed_dict)
    398 
    399     # Confirm step size is small enough that we usually accept.
    400     self.assertGreater(acceptance_probs_.mean(), 0.5)
    401     self.assertGreater(bad_acceptance_probs_.mean(), 0.5)
    402 
    403     # Confirm step size is large enough that we sometimes reject.
    404     self.assertLess(acceptance_probs_.mean(), 0.99)
    405     self.assertLess(bad_acceptance_probs_.mean(), 0.99)
    406 
    407     _, ks_p_value_true = stats.ks_2samp(initial_draws_.flatten(),
    408                                         updated_draws_.flatten())
    409     _, ks_p_value_fake = stats.ks_2samp(initial_draws_.flatten(),
    410                                         fake_draws_.flatten())
    411 
    412     logging_ops.vlog(1, "acceptance rate for true target: {}".format(
    413         acceptance_probs_.mean()))
    414     logging_ops.vlog(1, "acceptance rate for fake target: {}".format(
    415         bad_acceptance_probs_.mean()))
    416     logging_ops.vlog(1, "K-S p-value for true target: {}".format(
    417         ks_p_value_true))
    418     logging_ops.vlog(1, "K-S p-value for fake target: {}".format(
    419         ks_p_value_fake))
    420     # Make sure that the MCMC update hasn't changed the empirical CDF much.
    421     self.assertGreater(ks_p_value_true, 1e-3)
    422     # Confirm that targeting the wrong distribution does
    423     # significantly change the empirical CDF.
    424     self.assertLess(ks_p_value_fake, 1e-6)
    425 
    426   def _kernel_leaves_target_invariant_wrapper(self, independent_chain_ndims):
    427     """Tests that the kernel leaves the target distribution invariant.
    428 
    429     Draws some independent samples from the target distribution,
    430     applies an iteration of the MCMC kernel, then runs a
    431     Kolmogorov-Smirnov test to determine if the distribution of the
    432     MCMC-updated samples has changed.
    433 
    434     We also confirm that running the kernel with a different log-pdf
    435     does change the target distribution. (And that we can detect that.)
    436 
    437     Args:
    438       independent_chain_ndims: Python `int` scalar representing the number of
    439         dims associated with independent chains.
    440     """
    441     with self.test_session(graph=ops.Graph()) as sess:
    442       initial_draws = np.log(np.random.gamma(self._shape_param,
    443                                              size=[50000, 2, 2]))
    444       initial_draws -= np.log(self._rate_param)
    445       x_ph = array_ops.placeholder(np.float32, name="x_ph")
    446 
    447       feed_dict = {x_ph: initial_draws}
    448 
    449       self._kernel_leaves_target_invariant(x_ph, independent_chain_ndims,
    450                                            sess, feed_dict)
    451 
    452   def testKernelLeavesTargetInvariant1(self):
    453     self._kernel_leaves_target_invariant_wrapper(1)
    454 
    455   def testKernelLeavesTargetInvariant2(self):
    456     self._kernel_leaves_target_invariant_wrapper(2)
    457 
    458   def testKernelLeavesTargetInvariant3(self):
    459     self._kernel_leaves_target_invariant_wrapper(3)
    460 
    461   def _ais_gets_correct_log_normalizer(self, init, independent_chain_ndims,
    462                                        sess, feed_dict=None):
    463     counter = collections.Counter()
    464 
    465     def proposal_log_prob(x):
    466       counter["proposal_calls"] += 1
    467       event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
    468       return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi),
    469                                         axis=event_dims)
    470 
    471     def target_log_prob(x):
    472       counter["target_calls"] += 1
    473       event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
    474       return self._log_gamma_log_prob(x, event_dims)
    475 
    476     if feed_dict is None:
    477       feed_dict = {}
    478 
    479     num_steps = 200
    480 
    481     _, ais_weights, _ = hmc.sample_annealed_importance_chain(
    482         proposal_log_prob_fn=proposal_log_prob,
    483         num_steps=num_steps,
    484         target_log_prob_fn=target_log_prob,
    485         step_size=0.5,
    486         current_state=init,
    487         num_leapfrog_steps=2,
    488         seed=45)
    489 
    490     # We have three calls because the calculation of `ais_weights` entails
    491     # another call to the `convex_combined_log_prob_fn`. We could refactor
    492     # things to avoid this, if needed (eg, b/72994218).
    493     self.assertAllEqual(dict(target_calls=3, proposal_calls=3), counter)
    494 
    495     event_shape = array_ops.shape(init)[independent_chain_ndims:]
    496     event_size = math_ops.reduce_prod(event_shape)
    497 
    498     log_true_normalizer = (
    499         -self._shape_param * math_ops.log(self._rate_param)
    500         + math_ops.lgamma(self._shape_param))
    501     log_true_normalizer *= math_ops.cast(event_size, log_true_normalizer.dtype)
    502 
    503     log_estimated_normalizer = (math_ops.reduce_logsumexp(ais_weights)
    504                                 - np.log(num_steps))
    505 
    506     ratio_estimate_true = math_ops.exp(ais_weights - log_true_normalizer)
    507     ais_weights_size = array_ops.size(ais_weights)
    508     standard_error = math_ops.sqrt(
    509         _reduce_variance(ratio_estimate_true)
    510         / math_ops.cast(ais_weights_size, ratio_estimate_true.dtype))
    511 
    512     [
    513         ratio_estimate_true_,
    514         log_true_normalizer_,
    515         log_estimated_normalizer_,
    516         standard_error_,
    517         ais_weights_size_,
    518         event_size_,
    519     ] = sess.run([
    520         ratio_estimate_true,
    521         log_true_normalizer,
    522         log_estimated_normalizer,
    523         standard_error,
    524         ais_weights_size,
    525         event_size,
    526     ], feed_dict)
    527 
    528     logging_ops.vlog(1, "        log_true_normalizer: {}\n"
    529                         "   log_estimated_normalizer: {}\n"
    530                         "           ais_weights_size: {}\n"
    531                         "                 event_size: {}\n".format(
    532                             log_true_normalizer_,
    533                             log_estimated_normalizer_,
    534                             ais_weights_size_,
    535                             event_size_))
    536     self.assertNear(ratio_estimate_true_.mean(), 1., 4. * standard_error_)
    537 
    538   def _ais_gets_correct_log_normalizer_wrapper(self, independent_chain_ndims):
    539     """Tests that AIS yields reasonable estimates of normalizers."""
    540     with self.test_session(graph=ops.Graph()) as sess:
    541       x_ph = array_ops.placeholder(np.float32, name="x_ph")
    542       initial_draws = np.random.normal(size=[30, 2, 1])
    543       self._ais_gets_correct_log_normalizer(
    544           x_ph,
    545           independent_chain_ndims,
    546           sess,
    547           feed_dict={x_ph: initial_draws})
    548 
    549   def testAIS1(self):
    550     self._ais_gets_correct_log_normalizer_wrapper(1)
    551 
    552   def testAIS2(self):
    553     self._ais_gets_correct_log_normalizer_wrapper(2)
    554 
    555   def testAIS3(self):
    556     self._ais_gets_correct_log_normalizer_wrapper(3)
    557 
    558   def testSampleAIChainSeedReproducibleWorksCorrectly(self):
    559     with self.test_session(graph=ops.Graph()) as sess:
    560       independent_chain_ndims = 1
    561       x = np.random.rand(4, 3, 2)
    562 
    563       def proposal_log_prob(x):
    564         event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
    565         return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi),
    566                                           axis=event_dims)
    567 
    568       def target_log_prob(x):
    569         event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
    570         return self._log_gamma_log_prob(x, event_dims)
    571 
    572       ais_kwargs = dict(
    573           proposal_log_prob_fn=proposal_log_prob,
    574           num_steps=200,
    575           target_log_prob_fn=target_log_prob,
    576           step_size=0.5,
    577           current_state=x,
    578           num_leapfrog_steps=2,
    579           seed=53)
    580 
    581       _, ais_weights0, _ = hmc.sample_annealed_importance_chain(
    582           **ais_kwargs)
    583 
    584       _, ais_weights1, _ = hmc.sample_annealed_importance_chain(
    585           **ais_kwargs)
    586 
    587       [ais_weights0_, ais_weights1_] = sess.run([
    588           ais_weights0, ais_weights1])
    589 
    590       self.assertAllClose(ais_weights0_, ais_weights1_,
    591                           atol=1e-5, rtol=1e-5)
    592 
    593   def testNanRejection(self):
    594     """Tests that an update that yields NaN potentials gets rejected.
    595 
    596     We run HMC with a target distribution that returns NaN
    597     log-likelihoods if any element of x < 0, and unit-scale
    598     exponential log-likelihoods otherwise. The exponential potential
    599     pushes x towards 0, ensuring that any reasonably large update will
    600     push us over the edge into NaN territory.
    601     """
    602     def _unbounded_exponential_log_prob(x):
    603       """An exponential distribution with log-likelihood NaN for x < 0."""
    604       per_element_potentials = array_ops.where(
    605           x < 0.,
    606           array_ops.fill(array_ops.shape(x), x.dtype.as_numpy_dtype(np.nan)),
    607           -x)
    608       return math_ops.reduce_sum(per_element_potentials)
    609 
    610     with self.test_session(graph=ops.Graph()) as sess:
    611       initial_x = math_ops.linspace(0.01, 5, 10)
    612       updated_x, kernel_results = hmc.kernel(
    613           target_log_prob_fn=_unbounded_exponential_log_prob,
    614           current_state=initial_x,
    615           step_size=2.,
    616           num_leapfrog_steps=5,
    617           seed=46)
    618       initial_x_, updated_x_, acceptance_probs_ = sess.run(
    619           [initial_x, updated_x, kernel_results.acceptance_probs])
    620 
    621       logging_ops.vlog(1, "initial_x = {}".format(initial_x_))
    622       logging_ops.vlog(1, "updated_x = {}".format(updated_x_))
    623       logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_))
    624 
    625       self.assertAllEqual(initial_x_, updated_x_)
    626       self.assertEqual(acceptance_probs_, 0.)
    627 
    628   def testNanFromGradsDontPropagate(self):
    629     """Test that update with NaN gradients does not cause NaN in results."""
    630     def _nan_log_prob_with_nan_gradient(x):
    631       return np.nan * math_ops.reduce_sum(x)
    632 
    633     with self.test_session(graph=ops.Graph()) as sess:
    634       initial_x = math_ops.linspace(0.01, 5, 10)
    635       updated_x, kernel_results = hmc.kernel(
    636           target_log_prob_fn=_nan_log_prob_with_nan_gradient,
    637           current_state=initial_x,
    638           step_size=2.,
    639           num_leapfrog_steps=5,
    640           seed=47)
    641       initial_x_, updated_x_, acceptance_probs_ = sess.run(
    642           [initial_x, updated_x, kernel_results.acceptance_probs])
    643 
    644       logging_ops.vlog(1, "initial_x = {}".format(initial_x_))
    645       logging_ops.vlog(1, "updated_x = {}".format(updated_x_))
    646       logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_))
    647 
    648       self.assertAllEqual(initial_x_, updated_x_)
    649       self.assertEqual(acceptance_probs_, 0.)
    650 
    651       self.assertAllFinite(
    652           gradients_ops.gradients(updated_x, initial_x)[0].eval())
    653       self.assertAllEqual([True], [g is None for g in gradients_ops.gradients(
    654           kernel_results.proposed_grads_target_log_prob, initial_x)])
    655       self.assertAllEqual([False], [g is None for g in gradients_ops.gradients(
    656           kernel_results.proposed_grads_target_log_prob,
    657           kernel_results.proposed_state)])
    658 
    659       # Gradients of the acceptance probs and new log prob are not finite.
    660       # self.assertAllFinite(
    661       #     gradients_ops.gradients(acceptance_probs, initial_x)[0].eval())
    662       # self.assertAllFinite(
    663       #     gradients_ops.gradients(new_log_prob, initial_x)[0].eval())
    664 
    665   def _testChainWorksDtype(self, dtype):
    666     with self.test_session(graph=ops.Graph()) as sess:
    667       states, kernel_results = hmc.sample_chain(
    668           num_results=10,
    669           target_log_prob_fn=lambda x: -math_ops.reduce_sum(x**2., axis=-1),
    670           current_state=np.zeros(5).astype(dtype),
    671           step_size=0.01,
    672           num_leapfrog_steps=10,
    673           seed=48)
    674       states_, acceptance_probs_ = sess.run(
    675           [states, kernel_results.acceptance_probs])
    676       self.assertEqual(dtype, states_.dtype)
    677       self.assertEqual(dtype, acceptance_probs_.dtype)
    678 
    679   def testChainWorksIn64Bit(self):
    680     self._testChainWorksDtype(np.float64)
    681 
    682   def testChainWorksIn16Bit(self):
    683     self._testChainWorksDtype(np.float16)
    684 
    685   def testChainWorksCorrelatedMultivariate(self):
    686     dtype = np.float32
    687     true_mean = dtype([0, 0])
    688     true_cov = dtype([[1, 0.5],
    689                       [0.5, 1]])
    690     num_results = 2000
    691     counter = collections.Counter()
    692     with self.test_session(graph=ops.Graph()) as sess:
    693       def target_log_prob(x, y):
    694         counter["target_calls"] += 1
    695         # Corresponds to unnormalized MVN.
    696         # z = matmul(inv(chol(true_cov)), [x, y] - true_mean)
    697         z = array_ops.stack([x, y], axis=-1) - true_mean
    698         z = array_ops.squeeze(
    699             gen_linalg_ops.matrix_triangular_solve(
    700                 np.linalg.cholesky(true_cov),
    701                 z[..., array_ops.newaxis]),
    702             axis=-1)
    703         return -0.5 * math_ops.reduce_sum(z**2., axis=-1)
    704       states, _ = hmc.sample_chain(
    705           num_results=num_results,
    706           target_log_prob_fn=target_log_prob,
    707           current_state=[dtype(-2), dtype(2)],
    708           step_size=[0.5, 0.5],
    709           num_leapfrog_steps=2,
    710           num_burnin_steps=200,
    711           num_steps_between_results=1,
    712           seed=54)
    713       self.assertAllEqual(dict(target_calls=2), counter)
    714       states = array_ops.stack(states, axis=-1)
    715       self.assertEqual(num_results, states.shape[0].value)
    716       sample_mean = math_ops.reduce_mean(states, axis=0)
    717       x = states - sample_mean
    718       sample_cov = math_ops.matmul(x, x, transpose_a=True) / dtype(num_results)
    719       [sample_mean_, sample_cov_] = sess.run([
    720           sample_mean, sample_cov])
    721       self.assertAllClose(true_mean, sample_mean_,
    722                           atol=0.05, rtol=0.)
    723       self.assertAllClose(true_cov, sample_cov_,
    724                           atol=0., rtol=0.1)
    725 
    726 
    727 class _EnergyComputationTest(object):
    728 
    729   def testHandlesNanFromPotential(self):
    730     with self.test_session(graph=ops.Graph()) as sess:
    731       x = [1, np.inf, -np.inf, np.nan]
    732       target_log_prob, proposed_target_log_prob = [
    733           self.dtype(x.flatten()) for x in np.meshgrid(x, x)]
    734       num_chains = len(target_log_prob)
    735       dummy_momentums = [-1, 1]
    736       momentums = [self.dtype([dummy_momentums] * num_chains)]
    737       proposed_momentums = [self.dtype([dummy_momentums] * num_chains)]
    738 
    739       target_log_prob = ops.convert_to_tensor(target_log_prob)
    740       momentums = [ops.convert_to_tensor(momentums[0])]
    741       proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob)
    742       proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])]
    743 
    744       energy = _compute_energy_change(
    745           target_log_prob,
    746           momentums,
    747           proposed_target_log_prob,
    748           proposed_momentums,
    749           independent_chain_ndims=1)
    750       grads = gradients_ops.gradients(energy, momentums)
    751 
    752       [actual_energy, grads_] = sess.run([energy, grads])
    753 
    754       # Ensure energy is `inf` (note: that's positive inf) in weird cases and
    755       # finite otherwise.
    756       expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1))
    757       self.assertAllEqual(expected_energy, actual_energy)
    758 
    759       # Ensure gradient is finite.
    760       self.assertAllEqual(np.ones_like(grads_).astype(np.bool),
    761                           np.isfinite(grads_))
    762 
    763   def testHandlesNanFromKinetic(self):
    764     with self.test_session(graph=ops.Graph()) as sess:
    765       x = [1, np.inf, -np.inf, np.nan]
    766       momentums, proposed_momentums = [
    767           [np.reshape(self.dtype(x), [-1, 1])]
    768           for x in np.meshgrid(x, x)]
    769       num_chains = len(momentums[0])
    770       target_log_prob = np.ones(num_chains, self.dtype)
    771       proposed_target_log_prob = np.ones(num_chains, self.dtype)
    772 
    773       target_log_prob = ops.convert_to_tensor(target_log_prob)
    774       momentums = [ops.convert_to_tensor(momentums[0])]
    775       proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob)
    776       proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])]
    777 
    778       energy = _compute_energy_change(
    779           target_log_prob,
    780           momentums,
    781           proposed_target_log_prob,
    782           proposed_momentums,
    783           independent_chain_ndims=1)
    784       grads = gradients_ops.gradients(energy, momentums)
    785 
    786       [actual_energy, grads_] = sess.run([energy, grads])
    787 
    788       # Ensure energy is `inf` (note: that's positive inf) in weird cases and
    789       # finite otherwise.
    790       expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1))
    791       self.assertAllEqual(expected_energy, actual_energy)
    792 
    793       # Ensure gradient is finite.
    794       g = grads_[0].reshape([len(x), len(x)])[:, 0]
    795       self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isfinite(g))
    796 
    797       # The remaining gradients are nan because the momentum was itself nan or
    798       # inf.
    799       g = grads_[0].reshape([len(x), len(x)])[:, 1:]
    800       self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isnan(g))
    801 
    802 
    803 class EnergyComputationTest16(test.TestCase, _EnergyComputationTest):
    804   dtype = np.float16
    805 
    806 
    807 class EnergyComputationTest32(test.TestCase, _EnergyComputationTest):
    808   dtype = np.float32
    809 
    810 
    811 class EnergyComputationTest64(test.TestCase, _EnergyComputationTest):
    812   dtype = np.float64
    813 
    814 
    815 class _HMCHandlesLists(object):
    816 
    817   def testStateParts(self):
    818     with self.test_session(graph=ops.Graph()) as sess:
    819       dist_x = normal_lib.Normal(loc=self.dtype(0), scale=self.dtype(1))
    820       dist_y = independent_lib.Independent(
    821           gamma_lib.Gamma(concentration=self.dtype([1, 2]),
    822                           rate=self.dtype([0.5, 0.75])),
    823           reinterpreted_batch_ndims=1)
    824       def target_log_prob(x, y):
    825         return dist_x.log_prob(x) + dist_y.log_prob(y)
    826       x0 = [dist_x.sample(seed=1), dist_y.sample(seed=2)]
    827       samples, _ = hmc.sample_chain(
    828           num_results=int(2e3),
    829           target_log_prob_fn=target_log_prob,
    830           current_state=x0,
    831           step_size=0.85,
    832           num_leapfrog_steps=3,
    833           num_burnin_steps=int(250),
    834           seed=49)
    835       actual_means = [math_ops.reduce_mean(s, axis=0) for s in samples]
    836       actual_vars = [_reduce_variance(s, axis=0) for s in samples]
    837       expected_means = [dist_x.mean(), dist_y.mean()]
    838       expected_vars = [dist_x.variance(), dist_y.variance()]
    839       [
    840           actual_means_,
    841           actual_vars_,
    842           expected_means_,
    843           expected_vars_,
    844       ] = sess.run([
    845           actual_means,
    846           actual_vars,
    847           expected_means,
    848           expected_vars,
    849       ])
    850       self.assertAllClose(expected_means_, actual_means_, atol=0.05, rtol=0.16)
    851       self.assertAllClose(expected_vars_, actual_vars_, atol=0., rtol=0.25)
    852 
    853 
    854 class HMCHandlesLists32(_HMCHandlesLists, test.TestCase):
    855   dtype = np.float32
    856 
    857 
    858 class HMCHandlesLists64(_HMCHandlesLists, test.TestCase):
    859   dtype = np.float64
    860 
    861 
    862 if __name__ == "__main__":
    863   test.main()
    864