Home | History | Annotate | Download | only in distributions
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for Bijector."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import abc
     22 
     23 import six
     24 
     25 from tensorflow.python.framework import constant_op
     26 from tensorflow.python.ops import math_ops
     27 from tensorflow.python.ops.distributions import bijector
     28 from tensorflow.python.platform import test
     29 
     30 
     31 class BaseBijectorTest(test.TestCase):
     32   """Tests properties of the Bijector base-class."""
     33 
     34   def testIsAbstract(self):
     35     with self.test_session():
     36       with self.assertRaisesRegexp(TypeError,
     37                                    ("Can't instantiate abstract class Bijector "
     38                                     "with abstract methods __init__")):
     39         bijector.Bijector()  # pylint: disable=abstract-class-instantiated
     40 
     41   def testDefaults(self):
     42     class _BareBonesBijector(bijector.Bijector):
     43       """Minimal specification of a `Bijector`."""
     44 
     45       def __init__(self):
     46         super(_BareBonesBijector, self).__init__()
     47 
     48     with self.test_session() as sess:
     49       bij = _BareBonesBijector()
     50       self.assertEqual(None, bij.event_ndims)
     51       self.assertEqual([], bij.graph_parents)
     52       self.assertEqual(False, bij.is_constant_jacobian)
     53       self.assertEqual(False, bij.validate_args)
     54       self.assertEqual(None, bij.dtype)
     55       self.assertEqual("bare_bones_bijector", bij.name)
     56 
     57       for shape in [[], [1, 2], [1, 2, 3]]:
     58         [
     59             forward_event_shape_,
     60             inverse_event_shape_,
     61         ] = sess.run([
     62             bij.inverse_event_shape_tensor(shape),
     63             bij.forward_event_shape_tensor(shape),
     64         ])
     65         self.assertAllEqual(shape, forward_event_shape_)
     66         self.assertAllEqual(shape, bij.forward_event_shape(shape))
     67         self.assertAllEqual(shape, inverse_event_shape_)
     68         self.assertAllEqual(shape, bij.inverse_event_shape(shape))
     69 
     70       for fn in ["forward",
     71                  "inverse",
     72                  "inverse_log_det_jacobian",
     73                  "forward_log_det_jacobian"]:
     74         with self.assertRaisesRegexp(
     75             NotImplementedError, fn + " not implemented"):
     76           getattr(bij, fn)(0)
     77 
     78 
     79 class IntentionallyMissingError(Exception):
     80   pass
     81 
     82 
     83 class BrokenBijector(bijector.Bijector):
     84   """Forward and inverse are not inverses of each other."""
     85 
     86   def __init__(self, forward_missing=False, inverse_missing=False):
     87     super(BrokenBijector, self).__init__(
     88         event_ndims=0, validate_args=False, name="broken")
     89     self._forward_missing = forward_missing
     90     self._inverse_missing = inverse_missing
     91 
     92   def _forward(self, x):
     93     if self._forward_missing:
     94       raise IntentionallyMissingError
     95     return 2 * x
     96 
     97   def _inverse(self, y):
     98     if self._inverse_missing:
     99       raise IntentionallyMissingError
    100     return y / 2.
    101 
    102   def _inverse_log_det_jacobian(self, y):  # pylint:disable=unused-argument
    103     if self._inverse_missing:
    104       raise IntentionallyMissingError
    105     return -math_ops.log(2.)
    106 
    107   def _forward_log_det_jacobian(self, x):  # pylint:disable=unused-argument
    108     if self._forward_missing:
    109       raise IntentionallyMissingError
    110     return math_ops.log(2.)
    111 
    112 
    113 @six.add_metaclass(abc.ABCMeta)
    114 class BijectorCachingTestBase(object):
    115 
    116   @abc.abstractproperty
    117   def broken_bijector_cls(self):
    118     # return a BrokenBijector type Bijector, since this will test the caching.
    119     raise IntentionallyMissingError("Not implemented")
    120 
    121   def testCachingOfForwardResults(self):
    122     broken_bijector = self.broken_bijector_cls(inverse_missing=True)
    123     with self.test_session():
    124       x = constant_op.constant(1.1)
    125 
    126       # Call forward and forward_log_det_jacobian one-by-one (not together).
    127       y = broken_bijector.forward(x)
    128       _ = broken_bijector.forward_log_det_jacobian(x)
    129 
    130       # Now, everything should be cached if the argument is y.
    131       try:
    132         broken_bijector.inverse(y)
    133         broken_bijector.inverse_log_det_jacobian(y)
    134       except IntentionallyMissingError:
    135         raise AssertionError("Tests failed! Cached values not used.")
    136 
    137   def testCachingOfInverseResults(self):
    138     broken_bijector = self.broken_bijector_cls(forward_missing=True)
    139     with self.test_session():
    140       y = constant_op.constant(1.1)
    141 
    142       # Call inverse and inverse_log_det_jacobian one-by-one (not together).
    143       x = broken_bijector.inverse(y)
    144       _ = broken_bijector.inverse_log_det_jacobian(y)
    145 
    146       # Now, everything should be cached if the argument is x.
    147       try:
    148         broken_bijector.forward(x)
    149         broken_bijector.forward_log_det_jacobian(x)
    150       except IntentionallyMissingError:
    151         raise AssertionError("Tests failed! Cached values not used.")
    152 
    153 
    154 class BijectorCachingTest(BijectorCachingTestBase, test.TestCase):
    155   """Test caching with BrokenBijector."""
    156 
    157   @property
    158   def broken_bijector_cls(self):
    159     return BrokenBijector
    160 
    161 
    162 if __name__ == "__main__":
    163   test.main()
    164