1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for Bijector.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import abc 22 23 import six 24 25 from tensorflow.python.framework import constant_op 26 from tensorflow.python.ops import math_ops 27 from tensorflow.python.ops.distributions import bijector 28 from tensorflow.python.platform import test 29 30 31 class BaseBijectorTest(test.TestCase): 32 """Tests properties of the Bijector base-class.""" 33 34 def testIsAbstract(self): 35 with self.test_session(): 36 with self.assertRaisesRegexp(TypeError, 37 ("Can't instantiate abstract class Bijector " 38 "with abstract methods __init__")): 39 bijector.Bijector() # pylint: disable=abstract-class-instantiated 40 41 def testDefaults(self): 42 class _BareBonesBijector(bijector.Bijector): 43 """Minimal specification of a `Bijector`.""" 44 45 def __init__(self): 46 super(_BareBonesBijector, self).__init__() 47 48 with self.test_session() as sess: 49 bij = _BareBonesBijector() 50 self.assertEqual(None, bij.event_ndims) 51 self.assertEqual([], bij.graph_parents) 52 self.assertEqual(False, bij.is_constant_jacobian) 53 self.assertEqual(False, bij.validate_args) 54 self.assertEqual(None, bij.dtype) 55 self.assertEqual("bare_bones_bijector", bij.name) 56 57 for shape in [[], [1, 2], [1, 2, 3]]: 58 [ 59 forward_event_shape_, 60 inverse_event_shape_, 61 ] = sess.run([ 62 bij.inverse_event_shape_tensor(shape), 63 bij.forward_event_shape_tensor(shape), 64 ]) 65 self.assertAllEqual(shape, forward_event_shape_) 66 self.assertAllEqual(shape, bij.forward_event_shape(shape)) 67 self.assertAllEqual(shape, inverse_event_shape_) 68 self.assertAllEqual(shape, bij.inverse_event_shape(shape)) 69 70 for fn in ["forward", 71 "inverse", 72 "inverse_log_det_jacobian", 73 "forward_log_det_jacobian"]: 74 with self.assertRaisesRegexp( 75 NotImplementedError, fn + " not implemented"): 76 getattr(bij, fn)(0) 77 78 79 class IntentionallyMissingError(Exception): 80 pass 81 82 83 class BrokenBijector(bijector.Bijector): 84 """Forward and inverse are not inverses of each other.""" 85 86 def __init__(self, forward_missing=False, inverse_missing=False): 87 super(BrokenBijector, self).__init__( 88 event_ndims=0, validate_args=False, name="broken") 89 self._forward_missing = forward_missing 90 self._inverse_missing = inverse_missing 91 92 def _forward(self, x): 93 if self._forward_missing: 94 raise IntentionallyMissingError 95 return 2 * x 96 97 def _inverse(self, y): 98 if self._inverse_missing: 99 raise IntentionallyMissingError 100 return y / 2. 101 102 def _inverse_log_det_jacobian(self, y): # pylint:disable=unused-argument 103 if self._inverse_missing: 104 raise IntentionallyMissingError 105 return -math_ops.log(2.) 106 107 def _forward_log_det_jacobian(self, x): # pylint:disable=unused-argument 108 if self._forward_missing: 109 raise IntentionallyMissingError 110 return math_ops.log(2.) 111 112 113 @six.add_metaclass(abc.ABCMeta) 114 class BijectorCachingTestBase(object): 115 116 @abc.abstractproperty 117 def broken_bijector_cls(self): 118 # return a BrokenBijector type Bijector, since this will test the caching. 119 raise IntentionallyMissingError("Not implemented") 120 121 def testCachingOfForwardResults(self): 122 broken_bijector = self.broken_bijector_cls(inverse_missing=True) 123 with self.test_session(): 124 x = constant_op.constant(1.1) 125 126 # Call forward and forward_log_det_jacobian one-by-one (not together). 127 y = broken_bijector.forward(x) 128 _ = broken_bijector.forward_log_det_jacobian(x) 129 130 # Now, everything should be cached if the argument is y. 131 try: 132 broken_bijector.inverse(y) 133 broken_bijector.inverse_log_det_jacobian(y) 134 except IntentionallyMissingError: 135 raise AssertionError("Tests failed! Cached values not used.") 136 137 def testCachingOfInverseResults(self): 138 broken_bijector = self.broken_bijector_cls(forward_missing=True) 139 with self.test_session(): 140 y = constant_op.constant(1.1) 141 142 # Call inverse and inverse_log_det_jacobian one-by-one (not together). 143 x = broken_bijector.inverse(y) 144 _ = broken_bijector.inverse_log_det_jacobian(y) 145 146 # Now, everything should be cached if the argument is x. 147 try: 148 broken_bijector.forward(x) 149 broken_bijector.forward_log_det_jacobian(x) 150 except IntentionallyMissingError: 151 raise AssertionError("Tests failed! Cached values not used.") 152 153 154 class BijectorCachingTest(BijectorCachingTestBase, test.TestCase): 155 """Test caching with BrokenBijector.""" 156 157 @property 158 def broken_bijector_cls(self): 159 return BrokenBijector 160 161 162 if __name__ == "__main__": 163 test.main() 164