1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the Bernoulli distribution.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib import distributions 22 from tensorflow.contrib.distributions.python.kernel_tests import distribution_test 23 from tensorflow.python.framework import dtypes 24 from tensorflow.python.framework import tensor_shape 25 from tensorflow.python.platform import test 26 27 28 class ConditionalDistributionTest(distribution_test.DistributionTest): 29 30 def _GetFakeDistribution(self): 31 class _FakeDistribution(distributions.ConditionalDistribution): 32 """Fake Distribution for testing _set_sample_static_shape.""" 33 34 def __init__(self, batch_shape=None, event_shape=None): 35 self._static_batch_shape = tensor_shape.TensorShape(batch_shape) 36 self._static_event_shape = tensor_shape.TensorShape(event_shape) 37 super(_FakeDistribution, self).__init__( 38 dtype=dtypes.float32, 39 reparameterization_type=distributions.NOT_REPARAMETERIZED, 40 validate_args=True, 41 allow_nan_stats=True, 42 name="DummyDistribution") 43 44 def _batch_shape(self): 45 return self._static_batch_shape 46 47 def _event_shape(self): 48 return self._static_event_shape 49 50 def _sample_n(self, unused_shape, unused_seed, arg1, arg2): 51 raise ValueError(arg1, arg2) 52 53 def _log_prob(self, _, arg1, arg2): 54 raise ValueError(arg1, arg2) 55 56 def _prob(self, _, arg1, arg2): 57 raise ValueError(arg1, arg2) 58 59 def _cdf(self, _, arg1, arg2): 60 raise ValueError(arg1, arg2) 61 62 def _log_cdf(self, _, arg1, arg2): 63 raise ValueError(arg1, arg2) 64 65 def _log_survival_function(self, _, arg1, arg2): 66 raise ValueError(arg1, arg2) 67 68 def _survival_function(self, _, arg1, arg2): 69 raise ValueError(arg1, arg2) 70 71 return _FakeDistribution 72 73 def testNotImplemented(self): 74 d = self._GetFakeDistribution()(batch_shape=[], event_shape=[]) 75 for name in ["sample", "log_prob", "prob", "log_cdf", "cdf", 76 "log_survival_function", "survival_function"]: 77 method = getattr(d, name) 78 with self.assertRaisesRegexp(ValueError, "b1.*b2"): 79 method([] if name == "sample" else 1.0, arg1="b1", arg2="b2") 80 81 82 if __name__ == "__main__": 83 test.main() 84