Home | History | Annotate | Download | only in distributions
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for Uniform distribution."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import importlib
     22 
     23 import numpy as np
     24 
     25 from tensorflow.python.framework import constant_op
     26 from tensorflow.python.framework import errors
     27 from tensorflow.python.framework import tensor_shape
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import math_ops
     30 from tensorflow.python.ops.distributions import uniform as uniform_lib
     31 from tensorflow.python.platform import test
     32 from tensorflow.python.platform import tf_logging
     33 
     34 
     35 def try_import(name):  # pylint: disable=invalid-name
     36   module = None
     37   try:
     38     module = importlib.import_module(name)
     39   except ImportError as e:
     40     tf_logging.warning("Could not import %s: %s" % (name, str(e)))
     41   return module
     42 
     43 
     44 stats = try_import("scipy.stats")
     45 
     46 
     47 class UniformTest(test.TestCase):
     48 
     49   def testUniformRange(self):
     50     with self.test_session():
     51       a = 3.0
     52       b = 10.0
     53       uniform = uniform_lib.Uniform(low=a, high=b)
     54       self.assertAllClose(a, uniform.low.eval())
     55       self.assertAllClose(b, uniform.high.eval())
     56       self.assertAllClose(b - a, uniform.range().eval())
     57 
     58   def testUniformPDF(self):
     59     with self.test_session():
     60       a = constant_op.constant([-3.0] * 5 + [15.0])
     61       b = constant_op.constant([11.0] * 5 + [20.0])
     62       uniform = uniform_lib.Uniform(low=a, high=b)
     63 
     64       a_v = -3.0
     65       b_v = 11.0
     66       x = np.array([-10.5, 4.0, 0.0, 10.99, 11.3, 17.0], dtype=np.float32)
     67 
     68       def _expected_pdf():
     69         pdf = np.zeros_like(x) + 1.0 / (b_v - a_v)
     70         pdf[x > b_v] = 0.0
     71         pdf[x < a_v] = 0.0
     72         pdf[5] = 1.0 / (20.0 - 15.0)
     73         return pdf
     74 
     75       expected_pdf = _expected_pdf()
     76 
     77       pdf = uniform.prob(x)
     78       self.assertAllClose(expected_pdf, pdf.eval())
     79 
     80       log_pdf = uniform.log_prob(x)
     81       self.assertAllClose(np.log(expected_pdf), log_pdf.eval())
     82 
     83   def testUniformShape(self):
     84     with self.test_session():
     85       a = constant_op.constant([-3.0] * 5)
     86       b = constant_op.constant(11.0)
     87       uniform = uniform_lib.Uniform(low=a, high=b)
     88 
     89       self.assertEqual(uniform.batch_shape_tensor().eval(), (5,))
     90       self.assertEqual(uniform.batch_shape, tensor_shape.TensorShape([5]))
     91       self.assertAllEqual(uniform.event_shape_tensor().eval(), [])
     92       self.assertEqual(uniform.event_shape, tensor_shape.TensorShape([]))
     93 
     94   def testUniformPDFWithScalarEndpoint(self):
     95     with self.test_session():
     96       a = constant_op.constant([0.0, 5.0])
     97       b = constant_op.constant(10.0)
     98       uniform = uniform_lib.Uniform(low=a, high=b)
     99 
    100       x = np.array([0.0, 8.0], dtype=np.float32)
    101       expected_pdf = np.array([1.0 / (10.0 - 0.0), 1.0 / (10.0 - 5.0)])
    102 
    103       pdf = uniform.prob(x)
    104       self.assertAllClose(expected_pdf, pdf.eval())
    105 
    106   def testUniformCDF(self):
    107     with self.test_session():
    108       batch_size = 6
    109       a = constant_op.constant([1.0] * batch_size)
    110       b = constant_op.constant([11.0] * batch_size)
    111       a_v = 1.0
    112       b_v = 11.0
    113       x = np.array([-2.5, 2.5, 4.0, 0.0, 10.99, 12.0], dtype=np.float32)
    114 
    115       uniform = uniform_lib.Uniform(low=a, high=b)
    116 
    117       def _expected_cdf():
    118         cdf = (x - a_v) / (b_v - a_v)
    119         cdf[x >= b_v] = 1
    120         cdf[x < a_v] = 0
    121         return cdf
    122 
    123       cdf = uniform.cdf(x)
    124       self.assertAllClose(_expected_cdf(), cdf.eval())
    125 
    126       log_cdf = uniform.log_cdf(x)
    127       self.assertAllClose(np.log(_expected_cdf()), log_cdf.eval())
    128 
    129   def testUniformEntropy(self):
    130     with self.test_session():
    131       a_v = np.array([1.0, 1.0, 1.0])
    132       b_v = np.array([[1.5, 2.0, 3.0]])
    133       uniform = uniform_lib.Uniform(low=a_v, high=b_v)
    134 
    135       expected_entropy = np.log(b_v - a_v)
    136       self.assertAllClose(expected_entropy, uniform.entropy().eval())
    137 
    138   def testUniformAssertMaxGtMin(self):
    139     with self.test_session():
    140       a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32)
    141       b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32)
    142       uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True)
    143 
    144       with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
    145                                                "x < y"):
    146         uniform.low.eval()
    147 
    148   def testUniformSample(self):
    149     with self.test_session():
    150       a = constant_op.constant([3.0, 4.0])
    151       b = constant_op.constant(13.0)
    152       a1_v = 3.0
    153       a2_v = 4.0
    154       b_v = 13.0
    155       n = constant_op.constant(100000)
    156       uniform = uniform_lib.Uniform(low=a, high=b)
    157 
    158       samples = uniform.sample(n, seed=137)
    159       sample_values = samples.eval()
    160       self.assertEqual(sample_values.shape, (100000, 2))
    161       self.assertAllClose(
    162           sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-2)
    163       self.assertAllClose(
    164           sample_values[::, 1].mean(), (b_v + a2_v) / 2, atol=1e-2)
    165       self.assertFalse(
    166           np.any(sample_values[::, 0] < a1_v) or np.any(sample_values >= b_v))
    167       self.assertFalse(
    168           np.any(sample_values[::, 1] < a2_v) or np.any(sample_values >= b_v))
    169 
    170   def _testUniformSampleMultiDimensional(self):
    171     # DISABLED: Please enable this test once b/issues/30149644 is resolved.
    172     with self.test_session():
    173       batch_size = 2
    174       a_v = [3.0, 22.0]
    175       b_v = [13.0, 35.0]
    176       a = constant_op.constant([a_v] * batch_size)
    177       b = constant_op.constant([b_v] * batch_size)
    178 
    179       uniform = uniform_lib.Uniform(low=a, high=b)
    180 
    181       n_v = 100000
    182       n = constant_op.constant(n_v)
    183       samples = uniform.sample(n)
    184       self.assertEqual(samples.get_shape(), (n_v, batch_size, 2))
    185 
    186       sample_values = samples.eval()
    187 
    188       self.assertFalse(
    189           np.any(sample_values[:, 0, 0] < a_v[0]) or
    190           np.any(sample_values[:, 0, 0] >= b_v[0]))
    191       self.assertFalse(
    192           np.any(sample_values[:, 0, 1] < a_v[1]) or
    193           np.any(sample_values[:, 0, 1] >= b_v[1]))
    194 
    195       self.assertAllClose(
    196           sample_values[:, 0, 0].mean(), (a_v[0] + b_v[0]) / 2, atol=1e-2)
    197       self.assertAllClose(
    198           sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2, atol=1e-2)
    199 
    200   def testUniformMean(self):
    201     with self.test_session():
    202       a = 10.0
    203       b = 100.0
    204       uniform = uniform_lib.Uniform(low=a, high=b)
    205       if not stats:
    206         return
    207       s_uniform = stats.uniform(loc=a, scale=b - a)
    208       self.assertAllClose(uniform.mean().eval(), s_uniform.mean())
    209 
    210   def testUniformVariance(self):
    211     with self.test_session():
    212       a = 10.0
    213       b = 100.0
    214       uniform = uniform_lib.Uniform(low=a, high=b)
    215       if not stats:
    216         return
    217       s_uniform = stats.uniform(loc=a, scale=b - a)
    218       self.assertAllClose(uniform.variance().eval(), s_uniform.var())
    219 
    220   def testUniformStd(self):
    221     with self.test_session():
    222       a = 10.0
    223       b = 100.0
    224       uniform = uniform_lib.Uniform(low=a, high=b)
    225       if not stats:
    226         return
    227       s_uniform = stats.uniform(loc=a, scale=b - a)
    228       self.assertAllClose(uniform.stddev().eval(), s_uniform.std())
    229 
    230   def testUniformNans(self):
    231     with self.test_session():
    232       a = 10.0
    233       b = [11.0, 100.0]
    234       uniform = uniform_lib.Uniform(low=a, high=b)
    235 
    236       no_nans = constant_op.constant(1.0)
    237       nans = constant_op.constant(0.0) / constant_op.constant(0.0)
    238       self.assertTrue(math_ops.is_nan(nans).eval())
    239       with_nans = array_ops.stack([no_nans, nans])
    240 
    241       pdf = uniform.prob(with_nans)
    242 
    243       is_nan = math_ops.is_nan(pdf).eval()
    244       self.assertFalse(is_nan[0])
    245       self.assertTrue(is_nan[1])
    246 
    247   def testUniformSamplePdf(self):
    248     with self.test_session():
    249       a = 10.0
    250       b = [11.0, 100.0]
    251       uniform = uniform_lib.Uniform(a, b)
    252       self.assertTrue(
    253           math_ops.reduce_all(uniform.prob(uniform.sample(10)) > 0).eval())
    254 
    255   def testUniformBroadcasting(self):
    256     with self.test_session():
    257       a = 10.0
    258       b = [11.0, 20.0]
    259       uniform = uniform_lib.Uniform(a, b)
    260 
    261       pdf = uniform.prob([[10.5, 11.5], [9.0, 19.0], [10.5, 21.0]])
    262       expected_pdf = np.array([[1.0, 0.1], [0.0, 0.1], [1.0, 0.0]])
    263       self.assertAllClose(expected_pdf, pdf.eval())
    264 
    265   def testUniformSampleWithShape(self):
    266     with self.test_session():
    267       a = 10.0
    268       b = [11.0, 20.0]
    269       uniform = uniform_lib.Uniform(a, b)
    270 
    271       pdf = uniform.prob(uniform.sample((2, 3)))
    272       # pylint: disable=bad-continuation
    273       expected_pdf = [
    274           [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]],
    275           [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]],
    276       ]
    277       # pylint: enable=bad-continuation
    278       self.assertAllClose(expected_pdf, pdf.eval())
    279 
    280       pdf = uniform.prob(uniform.sample())
    281       expected_pdf = [1.0, 0.1]
    282       self.assertAllClose(expected_pdf, pdf.eval())
    283 
    284 
    285 if __name__ == "__main__":
    286   test.main()
    287