1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for convolutional Bayesian layers.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.contrib.bayesflow.python.ops import layers_conv_variational as prob_layers_lib 24 from tensorflow.contrib.bayesflow.python.ops import layers_util as prob_layers_util 25 from tensorflow.contrib.distributions.python.ops import independent as independent_lib 26 from tensorflow.python.framework import dtypes 27 from tensorflow.python.framework import ops 28 from tensorflow.python.framework import tensor_shape 29 from tensorflow.python.ops import array_ops 30 from tensorflow.python.ops import math_ops 31 from tensorflow.python.ops import nn 32 from tensorflow.python.ops import nn_ops 33 from tensorflow.python.ops import random_ops 34 from tensorflow.python.ops.distributions import normal as normal_lib 35 from tensorflow.python.ops.distributions import util as distribution_util 36 from tensorflow.python.platform import test 37 38 39 class Counter(object): 40 """Helper class to manage incrementing a counting `int`.""" 41 42 def __init__(self): 43 self._value = -1 44 45 @property 46 def value(self): 47 return self._value 48 49 def __call__(self): 50 self._value += 1 51 return self._value 52 53 54 class MockDistribution(independent_lib.Independent): 55 """Monitors layer calls to the underlying distribution.""" 56 57 def __init__(self, result_sample, result_log_prob, loc=None, scale=None): 58 self.result_sample = result_sample 59 self.result_log_prob = result_log_prob 60 self.result_loc = loc 61 self.result_scale = scale 62 self.result_distribution = normal_lib.Normal(loc=0.0, scale=1.0) 63 if loc is not None and scale is not None: 64 self.result_distribution = normal_lib.Normal(loc=self.result_loc, 65 scale=self.result_scale) 66 self.called_log_prob = Counter() 67 self.called_sample = Counter() 68 self.called_loc = Counter() 69 self.called_scale = Counter() 70 71 def log_prob(self, *args, **kwargs): 72 self.called_log_prob() 73 return self.result_log_prob 74 75 def sample(self, *args, **kwargs): 76 self.called_sample() 77 return self.result_sample 78 79 @property 80 def distribution(self): # for dummy check on Independent(Normal) 81 return self.result_distribution 82 83 @property 84 def loc(self): 85 self.called_loc() 86 return self.result_loc 87 88 @property 89 def scale(self): 90 self.called_scale() 91 return self.result_scale 92 93 94 class MockKLDivergence(object): 95 """Monitors layer calls to the divergence implementation.""" 96 97 def __init__(self, result): 98 self.result = result 99 self.args = [] 100 self.called = Counter() 101 102 def __call__(self, *args, **kwargs): 103 self.called() 104 self.args.append(args) 105 return self.result 106 107 108 class ConvVariational(test.TestCase): 109 110 def _testKLPenaltyKernel(self, layer_class): 111 with self.test_session(): 112 layer = layer_class(filters=2, kernel_size=3) 113 if layer_class in (prob_layers_lib.Conv1DReparameterization, 114 prob_layers_lib.Conv1DFlipout): 115 inputs = random_ops.random_uniform([2, 3, 1], seed=1) 116 elif layer_class in (prob_layers_lib.Conv2DReparameterization, 117 prob_layers_lib.Conv2DFlipout): 118 inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1) 119 elif layer_class in (prob_layers_lib.Conv3DReparameterization, 120 prob_layers_lib.Conv3DFlipout): 121 inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1) 122 123 # No keys. 124 losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) 125 self.assertEqual(len(losses), 0) 126 self.assertListEqual(layer.losses, losses) 127 128 _ = layer(inputs) 129 130 # Yes keys. 131 losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) 132 self.assertEqual(len(losses), 1) 133 self.assertListEqual(layer.losses, losses) 134 135 def _testKLPenaltyBoth(self, layer_class): 136 def _make_normal(dtype, *args): # pylint: disable=unused-argument 137 return normal_lib.Normal( 138 loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)) 139 with self.test_session(): 140 layer = layer_class( 141 filters=2, 142 kernel_size=3, 143 bias_posterior_fn=prob_layers_util.default_mean_field_normal_fn(), 144 bias_prior_fn=_make_normal) 145 if layer_class in (prob_layers_lib.Conv1DReparameterization, 146 prob_layers_lib.Conv1DFlipout): 147 inputs = random_ops.random_uniform([2, 3, 1], seed=1) 148 elif layer_class in (prob_layers_lib.Conv2DReparameterization, 149 prob_layers_lib.Conv2DFlipout): 150 inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1) 151 elif layer_class in (prob_layers_lib.Conv3DReparameterization, 152 prob_layers_lib.Conv3DFlipout): 153 inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1) 154 155 # No keys. 156 losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) 157 self.assertEqual(len(losses), 0) 158 self.assertListEqual(layer.losses, losses) 159 160 _ = layer(inputs) 161 162 # Yes keys. 163 losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) 164 self.assertEqual(len(losses), 2) 165 self.assertListEqual(layer.losses, losses) 166 167 def _testConvSetUp(self, layer_class, batch_size, depth=None, 168 height=None, width=None, channels=None, filters=None, 169 **kwargs): 170 seed = Counter() 171 if layer_class in (prob_layers_lib.Conv1DReparameterization, 172 prob_layers_lib.Conv1DFlipout): 173 inputs = random_ops.random_uniform( 174 [batch_size, width, channels], seed=seed()) 175 kernel_size = (2,) 176 elif layer_class in (prob_layers_lib.Conv2DReparameterization, 177 prob_layers_lib.Conv2DFlipout): 178 inputs = random_ops.random_uniform( 179 [batch_size, height, width, channels], seed=seed()) 180 kernel_size = (2, 2) 181 elif layer_class in (prob_layers_lib.Conv3DReparameterization, 182 prob_layers_lib.Conv3DFlipout): 183 inputs = random_ops.random_uniform( 184 [batch_size, depth, height, width, channels], seed=seed()) 185 kernel_size = (2, 2, 2) 186 187 kernel_shape = kernel_size + (channels, filters) 188 kernel_posterior = MockDistribution( 189 loc=random_ops.random_uniform(kernel_shape, seed=seed()), 190 scale=random_ops.random_uniform(kernel_shape, seed=seed()), 191 result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()), 192 result_sample=random_ops.random_uniform(kernel_shape, seed=seed())) 193 kernel_prior = MockDistribution( 194 result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()), 195 result_sample=random_ops.random_uniform(kernel_shape, seed=seed())) 196 kernel_divergence = MockKLDivergence( 197 result=random_ops.random_uniform(kernel_shape, seed=seed())) 198 199 bias_size = (filters,) 200 bias_posterior = MockDistribution( 201 result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), 202 result_sample=random_ops.random_uniform(bias_size, seed=seed())) 203 bias_prior = MockDistribution( 204 result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), 205 result_sample=random_ops.random_uniform(bias_size, seed=seed())) 206 bias_divergence = MockKLDivergence( 207 result=random_ops.random_uniform(bias_size, seed=seed())) 208 209 layer = layer_class( 210 filters=filters, 211 kernel_size=kernel_size, 212 padding="SAME", 213 kernel_posterior_fn=lambda *args: kernel_posterior, 214 kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), 215 kernel_prior_fn=lambda *args: kernel_prior, 216 kernel_divergence_fn=kernel_divergence, 217 bias_posterior_fn=lambda *args: bias_posterior, 218 bias_posterior_tensor_fn=lambda d: d.sample(seed=43), 219 bias_prior_fn=lambda *args: bias_prior, 220 bias_divergence_fn=bias_divergence, 221 **kwargs) 222 223 outputs = layer(inputs) 224 225 kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) 226 return (kernel_posterior, kernel_prior, kernel_divergence, 227 bias_posterior, bias_prior, bias_divergence, 228 layer, inputs, outputs, kl_penalty, kernel_shape) 229 230 def _testConvReparameterization(self, layer_class): 231 batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5 232 with self.test_session() as sess: 233 (kernel_posterior, kernel_prior, kernel_divergence, 234 bias_posterior, bias_prior, bias_divergence, layer, inputs, 235 outputs, kl_penalty, kernel_shape) = self._testConvSetUp( 236 layer_class, batch_size, 237 depth=depth, height=height, width=width, channels=channels, 238 filters=filters) 239 240 convolution_op = nn_ops.Convolution( 241 tensor_shape.TensorShape(inputs.shape), 242 filter_shape=tensor_shape.TensorShape(kernel_shape), 243 padding="SAME") 244 expected_outputs = convolution_op(inputs, kernel_posterior.result_sample) 245 expected_outputs = nn.bias_add(expected_outputs, 246 bias_posterior.result_sample, 247 data_format="NHWC") 248 249 [ 250 expected_outputs_, actual_outputs_, 251 expected_kernel_, actual_kernel_, 252 expected_kernel_divergence_, actual_kernel_divergence_, 253 expected_bias_, actual_bias_, 254 expected_bias_divergence_, actual_bias_divergence_, 255 ] = sess.run([ 256 expected_outputs, outputs, 257 kernel_posterior.result_sample, layer.kernel_posterior_tensor, 258 kernel_divergence.result, kl_penalty[0], 259 bias_posterior.result_sample, layer.bias_posterior_tensor, 260 bias_divergence.result, kl_penalty[1], 261 ]) 262 263 self.assertAllClose( 264 expected_kernel_, actual_kernel_, 265 rtol=1e-6, atol=0.) 266 self.assertAllClose( 267 expected_bias_, actual_bias_, 268 rtol=1e-6, atol=0.) 269 self.assertAllClose( 270 expected_outputs_, actual_outputs_, 271 rtol=1e-6, atol=0.) 272 self.assertAllClose( 273 expected_kernel_divergence_, actual_kernel_divergence_, 274 rtol=1e-6, atol=0.) 275 self.assertAllClose( 276 expected_bias_divergence_, actual_bias_divergence_, 277 rtol=1e-6, atol=0.) 278 279 self.assertAllEqual( 280 [[kernel_posterior.distribution, 281 kernel_prior.distribution, 282 kernel_posterior.result_sample]], 283 kernel_divergence.args) 284 285 self.assertAllEqual( 286 [[bias_posterior.distribution, 287 bias_prior.distribution, 288 bias_posterior.result_sample]], 289 bias_divergence.args) 290 291 def _testConvFlipout(self, layer_class): 292 batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5 293 with self.test_session() as sess: 294 (kernel_posterior, kernel_prior, kernel_divergence, 295 bias_posterior, bias_prior, bias_divergence, layer, inputs, 296 outputs, kl_penalty, kernel_shape) = self._testConvSetUp( 297 layer_class, batch_size, 298 depth=depth, height=height, width=width, channels=channels, 299 filters=filters, seed=44) 300 301 convolution_op = nn_ops.Convolution( 302 tensor_shape.TensorShape(inputs.shape), 303 filter_shape=tensor_shape.TensorShape(kernel_shape), 304 padding="SAME") 305 306 expected_kernel_posterior_affine = normal_lib.Normal( 307 loc=array_ops.zeros_like(kernel_posterior.result_loc), 308 scale=kernel_posterior.result_scale) 309 expected_kernel_posterior_affine_tensor = ( 310 expected_kernel_posterior_affine.sample(seed=42)) 311 312 expected_outputs = convolution_op( 313 inputs, kernel_posterior.distribution.loc) 314 315 input_shape = array_ops.shape(inputs) 316 output_shape = array_ops.shape(expected_outputs) 317 batch_shape = array_ops.expand_dims(input_shape[0], 0) 318 channels = input_shape[-1] 319 rank = len(inputs.get_shape()) - 2 320 321 sign_input = random_ops.random_uniform( 322 array_ops.concat([batch_shape, 323 array_ops.expand_dims(channels, 0)], 0), 324 minval=0, 325 maxval=2, 326 dtype=dtypes.int32, 327 seed=layer.seed) 328 sign_input = math_ops.cast(2 * sign_input - 1, inputs.dtype) 329 sign_output = random_ops.random_uniform( 330 array_ops.concat([batch_shape, 331 array_ops.expand_dims(filters, 0)], 0), 332 minval=0, 333 maxval=2, 334 dtype=dtypes.int32, 335 seed=distribution_util.gen_new_seed( 336 layer.seed, salt="conv_flipout")) 337 sign_output = math_ops.cast(2 * sign_output - 1, inputs.dtype) 338 for _ in range(rank): 339 sign_input = array_ops.expand_dims(sign_input, 1) # 2D ex: (B, 1, 1, C) 340 sign_output = array_ops.expand_dims(sign_output, 1) 341 342 sign_input = array_ops.tile( # tile for element-wise op broadcasting 343 sign_input, 344 [1] + [input_shape[i + 1] for i in range(rank)] + [1]) 345 sign_output = array_ops.tile( 346 sign_output, 347 [1] + [output_shape[i + 1] for i in range(rank)] + [1]) 348 349 perturbed_inputs = convolution_op( 350 inputs * sign_input, expected_kernel_posterior_affine_tensor) 351 perturbed_inputs *= sign_output 352 353 expected_outputs += perturbed_inputs 354 expected_outputs = nn.bias_add(expected_outputs, 355 bias_posterior.result_sample, 356 data_format="NHWC") 357 358 [ 359 expected_outputs_, actual_outputs_, 360 expected_kernel_divergence_, actual_kernel_divergence_, 361 expected_bias_, actual_bias_, 362 expected_bias_divergence_, actual_bias_divergence_, 363 ] = sess.run([ 364 expected_outputs, outputs, 365 kernel_divergence.result, kl_penalty[0], 366 bias_posterior.result_sample, layer.bias_posterior_tensor, 367 bias_divergence.result, kl_penalty[1], 368 ]) 369 370 self.assertAllClose( 371 expected_bias_, actual_bias_, 372 rtol=1e-6, atol=0.) 373 self.assertAllClose( 374 expected_outputs_, actual_outputs_, 375 rtol=1e-6, atol=0.) 376 self.assertAllClose( 377 expected_kernel_divergence_, actual_kernel_divergence_, 378 rtol=1e-6, atol=0.) 379 self.assertAllClose( 380 expected_bias_divergence_, actual_bias_divergence_, 381 rtol=1e-6, atol=0.) 382 383 self.assertAllEqual( 384 [[kernel_posterior.distribution, kernel_prior.distribution, None]], 385 kernel_divergence.args) 386 387 self.assertAllEqual( 388 [[bias_posterior.distribution, 389 bias_prior.distribution, 390 bias_posterior.result_sample]], 391 bias_divergence.args) 392 393 def _testRandomConvFlipout(self, layer_class): 394 batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5 395 with self.test_session() as sess: 396 seed = Counter() 397 if layer_class in (prob_layers_lib.Conv1DReparameterization, 398 prob_layers_lib.Conv1DFlipout): 399 inputs = random_ops.random_uniform( 400 [batch_size, width, channels], seed=seed()) 401 kernel_size = (2,) 402 elif layer_class in (prob_layers_lib.Conv2DReparameterization, 403 prob_layers_lib.Conv2DFlipout): 404 inputs = random_ops.random_uniform( 405 [batch_size, height, width, channels], seed=seed()) 406 kernel_size = (2, 2) 407 elif layer_class in (prob_layers_lib.Conv3DReparameterization, 408 prob_layers_lib.Conv3DFlipout): 409 inputs = random_ops.random_uniform( 410 [batch_size, depth, height, width, channels], seed=seed()) 411 kernel_size = (2, 2, 2) 412 413 kernel_shape = kernel_size + (channels, filters) 414 bias_size = (filters,) 415 416 kernel_posterior = MockDistribution( 417 loc=random_ops.random_uniform( 418 kernel_shape, seed=seed()), 419 scale=random_ops.random_uniform( 420 kernel_shape, seed=seed()), 421 result_log_prob=random_ops.random_uniform( 422 kernel_shape, seed=seed()), 423 result_sample=random_ops.random_uniform( 424 kernel_shape, seed=seed())) 425 bias_posterior = MockDistribution( 426 loc=random_ops.random_uniform( 427 bias_size, seed=seed()), 428 scale=random_ops.random_uniform( 429 bias_size, seed=seed()), 430 result_log_prob=random_ops.random_uniform( 431 bias_size, seed=seed()), 432 result_sample=random_ops.random_uniform( 433 bias_size, seed=seed())) 434 layer_one = layer_class( 435 filters=filters, 436 kernel_size=kernel_size, 437 padding="SAME", 438 kernel_posterior_fn=lambda *args: kernel_posterior, 439 kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), 440 bias_posterior_fn=lambda *args: bias_posterior, 441 bias_posterior_tensor_fn=lambda d: d.sample(seed=43), 442 seed=44) 443 layer_two = layer_class( 444 filters=filters, 445 kernel_size=kernel_size, 446 padding="SAME", 447 kernel_posterior_fn=lambda *args: kernel_posterior, 448 kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), 449 bias_posterior_fn=lambda *args: bias_posterior, 450 bias_posterior_tensor_fn=lambda d: d.sample(seed=43), 451 seed=45) 452 453 outputs_one = layer_one(inputs) 454 outputs_two = layer_two(inputs) 455 456 outputs_one_, outputs_two_ = sess.run([ 457 outputs_one, outputs_two]) 458 459 self.assertLess(np.sum(np.isclose(outputs_one_, outputs_two_)), 460 np.prod(outputs_one_.shape)) 461 462 def testKLPenaltyKernelConv1DReparameterization(self): 463 self._testKLPenaltyKernel(prob_layers_lib.Conv1DReparameterization) 464 465 def testKLPenaltyKernelConv2DReparameterization(self): 466 self._testKLPenaltyKernel(prob_layers_lib.Conv2DReparameterization) 467 468 def testKLPenaltyKernelConv3DReparameterization(self): 469 self._testKLPenaltyKernel(prob_layers_lib.Conv3DReparameterization) 470 471 def testKLPenaltyKernelConv1DFlipout(self): 472 self._testKLPenaltyKernel(prob_layers_lib.Conv1DFlipout) 473 474 def testKLPenaltyKernelConv2DFlipout(self): 475 self._testKLPenaltyKernel(prob_layers_lib.Conv2DFlipout) 476 477 def testKLPenaltyKernelConv3DFlipout(self): 478 self._testKLPenaltyKernel(prob_layers_lib.Conv3DFlipout) 479 480 def testKLPenaltyBothConv1DReparameterization(self): 481 self._testKLPenaltyBoth(prob_layers_lib.Conv1DReparameterization) 482 483 def testKLPenaltyBothConv2DReparameterization(self): 484 self._testKLPenaltyBoth(prob_layers_lib.Conv2DReparameterization) 485 486 def testKLPenaltyBothConv3DReparameterization(self): 487 self._testKLPenaltyBoth(prob_layers_lib.Conv3DReparameterization) 488 489 def testKLPenaltyBothConv1DFlipout(self): 490 self._testKLPenaltyBoth(prob_layers_lib.Conv1DFlipout) 491 492 def testKLPenaltyBothConv2DFlipout(self): 493 self._testKLPenaltyBoth(prob_layers_lib.Conv2DFlipout) 494 495 def testKLPenaltyBothConv3DFlipout(self): 496 self._testKLPenaltyBoth(prob_layers_lib.Conv3DFlipout) 497 498 def testConv1DReparameterization(self): 499 self._testConvReparameterization(prob_layers_lib.Conv1DReparameterization) 500 501 def testConv2DReparameterization(self): 502 self._testConvReparameterization(prob_layers_lib.Conv2DReparameterization) 503 504 def testConv3DReparameterization(self): 505 self._testConvReparameterization(prob_layers_lib.Conv3DReparameterization) 506 507 def testConv1DFlipout(self): 508 self._testConvFlipout(prob_layers_lib.Conv1DFlipout) 509 510 def testConv2DFlipout(self): 511 self._testConvFlipout(prob_layers_lib.Conv2DFlipout) 512 513 def testConv3DFlipout(self): 514 self._testConvFlipout(prob_layers_lib.Conv3DFlipout) 515 516 def testRandomConv1DFlipout(self): 517 self._testRandomConvFlipout(prob_layers_lib.Conv1DFlipout) 518 519 520 if __name__ == "__main__": 521 test.main() 522