Home | History | Annotate | Download | only in distributions
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for distributions KL mechanism."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.framework import test_util
     22 from tensorflow.python.ops import array_ops
     23 from tensorflow.python.ops.distributions import kullback_leibler
     24 from tensorflow.python.ops.distributions import normal
     25 from tensorflow.python.platform import test
     26 
     27 # pylint: disable=protected-access
     28 _DIVERGENCES = kullback_leibler._DIVERGENCES
     29 _registered_kl = kullback_leibler._registered_kl
     30 
     31 # pylint: enable=protected-access
     32 
     33 
     34 class KLTest(test.TestCase):
     35 
     36   def testRegistration(self):
     37 
     38     class MyDist(normal.Normal):
     39       pass
     40 
     41     # Register KL to a lambda that spits out the name parameter
     42     @kullback_leibler.RegisterKL(MyDist, MyDist)
     43     def _kl(a, b, name=None):  # pylint: disable=unused-argument,unused-variable
     44       return name
     45 
     46     a = MyDist(loc=0.0, scale=1.0)
     47     self.assertEqual("OK", kullback_leibler.kl_divergence(a, a, name="OK"))
     48 
     49   @test_util.run_deprecated_v1
     50   def testDomainErrorExceptions(self):
     51 
     52     class MyDistException(normal.Normal):
     53       pass
     54 
     55     # Register KL to a lambda that spits out the name parameter
     56     @kullback_leibler.RegisterKL(MyDistException, MyDistException)
     57     # pylint: disable=unused-argument,unused-variable
     58     def _kl(a, b, name=None):
     59       return array_ops.identity([float("nan")])
     60 
     61     # pylint: disable=unused-argument,unused-variable
     62 
     63     with self.cached_session():
     64       a = MyDistException(loc=0.0, scale=1.0, allow_nan_stats=False)
     65       kl = kullback_leibler.kl_divergence(a, a, allow_nan_stats=False)
     66       with self.assertRaisesOpError(
     67           "KL calculation between .* and .* returned NaN values"):
     68         self.evaluate(kl)
     69       with self.assertRaisesOpError(
     70           "KL calculation between .* and .* returned NaN values"):
     71         a.kl_divergence(a).eval()
     72       a = MyDistException(loc=0.0, scale=1.0, allow_nan_stats=True)
     73       kl_ok = kullback_leibler.kl_divergence(a, a)
     74       self.assertAllEqual([float("nan")], self.evaluate(kl_ok))
     75       self_kl_ok = a.kl_divergence(a)
     76       self.assertAllEqual([float("nan")], self.evaluate(self_kl_ok))
     77       cross_ok = a.cross_entropy(a)
     78       self.assertAllEqual([float("nan")], self.evaluate(cross_ok))
     79 
     80   def testRegistrationFailures(self):
     81 
     82     class MyDist(normal.Normal):
     83       pass
     84 
     85     with self.assertRaisesRegexp(TypeError, "must be callable"):
     86       kullback_leibler.RegisterKL(MyDist, MyDist)("blah")
     87 
     88     # First registration is OK
     89     kullback_leibler.RegisterKL(MyDist, MyDist)(lambda a, b: None)
     90 
     91     # Second registration fails
     92     with self.assertRaisesRegexp(ValueError, "has already been registered"):
     93       kullback_leibler.RegisterKL(MyDist, MyDist)(lambda a, b: None)
     94 
     95   def testExactRegistrationsAllMatch(self):
     96     for (k, v) in _DIVERGENCES.items():
     97       self.assertEqual(v, _registered_kl(*k))
     98 
     99   def _testIndirectRegistration(self, fn):
    100 
    101     class Sub1(normal.Normal):
    102 
    103       def entropy(self):
    104         return ""
    105 
    106     class Sub2(normal.Normal):
    107 
    108       def entropy(self):
    109         return ""
    110 
    111     class Sub11(Sub1):
    112 
    113       def entropy(self):
    114         return ""
    115 
    116     # pylint: disable=unused-argument,unused-variable
    117     @kullback_leibler.RegisterKL(Sub1, Sub1)
    118     def _kl11(a, b, name=None):
    119       return "sub1-1"
    120 
    121     @kullback_leibler.RegisterKL(Sub1, Sub2)
    122     def _kl12(a, b, name=None):
    123       return "sub1-2"
    124 
    125     @kullback_leibler.RegisterKL(Sub2, Sub1)
    126     def _kl21(a, b, name=None):
    127       return "sub2-1"
    128 
    129     # pylint: enable=unused-argument,unused_variable
    130 
    131     sub1 = Sub1(loc=0.0, scale=1.0)
    132     sub2 = Sub2(loc=0.0, scale=1.0)
    133     sub11 = Sub11(loc=0.0, scale=1.0)
    134 
    135     self.assertEqual("sub1-1", fn(sub1, sub1))
    136     self.assertEqual("sub1-2", fn(sub1, sub2))
    137     self.assertEqual("sub2-1", fn(sub2, sub1))
    138     self.assertEqual("sub1-1", fn(sub11, sub11))
    139     self.assertEqual("sub1-1", fn(sub11, sub1))
    140     self.assertEqual("sub1-2", fn(sub11, sub2))
    141     self.assertEqual("sub1-1", fn(sub11, sub1))
    142     self.assertEqual("sub1-2", fn(sub11, sub2))
    143     self.assertEqual("sub2-1", fn(sub2, sub11))
    144     self.assertEqual("sub1-1", fn(sub1, sub11))
    145 
    146   def testIndirectRegistrationKLFun(self):
    147     self._testIndirectRegistration(kullback_leibler.kl_divergence)
    148 
    149   def testIndirectRegistrationKLSelf(self):
    150     self._testIndirectRegistration(
    151         lambda p, q: p.kl_divergence(q))
    152 
    153   def testIndirectRegistrationCrossEntropy(self):
    154     self._testIndirectRegistration(
    155         lambda p, q: p.cross_entropy(q))
    156 
    157   def testFunctionCrossEntropy(self):
    158     self._testIndirectRegistration(kullback_leibler.cross_entropy)
    159 
    160 
    161 if __name__ == "__main__":
    162   test.main()
    163