Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for metropolis_hastings.py."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings_impl as mh
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.ops import array_ops
     25 from tensorflow.python.ops import math_ops
     26 from tensorflow.python.ops import random_ops
     27 from tensorflow.python.ops import variable_scope
     28 from tensorflow.python.ops import variables
     29 from tensorflow.python.platform import test
     30 
     31 
     32 class McmcStepTest(test.TestCase):
     33 
     34   def test_density_increasing_step_accepted(self):
     35     """Tests that if a transition increases density, it is always accepted."""
     36     target_log_density = lambda x: - x * x
     37     state = variable_scope.get_variable('state', initializer=10.)
     38     state_log_density = variable_scope.get_variable(
     39         'state_log_density',
     40         initializer=target_log_density(state.initialized_value()))
     41     log_accept_ratio = variable_scope.get_variable(
     42         'log_accept_ratio', initializer=0.)
     43 
     44     get_next_proposal = lambda x: (x - 1., None)
     45     step = mh.evolve(state, state_log_density, log_accept_ratio,
     46                      target_log_density, get_next_proposal, seed=1234)
     47     init = variables.initialize_all_variables()
     48     with self.test_session() as sess:
     49       sess.run(init)
     50       for j in range(9):
     51         sess.run(step)
     52         sample = sess.run(state)
     53         sample_log_density = sess.run(state_log_density)
     54         self.assertAlmostEqual(sample, 9 - j)
     55         self.assertAlmostEqual(sample_log_density, - (9 - j) * (9 - j))
     56 
     57   def test_sample_properties(self):
     58     """Tests that the samples converge to the target distribution."""
     59 
     60     def target_log_density(x):
     61       """Log-density corresponding to a normal distribution with mean = 4."""
     62       return - (x - 2.0) * (x - 2.0) * 0.5
     63 
     64     # Use the uniform random walker to generate proposals.
     65     proposal_fn = mh.uniform_random_proposal(
     66         step_size=1.0, seed=1234)
     67 
     68     state = variable_scope.get_variable('state', initializer=0.0)
     69     state_log_density = variable_scope.get_variable(
     70         'state_log_density',
     71         initializer=target_log_density(state.initialized_value()))
     72 
     73     log_accept_ratio = variable_scope.get_variable(
     74         'log_accept_ratio', initializer=0.)
     75     # Random walk MCMC converges slowly so need to put in enough iterations.
     76     num_iterations = 5000
     77     step = mh.evolve(state, state_log_density, log_accept_ratio,
     78                      target_log_density, proposal_fn, seed=4321)
     79 
     80     init = variables.global_variables_initializer()
     81 
     82     sample_sum, sample_sq_sum = 0.0, 0.0
     83     with self.test_session() as sess:
     84       sess.run(init)
     85       for _ in np.arange(num_iterations):
     86         # Allow for the mixing of the chain and discard these samples.
     87         sess.run(step)
     88       for _ in np.arange(num_iterations):
     89         sess.run(step)
     90         sample = sess.run(state)
     91         sample_sum += sample
     92         sample_sq_sum += sample * sample
     93 
     94     sample_mean = sample_sum / num_iterations
     95     sample_variance = sample_sq_sum / num_iterations - sample_mean * sample_mean
     96     # The samples have large autocorrelation which reduces the effective sample
     97     # size.
     98     self.assertAlmostEqual(sample_mean, 2.0, delta=0.1)
     99     self.assertAlmostEqual(sample_variance, 1.0, delta=0.1)
    100 
    101   def test_normal_proposals(self):
    102     """Tests that the normal proposals are correctly distributed."""
    103 
    104     initial_points = array_ops.ones([10000], dtype=dtypes.float32)
    105     proposal_fn = mh.normal_random_proposal(
    106         scale=2.0, seed=1234)
    107     proposal_points, _ = proposal_fn(initial_points)
    108 
    109     with self.test_session() as sess:
    110       sample = sess.run(proposal_points)
    111 
    112     # It is expected that the elements in proposal_points have the same mean as
    113     # initial_points and have the standard deviation that was supplied to the
    114     # proposal scheme.
    115     self.assertAlmostEqual(np.mean(sample), 1.0, delta=0.1)
    116     self.assertAlmostEqual(np.std(sample), 2.0, delta=0.1)
    117 
    118   def test_docstring_example(self):
    119     """Tests the simplified docstring example with multiple chains."""
    120 
    121     n = 2  # dimension of the problem
    122 
    123     # Generate 300 initial values randomly. Each of these would be an
    124     # independent starting point for a Markov chain.
    125     state = variable_scope.get_variable(
    126         'state', initializer=random_ops.random_normal(
    127             [300, n], mean=3.0, dtype=dtypes.float32, seed=42))
    128 
    129     # Computes the log(p(x)) for the unit normal density and ignores the
    130     # normalization constant.
    131     def log_density(x):
    132       return  - math_ops.reduce_sum(x * x, reduction_indices=-1) / 2.0
    133 
    134     # Initial log-density value
    135     state_log_density = variable_scope.get_variable(
    136         'state_log_density',
    137         initializer=log_density(state.initialized_value()))
    138 
    139     # A variable to store the log_acceptance_ratio:
    140     log_acceptance_ratio = variable_scope.get_variable(
    141         'log_acceptance_ratio',
    142         initializer=array_ops.zeros([300], dtype=dtypes.float32))
    143 
    144     # Generates random proposals by moving each coordinate uniformly and
    145     # independently in a box of size 2 centered around the current value.
    146     # Returns the new point and also the log of the Hastings ratio (the
    147     # ratio of the probability of going from the proposal to origin and the
    148     # probability of the reverse transition). When this ratio is 1, the value
    149     # may be omitted and replaced by None.
    150     def random_proposal(x):
    151       return (x + random_ops.random_uniform(
    152           array_ops.shape(x), minval=-1, maxval=1,
    153           dtype=x.dtype, seed=12)), None
    154 
    155     #  Create the op to propagate the chain for 100 steps.
    156     stepper = mh.evolve(
    157         state, state_log_density, log_acceptance_ratio,
    158         log_density, random_proposal, n_steps=100, seed=123)
    159     init = variables.initialize_all_variables()
    160     with self.test_session() as sess:
    161       sess.run(init)
    162       # Run the chains for a total of 1000 steps.
    163       for _ in range(10):
    164         sess.run(stepper)
    165       samples = sess.run(state)
    166       covariance = np.eye(n)
    167       # Verify that the estimated mean and covariance are close to the true
    168       # values.
    169       self.assertAlmostEqual(
    170           np.max(np.abs(np.mean(samples, 0)
    171                         - np.zeros(n))), 0,
    172           delta=0.1)
    173       self.assertAlmostEqual(
    174           np.max(np.abs(np.reshape(np.cov(samples, rowvar=False), [n**2])
    175                         - np.reshape(covariance, [n**2]))), 0,
    176           delta=0.2)
    177 
    178 if __name__ == '__main__':
    179   test.main()
    180