Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for convolutional Bayesian layers."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.contrib.bayesflow.python.ops import layers_conv_variational as prob_layers_lib
     24 from tensorflow.contrib.bayesflow.python.ops import layers_util as prob_layers_util
     25 from tensorflow.contrib.distributions.python.ops import independent as independent_lib
     26 from tensorflow.python.framework import dtypes
     27 from tensorflow.python.framework import ops
     28 from tensorflow.python.framework import tensor_shape
     29 from tensorflow.python.ops import array_ops
     30 from tensorflow.python.ops import math_ops
     31 from tensorflow.python.ops import nn
     32 from tensorflow.python.ops import nn_ops
     33 from tensorflow.python.ops import random_ops
     34 from tensorflow.python.ops.distributions import normal as normal_lib
     35 from tensorflow.python.ops.distributions import util as distribution_util
     36 from tensorflow.python.platform import test
     37 
     38 
     39 class Counter(object):
     40   """Helper class to manage incrementing a counting `int`."""
     41 
     42   def __init__(self):
     43     self._value = -1
     44 
     45   @property
     46   def value(self):
     47     return self._value
     48 
     49   def __call__(self):
     50     self._value += 1
     51     return self._value
     52 
     53 
     54 class MockDistribution(independent_lib.Independent):
     55   """Monitors layer calls to the underlying distribution."""
     56 
     57   def __init__(self, result_sample, result_log_prob, loc=None, scale=None):
     58     self.result_sample = result_sample
     59     self.result_log_prob = result_log_prob
     60     self.result_loc = loc
     61     self.result_scale = scale
     62     self.result_distribution = normal_lib.Normal(loc=0.0, scale=1.0)
     63     if loc is not None and scale is not None:
     64       self.result_distribution = normal_lib.Normal(loc=self.result_loc,
     65                                                    scale=self.result_scale)
     66     self.called_log_prob = Counter()
     67     self.called_sample = Counter()
     68     self.called_loc = Counter()
     69     self.called_scale = Counter()
     70 
     71   def log_prob(self, *args, **kwargs):
     72     self.called_log_prob()
     73     return self.result_log_prob
     74 
     75   def sample(self, *args, **kwargs):
     76     self.called_sample()
     77     return self.result_sample
     78 
     79   @property
     80   def distribution(self):  # for dummy check on Independent(Normal)
     81     return self.result_distribution
     82 
     83   @property
     84   def loc(self):
     85     self.called_loc()
     86     return self.result_loc
     87 
     88   @property
     89   def scale(self):
     90     self.called_scale()
     91     return self.result_scale
     92 
     93 
     94 class MockKLDivergence(object):
     95   """Monitors layer calls to the divergence implementation."""
     96 
     97   def __init__(self, result):
     98     self.result = result
     99     self.args = []
    100     self.called = Counter()
    101 
    102   def __call__(self, *args, **kwargs):
    103     self.called()
    104     self.args.append(args)
    105     return self.result
    106 
    107 
    108 class ConvVariational(test.TestCase):
    109 
    110   def _testKLPenaltyKernel(self, layer_class):
    111     with self.test_session():
    112       layer = layer_class(filters=2, kernel_size=3)
    113       if layer_class in (prob_layers_lib.Conv1DReparameterization,
    114                          prob_layers_lib.Conv1DFlipout):
    115         inputs = random_ops.random_uniform([2, 3, 1], seed=1)
    116       elif layer_class in (prob_layers_lib.Conv2DReparameterization,
    117                            prob_layers_lib.Conv2DFlipout):
    118         inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1)
    119       elif layer_class in (prob_layers_lib.Conv3DReparameterization,
    120                            prob_layers_lib.Conv3DFlipout):
    121         inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1)
    122 
    123       # No keys.
    124       losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
    125       self.assertEqual(len(losses), 0)
    126       self.assertListEqual(layer.losses, losses)
    127 
    128       _ = layer(inputs)
    129 
    130       # Yes keys.
    131       losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
    132       self.assertEqual(len(losses), 1)
    133       self.assertListEqual(layer.losses, losses)
    134 
    135   def _testKLPenaltyBoth(self, layer_class):
    136     def _make_normal(dtype, *args):  # pylint: disable=unused-argument
    137       return normal_lib.Normal(
    138           loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.))
    139     with self.test_session():
    140       layer = layer_class(
    141           filters=2,
    142           kernel_size=3,
    143           bias_posterior_fn=prob_layers_util.default_mean_field_normal_fn(),
    144           bias_prior_fn=_make_normal)
    145       if layer_class in (prob_layers_lib.Conv1DReparameterization,
    146                          prob_layers_lib.Conv1DFlipout):
    147         inputs = random_ops.random_uniform([2, 3, 1], seed=1)
    148       elif layer_class in (prob_layers_lib.Conv2DReparameterization,
    149                            prob_layers_lib.Conv2DFlipout):
    150         inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1)
    151       elif layer_class in (prob_layers_lib.Conv3DReparameterization,
    152                            prob_layers_lib.Conv3DFlipout):
    153         inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1)
    154 
    155       # No keys.
    156       losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
    157       self.assertEqual(len(losses), 0)
    158       self.assertListEqual(layer.losses, losses)
    159 
    160       _ = layer(inputs)
    161 
    162       # Yes keys.
    163       losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
    164       self.assertEqual(len(losses), 2)
    165       self.assertListEqual(layer.losses, losses)
    166 
    167   def _testConvSetUp(self, layer_class, batch_size, depth=None,
    168                      height=None, width=None, channels=None, filters=None,
    169                      **kwargs):
    170     seed = Counter()
    171     if layer_class in (prob_layers_lib.Conv1DReparameterization,
    172                        prob_layers_lib.Conv1DFlipout):
    173       inputs = random_ops.random_uniform(
    174           [batch_size, width, channels], seed=seed())
    175       kernel_size = (2,)
    176     elif layer_class in (prob_layers_lib.Conv2DReparameterization,
    177                          prob_layers_lib.Conv2DFlipout):
    178       inputs = random_ops.random_uniform(
    179           [batch_size, height, width, channels], seed=seed())
    180       kernel_size = (2, 2)
    181     elif layer_class in (prob_layers_lib.Conv3DReparameterization,
    182                          prob_layers_lib.Conv3DFlipout):
    183       inputs = random_ops.random_uniform(
    184           [batch_size, depth, height, width, channels], seed=seed())
    185       kernel_size = (2, 2, 2)
    186 
    187     kernel_shape = kernel_size + (channels, filters)
    188     kernel_posterior = MockDistribution(
    189         loc=random_ops.random_uniform(kernel_shape, seed=seed()),
    190         scale=random_ops.random_uniform(kernel_shape, seed=seed()),
    191         result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()),
    192         result_sample=random_ops.random_uniform(kernel_shape, seed=seed()))
    193     kernel_prior = MockDistribution(
    194         result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()),
    195         result_sample=random_ops.random_uniform(kernel_shape, seed=seed()))
    196     kernel_divergence = MockKLDivergence(
    197         result=random_ops.random_uniform(kernel_shape, seed=seed()))
    198 
    199     bias_size = (filters,)
    200     bias_posterior = MockDistribution(
    201         result_log_prob=random_ops.random_uniform(bias_size, seed=seed()),
    202         result_sample=random_ops.random_uniform(bias_size, seed=seed()))
    203     bias_prior = MockDistribution(
    204         result_log_prob=random_ops.random_uniform(bias_size, seed=seed()),
    205         result_sample=random_ops.random_uniform(bias_size, seed=seed()))
    206     bias_divergence = MockKLDivergence(
    207         result=random_ops.random_uniform(bias_size, seed=seed()))
    208 
    209     layer = layer_class(
    210         filters=filters,
    211         kernel_size=kernel_size,
    212         padding="SAME",
    213         kernel_posterior_fn=lambda *args: kernel_posterior,
    214         kernel_posterior_tensor_fn=lambda d: d.sample(seed=42),
    215         kernel_prior_fn=lambda *args: kernel_prior,
    216         kernel_divergence_fn=kernel_divergence,
    217         bias_posterior_fn=lambda *args: bias_posterior,
    218         bias_posterior_tensor_fn=lambda d: d.sample(seed=43),
    219         bias_prior_fn=lambda *args: bias_prior,
    220         bias_divergence_fn=bias_divergence,
    221         **kwargs)
    222 
    223     outputs = layer(inputs)
    224 
    225     kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
    226     return (kernel_posterior, kernel_prior, kernel_divergence,
    227             bias_posterior, bias_prior, bias_divergence,
    228             layer, inputs, outputs, kl_penalty, kernel_shape)
    229 
    230   def _testConvReparameterization(self, layer_class):
    231     batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5
    232     with self.test_session() as sess:
    233       (kernel_posterior, kernel_prior, kernel_divergence,
    234        bias_posterior, bias_prior, bias_divergence, layer, inputs,
    235        outputs, kl_penalty, kernel_shape) = self._testConvSetUp(
    236            layer_class, batch_size,
    237            depth=depth, height=height, width=width, channels=channels,
    238            filters=filters)
    239 
    240       convolution_op = nn_ops.Convolution(
    241           tensor_shape.TensorShape(inputs.shape),
    242           filter_shape=tensor_shape.TensorShape(kernel_shape),
    243           padding="SAME")
    244       expected_outputs = convolution_op(inputs, kernel_posterior.result_sample)
    245       expected_outputs = nn.bias_add(expected_outputs,
    246                                      bias_posterior.result_sample,
    247                                      data_format="NHWC")
    248 
    249       [
    250           expected_outputs_, actual_outputs_,
    251           expected_kernel_, actual_kernel_,
    252           expected_kernel_divergence_, actual_kernel_divergence_,
    253           expected_bias_, actual_bias_,
    254           expected_bias_divergence_, actual_bias_divergence_,
    255       ] = sess.run([
    256           expected_outputs, outputs,
    257           kernel_posterior.result_sample, layer.kernel_posterior_tensor,
    258           kernel_divergence.result, kl_penalty[0],
    259           bias_posterior.result_sample, layer.bias_posterior_tensor,
    260           bias_divergence.result, kl_penalty[1],
    261       ])
    262 
    263       self.assertAllClose(
    264           expected_kernel_, actual_kernel_,
    265           rtol=1e-6, atol=0.)
    266       self.assertAllClose(
    267           expected_bias_, actual_bias_,
    268           rtol=1e-6, atol=0.)
    269       self.assertAllClose(
    270           expected_outputs_, actual_outputs_,
    271           rtol=1e-6, atol=0.)
    272       self.assertAllClose(
    273           expected_kernel_divergence_, actual_kernel_divergence_,
    274           rtol=1e-6, atol=0.)
    275       self.assertAllClose(
    276           expected_bias_divergence_, actual_bias_divergence_,
    277           rtol=1e-6, atol=0.)
    278 
    279       self.assertAllEqual(
    280           [[kernel_posterior.distribution,
    281             kernel_prior.distribution,
    282             kernel_posterior.result_sample]],
    283           kernel_divergence.args)
    284 
    285       self.assertAllEqual(
    286           [[bias_posterior.distribution,
    287             bias_prior.distribution,
    288             bias_posterior.result_sample]],
    289           bias_divergence.args)
    290 
    291   def _testConvFlipout(self, layer_class):
    292     batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5
    293     with self.test_session() as sess:
    294       (kernel_posterior, kernel_prior, kernel_divergence,
    295        bias_posterior, bias_prior, bias_divergence, layer, inputs,
    296        outputs, kl_penalty, kernel_shape) = self._testConvSetUp(
    297            layer_class, batch_size,
    298            depth=depth, height=height, width=width, channels=channels,
    299            filters=filters, seed=44)
    300 
    301       convolution_op = nn_ops.Convolution(
    302           tensor_shape.TensorShape(inputs.shape),
    303           filter_shape=tensor_shape.TensorShape(kernel_shape),
    304           padding="SAME")
    305 
    306       expected_kernel_posterior_affine = normal_lib.Normal(
    307           loc=array_ops.zeros_like(kernel_posterior.result_loc),
    308           scale=kernel_posterior.result_scale)
    309       expected_kernel_posterior_affine_tensor = (
    310           expected_kernel_posterior_affine.sample(seed=42))
    311 
    312       expected_outputs = convolution_op(
    313           inputs, kernel_posterior.distribution.loc)
    314 
    315       input_shape = array_ops.shape(inputs)
    316       output_shape = array_ops.shape(expected_outputs)
    317       batch_shape = array_ops.expand_dims(input_shape[0], 0)
    318       channels = input_shape[-1]
    319       rank = len(inputs.get_shape()) - 2
    320 
    321       sign_input = random_ops.random_uniform(
    322           array_ops.concat([batch_shape,
    323                             array_ops.expand_dims(channels, 0)], 0),
    324           minval=0,
    325           maxval=2,
    326           dtype=dtypes.int32,
    327           seed=layer.seed)
    328       sign_input = math_ops.cast(2 * sign_input - 1, inputs.dtype)
    329       sign_output = random_ops.random_uniform(
    330           array_ops.concat([batch_shape,
    331                             array_ops.expand_dims(filters, 0)], 0),
    332           minval=0,
    333           maxval=2,
    334           dtype=dtypes.int32,
    335           seed=distribution_util.gen_new_seed(
    336               layer.seed, salt="conv_flipout"))
    337       sign_output = math_ops.cast(2 * sign_output - 1, inputs.dtype)
    338       for _ in range(rank):
    339         sign_input = array_ops.expand_dims(sign_input, 1)  # 2D ex: (B, 1, 1, C)
    340         sign_output = array_ops.expand_dims(sign_output, 1)
    341 
    342       sign_input = array_ops.tile(  # tile for element-wise op broadcasting
    343           sign_input,
    344           [1] + [input_shape[i + 1] for i in range(rank)] + [1])
    345       sign_output = array_ops.tile(
    346           sign_output,
    347           [1] + [output_shape[i + 1] for i in range(rank)] + [1])
    348 
    349       perturbed_inputs = convolution_op(
    350           inputs * sign_input, expected_kernel_posterior_affine_tensor)
    351       perturbed_inputs *= sign_output
    352 
    353       expected_outputs += perturbed_inputs
    354       expected_outputs = nn.bias_add(expected_outputs,
    355                                      bias_posterior.result_sample,
    356                                      data_format="NHWC")
    357 
    358       [
    359           expected_outputs_, actual_outputs_,
    360           expected_kernel_divergence_, actual_kernel_divergence_,
    361           expected_bias_, actual_bias_,
    362           expected_bias_divergence_, actual_bias_divergence_,
    363       ] = sess.run([
    364           expected_outputs, outputs,
    365           kernel_divergence.result, kl_penalty[0],
    366           bias_posterior.result_sample, layer.bias_posterior_tensor,
    367           bias_divergence.result, kl_penalty[1],
    368       ])
    369 
    370       self.assertAllClose(
    371           expected_bias_, actual_bias_,
    372           rtol=1e-6, atol=0.)
    373       self.assertAllClose(
    374           expected_outputs_, actual_outputs_,
    375           rtol=1e-6, atol=0.)
    376       self.assertAllClose(
    377           expected_kernel_divergence_, actual_kernel_divergence_,
    378           rtol=1e-6, atol=0.)
    379       self.assertAllClose(
    380           expected_bias_divergence_, actual_bias_divergence_,
    381           rtol=1e-6, atol=0.)
    382 
    383       self.assertAllEqual(
    384           [[kernel_posterior.distribution, kernel_prior.distribution, None]],
    385           kernel_divergence.args)
    386 
    387       self.assertAllEqual(
    388           [[bias_posterior.distribution,
    389             bias_prior.distribution,
    390             bias_posterior.result_sample]],
    391           bias_divergence.args)
    392 
    393   def _testRandomConvFlipout(self, layer_class):
    394     batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5
    395     with self.test_session() as sess:
    396       seed = Counter()
    397       if layer_class in (prob_layers_lib.Conv1DReparameterization,
    398                          prob_layers_lib.Conv1DFlipout):
    399         inputs = random_ops.random_uniform(
    400             [batch_size, width, channels], seed=seed())
    401         kernel_size = (2,)
    402       elif layer_class in (prob_layers_lib.Conv2DReparameterization,
    403                            prob_layers_lib.Conv2DFlipout):
    404         inputs = random_ops.random_uniform(
    405             [batch_size, height, width, channels], seed=seed())
    406         kernel_size = (2, 2)
    407       elif layer_class in (prob_layers_lib.Conv3DReparameterization,
    408                            prob_layers_lib.Conv3DFlipout):
    409         inputs = random_ops.random_uniform(
    410             [batch_size, depth, height, width, channels], seed=seed())
    411         kernel_size = (2, 2, 2)
    412 
    413       kernel_shape = kernel_size + (channels, filters)
    414       bias_size = (filters,)
    415 
    416       kernel_posterior = MockDistribution(
    417           loc=random_ops.random_uniform(
    418               kernel_shape, seed=seed()),
    419           scale=random_ops.random_uniform(
    420               kernel_shape, seed=seed()),
    421           result_log_prob=random_ops.random_uniform(
    422               kernel_shape, seed=seed()),
    423           result_sample=random_ops.random_uniform(
    424               kernel_shape, seed=seed()))
    425       bias_posterior = MockDistribution(
    426           loc=random_ops.random_uniform(
    427               bias_size, seed=seed()),
    428           scale=random_ops.random_uniform(
    429               bias_size, seed=seed()),
    430           result_log_prob=random_ops.random_uniform(
    431               bias_size, seed=seed()),
    432           result_sample=random_ops.random_uniform(
    433               bias_size, seed=seed()))
    434       layer_one = layer_class(
    435           filters=filters,
    436           kernel_size=kernel_size,
    437           padding="SAME",
    438           kernel_posterior_fn=lambda *args: kernel_posterior,
    439           kernel_posterior_tensor_fn=lambda d: d.sample(seed=42),
    440           bias_posterior_fn=lambda *args: bias_posterior,
    441           bias_posterior_tensor_fn=lambda d: d.sample(seed=43),
    442           seed=44)
    443       layer_two = layer_class(
    444           filters=filters,
    445           kernel_size=kernel_size,
    446           padding="SAME",
    447           kernel_posterior_fn=lambda *args: kernel_posterior,
    448           kernel_posterior_tensor_fn=lambda d: d.sample(seed=42),
    449           bias_posterior_fn=lambda *args: bias_posterior,
    450           bias_posterior_tensor_fn=lambda d: d.sample(seed=43),
    451           seed=45)
    452 
    453       outputs_one = layer_one(inputs)
    454       outputs_two = layer_two(inputs)
    455 
    456       outputs_one_, outputs_two_ = sess.run([
    457           outputs_one, outputs_two])
    458 
    459       self.assertLess(np.sum(np.isclose(outputs_one_, outputs_two_)),
    460                       np.prod(outputs_one_.shape))
    461 
    462   def testKLPenaltyKernelConv1DReparameterization(self):
    463     self._testKLPenaltyKernel(prob_layers_lib.Conv1DReparameterization)
    464 
    465   def testKLPenaltyKernelConv2DReparameterization(self):
    466     self._testKLPenaltyKernel(prob_layers_lib.Conv2DReparameterization)
    467 
    468   def testKLPenaltyKernelConv3DReparameterization(self):
    469     self._testKLPenaltyKernel(prob_layers_lib.Conv3DReparameterization)
    470 
    471   def testKLPenaltyKernelConv1DFlipout(self):
    472     self._testKLPenaltyKernel(prob_layers_lib.Conv1DFlipout)
    473 
    474   def testKLPenaltyKernelConv2DFlipout(self):
    475     self._testKLPenaltyKernel(prob_layers_lib.Conv2DFlipout)
    476 
    477   def testKLPenaltyKernelConv3DFlipout(self):
    478     self._testKLPenaltyKernel(prob_layers_lib.Conv3DFlipout)
    479 
    480   def testKLPenaltyBothConv1DReparameterization(self):
    481     self._testKLPenaltyBoth(prob_layers_lib.Conv1DReparameterization)
    482 
    483   def testKLPenaltyBothConv2DReparameterization(self):
    484     self._testKLPenaltyBoth(prob_layers_lib.Conv2DReparameterization)
    485 
    486   def testKLPenaltyBothConv3DReparameterization(self):
    487     self._testKLPenaltyBoth(prob_layers_lib.Conv3DReparameterization)
    488 
    489   def testKLPenaltyBothConv1DFlipout(self):
    490     self._testKLPenaltyBoth(prob_layers_lib.Conv1DFlipout)
    491 
    492   def testKLPenaltyBothConv2DFlipout(self):
    493     self._testKLPenaltyBoth(prob_layers_lib.Conv2DFlipout)
    494 
    495   def testKLPenaltyBothConv3DFlipout(self):
    496     self._testKLPenaltyBoth(prob_layers_lib.Conv3DFlipout)
    497 
    498   def testConv1DReparameterization(self):
    499     self._testConvReparameterization(prob_layers_lib.Conv1DReparameterization)
    500 
    501   def testConv2DReparameterization(self):
    502     self._testConvReparameterization(prob_layers_lib.Conv2DReparameterization)
    503 
    504   def testConv3DReparameterization(self):
    505     self._testConvReparameterization(prob_layers_lib.Conv3DReparameterization)
    506 
    507   def testConv1DFlipout(self):
    508     self._testConvFlipout(prob_layers_lib.Conv1DFlipout)
    509 
    510   def testConv2DFlipout(self):
    511     self._testConvFlipout(prob_layers_lib.Conv2DFlipout)
    512 
    513   def testConv3DFlipout(self):
    514     self._testConvFlipout(prob_layers_lib.Conv3DFlipout)
    515 
    516   def testRandomConv1DFlipout(self):
    517     self._testRandomConvFlipout(prob_layers_lib.Conv1DFlipout)
    518 
    519 
    520 if __name__ == "__main__":
    521   test.main()
    522