1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for distributions KL mechanism.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.framework import test_util 22 from tensorflow.python.ops import array_ops 23 from tensorflow.python.ops.distributions import kullback_leibler 24 from tensorflow.python.ops.distributions import normal 25 from tensorflow.python.platform import test 26 27 # pylint: disable=protected-access 28 _DIVERGENCES = kullback_leibler._DIVERGENCES 29 _registered_kl = kullback_leibler._registered_kl 30 31 # pylint: enable=protected-access 32 33 34 class KLTest(test.TestCase): 35 36 def testRegistration(self): 37 38 class MyDist(normal.Normal): 39 pass 40 41 # Register KL to a lambda that spits out the name parameter 42 @kullback_leibler.RegisterKL(MyDist, MyDist) 43 def _kl(a, b, name=None): # pylint: disable=unused-argument,unused-variable 44 return name 45 46 a = MyDist(loc=0.0, scale=1.0) 47 self.assertEqual("OK", kullback_leibler.kl_divergence(a, a, name="OK")) 48 49 @test_util.run_deprecated_v1 50 def testDomainErrorExceptions(self): 51 52 class MyDistException(normal.Normal): 53 pass 54 55 # Register KL to a lambda that spits out the name parameter 56 @kullback_leibler.RegisterKL(MyDistException, MyDistException) 57 # pylint: disable=unused-argument,unused-variable 58 def _kl(a, b, name=None): 59 return array_ops.identity([float("nan")]) 60 61 # pylint: disable=unused-argument,unused-variable 62 63 with self.cached_session(): 64 a = MyDistException(loc=0.0, scale=1.0, allow_nan_stats=False) 65 kl = kullback_leibler.kl_divergence(a, a, allow_nan_stats=False) 66 with self.assertRaisesOpError( 67 "KL calculation between .* and .* returned NaN values"): 68 self.evaluate(kl) 69 with self.assertRaisesOpError( 70 "KL calculation between .* and .* returned NaN values"): 71 a.kl_divergence(a).eval() 72 a = MyDistException(loc=0.0, scale=1.0, allow_nan_stats=True) 73 kl_ok = kullback_leibler.kl_divergence(a, a) 74 self.assertAllEqual([float("nan")], self.evaluate(kl_ok)) 75 self_kl_ok = a.kl_divergence(a) 76 self.assertAllEqual([float("nan")], self.evaluate(self_kl_ok)) 77 cross_ok = a.cross_entropy(a) 78 self.assertAllEqual([float("nan")], self.evaluate(cross_ok)) 79 80 def testRegistrationFailures(self): 81 82 class MyDist(normal.Normal): 83 pass 84 85 with self.assertRaisesRegexp(TypeError, "must be callable"): 86 kullback_leibler.RegisterKL(MyDist, MyDist)("blah") 87 88 # First registration is OK 89 kullback_leibler.RegisterKL(MyDist, MyDist)(lambda a, b: None) 90 91 # Second registration fails 92 with self.assertRaisesRegexp(ValueError, "has already been registered"): 93 kullback_leibler.RegisterKL(MyDist, MyDist)(lambda a, b: None) 94 95 def testExactRegistrationsAllMatch(self): 96 for (k, v) in _DIVERGENCES.items(): 97 self.assertEqual(v, _registered_kl(*k)) 98 99 def _testIndirectRegistration(self, fn): 100 101 class Sub1(normal.Normal): 102 103 def entropy(self): 104 return "" 105 106 class Sub2(normal.Normal): 107 108 def entropy(self): 109 return "" 110 111 class Sub11(Sub1): 112 113 def entropy(self): 114 return "" 115 116 # pylint: disable=unused-argument,unused-variable 117 @kullback_leibler.RegisterKL(Sub1, Sub1) 118 def _kl11(a, b, name=None): 119 return "sub1-1" 120 121 @kullback_leibler.RegisterKL(Sub1, Sub2) 122 def _kl12(a, b, name=None): 123 return "sub1-2" 124 125 @kullback_leibler.RegisterKL(Sub2, Sub1) 126 def _kl21(a, b, name=None): 127 return "sub2-1" 128 129 # pylint: enable=unused-argument,unused_variable 130 131 sub1 = Sub1(loc=0.0, scale=1.0) 132 sub2 = Sub2(loc=0.0, scale=1.0) 133 sub11 = Sub11(loc=0.0, scale=1.0) 134 135 self.assertEqual("sub1-1", fn(sub1, sub1)) 136 self.assertEqual("sub1-2", fn(sub1, sub2)) 137 self.assertEqual("sub2-1", fn(sub2, sub1)) 138 self.assertEqual("sub1-1", fn(sub11, sub11)) 139 self.assertEqual("sub1-1", fn(sub11, sub1)) 140 self.assertEqual("sub1-2", fn(sub11, sub2)) 141 self.assertEqual("sub1-1", fn(sub11, sub1)) 142 self.assertEqual("sub1-2", fn(sub11, sub2)) 143 self.assertEqual("sub2-1", fn(sub2, sub11)) 144 self.assertEqual("sub1-1", fn(sub1, sub11)) 145 146 def testIndirectRegistrationKLFun(self): 147 self._testIndirectRegistration(kullback_leibler.kl_divergence) 148 149 def testIndirectRegistrationKLSelf(self): 150 self._testIndirectRegistration( 151 lambda p, q: p.kl_divergence(q)) 152 153 def testIndirectRegistrationCrossEntropy(self): 154 self._testIndirectRegistration( 155 lambda p, q: p.cross_entropy(q)) 156 157 def testFunctionCrossEntropy(self): 158 self._testIndirectRegistration(kullback_leibler.cross_entropy) 159 160 161 if __name__ == "__main__": 162 test.main() 163