1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for initializers.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import importlib 22 import math 23 24 import numpy as np 25 26 from tensorflow.python.framework import constant_op 27 from tensorflow.python.framework import dtypes 28 from tensorflow.python.framework import ops 29 from tensorflow.python.framework import tensor_shape 30 from tensorflow.python.ops import array_ops 31 from tensorflow.python.ops import gradients_impl 32 from tensorflow.python.ops import nn_ops 33 from tensorflow.python.ops import variables 34 from tensorflow.python.ops.distributions import kullback_leibler 35 from tensorflow.python.ops.distributions import normal as normal_lib 36 from tensorflow.python.platform import test 37 from tensorflow.python.platform import tf_logging 38 39 40 def try_import(name): # pylint: disable=invalid-name 41 module = None 42 try: 43 module = importlib.import_module(name) 44 except ImportError as e: 45 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 46 return module 47 48 stats = try_import("scipy.stats") 49 50 51 class NormalTest(test.TestCase): 52 53 def setUp(self): 54 self._rng = np.random.RandomState(123) 55 56 def assertAllFinite(self, tensor): 57 is_finite = np.isfinite(tensor.eval()) 58 all_true = np.ones_like(is_finite, dtype=np.bool) 59 self.assertAllEqual(all_true, is_finite) 60 61 def _testParamShapes(self, sample_shape, expected): 62 with self.test_session(): 63 param_shapes = normal_lib.Normal.param_shapes(sample_shape) 64 mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"] 65 self.assertAllEqual(expected, mu_shape.eval()) 66 self.assertAllEqual(expected, sigma_shape.eval()) 67 mu = array_ops.zeros(mu_shape) 68 sigma = array_ops.ones(sigma_shape) 69 self.assertAllEqual( 70 expected, 71 array_ops.shape(normal_lib.Normal(mu, sigma).sample()).eval()) 72 73 def _testParamStaticShapes(self, sample_shape, expected): 74 param_shapes = normal_lib.Normal.param_static_shapes(sample_shape) 75 mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"] 76 self.assertEqual(expected, mu_shape) 77 self.assertEqual(expected, sigma_shape) 78 79 def testParamShapes(self): 80 sample_shape = [10, 3, 4] 81 self._testParamShapes(sample_shape, sample_shape) 82 self._testParamShapes(constant_op.constant(sample_shape), sample_shape) 83 84 def testParamStaticShapes(self): 85 sample_shape = [10, 3, 4] 86 self._testParamStaticShapes(sample_shape, sample_shape) 87 self._testParamStaticShapes( 88 tensor_shape.TensorShape(sample_shape), sample_shape) 89 90 def testNormalWithSoftplusScale(self): 91 with self.test_session(): 92 mu = array_ops.zeros((10, 3)) 93 rho = array_ops.ones((10, 3)) * -2. 94 normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho) 95 self.assertAllEqual(mu.eval(), normal.loc.eval()) 96 self.assertAllEqual(nn_ops.softplus(rho).eval(), normal.scale.eval()) 97 98 def testNormalLogPDF(self): 99 with self.test_session(): 100 batch_size = 6 101 mu = constant_op.constant([3.0] * batch_size) 102 sigma = constant_op.constant([math.sqrt(10.0)] * batch_size) 103 x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32) 104 normal = normal_lib.Normal(loc=mu, scale=sigma) 105 106 log_pdf = normal.log_prob(x) 107 self.assertAllEqual(normal.batch_shape_tensor().eval(), 108 log_pdf.get_shape()) 109 self.assertAllEqual(normal.batch_shape_tensor().eval(), 110 log_pdf.eval().shape) 111 self.assertAllEqual(normal.batch_shape, log_pdf.get_shape()) 112 self.assertAllEqual(normal.batch_shape, log_pdf.eval().shape) 113 114 pdf = normal.prob(x) 115 self.assertAllEqual(normal.batch_shape_tensor().eval(), pdf.get_shape()) 116 self.assertAllEqual(normal.batch_shape_tensor().eval(), pdf.eval().shape) 117 self.assertAllEqual(normal.batch_shape, pdf.get_shape()) 118 self.assertAllEqual(normal.batch_shape, pdf.eval().shape) 119 120 if not stats: 121 return 122 expected_log_pdf = stats.norm(mu.eval(), sigma.eval()).logpdf(x) 123 self.assertAllClose(expected_log_pdf, log_pdf.eval()) 124 self.assertAllClose(np.exp(expected_log_pdf), pdf.eval()) 125 126 def testNormalLogPDFMultidimensional(self): 127 with self.test_session(): 128 batch_size = 6 129 mu = constant_op.constant([[3.0, -3.0]] * batch_size) 130 sigma = constant_op.constant([[math.sqrt(10.0), math.sqrt(15.0)]] * 131 batch_size) 132 x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T 133 normal = normal_lib.Normal(loc=mu, scale=sigma) 134 135 log_pdf = normal.log_prob(x) 136 log_pdf_values = log_pdf.eval() 137 self.assertEqual(log_pdf.get_shape(), (6, 2)) 138 self.assertAllEqual(normal.batch_shape_tensor().eval(), 139 log_pdf.get_shape()) 140 self.assertAllEqual(normal.batch_shape_tensor().eval(), 141 log_pdf.eval().shape) 142 self.assertAllEqual(normal.batch_shape, log_pdf.get_shape()) 143 self.assertAllEqual(normal.batch_shape, log_pdf.eval().shape) 144 145 pdf = normal.prob(x) 146 pdf_values = pdf.eval() 147 self.assertEqual(pdf.get_shape(), (6, 2)) 148 self.assertAllEqual(normal.batch_shape_tensor().eval(), pdf.get_shape()) 149 self.assertAllEqual(normal.batch_shape_tensor().eval(), pdf_values.shape) 150 self.assertAllEqual(normal.batch_shape, pdf.get_shape()) 151 self.assertAllEqual(normal.batch_shape, pdf_values.shape) 152 153 if not stats: 154 return 155 expected_log_pdf = stats.norm(mu.eval(), sigma.eval()).logpdf(x) 156 self.assertAllClose(expected_log_pdf, log_pdf_values) 157 self.assertAllClose(np.exp(expected_log_pdf), pdf_values) 158 159 def testNormalCDF(self): 160 with self.test_session(): 161 batch_size = 50 162 mu = self._rng.randn(batch_size) 163 sigma = self._rng.rand(batch_size) + 1.0 164 x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) 165 166 normal = normal_lib.Normal(loc=mu, scale=sigma) 167 cdf = normal.cdf(x) 168 self.assertAllEqual(normal.batch_shape_tensor().eval(), cdf.get_shape()) 169 self.assertAllEqual(normal.batch_shape_tensor().eval(), cdf.eval().shape) 170 self.assertAllEqual(normal.batch_shape, cdf.get_shape()) 171 self.assertAllEqual(normal.batch_shape, cdf.eval().shape) 172 if not stats: 173 return 174 expected_cdf = stats.norm(mu, sigma).cdf(x) 175 self.assertAllClose(expected_cdf, cdf.eval(), atol=0) 176 177 def testNormalSurvivalFunction(self): 178 with self.test_session(): 179 batch_size = 50 180 mu = self._rng.randn(batch_size) 181 sigma = self._rng.rand(batch_size) + 1.0 182 x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) 183 184 normal = normal_lib.Normal(loc=mu, scale=sigma) 185 186 sf = normal.survival_function(x) 187 self.assertAllEqual(normal.batch_shape_tensor().eval(), sf.get_shape()) 188 self.assertAllEqual(normal.batch_shape_tensor().eval(), sf.eval().shape) 189 self.assertAllEqual(normal.batch_shape, sf.get_shape()) 190 self.assertAllEqual(normal.batch_shape, sf.eval().shape) 191 if not stats: 192 return 193 expected_sf = stats.norm(mu, sigma).sf(x) 194 self.assertAllClose(expected_sf, sf.eval(), atol=0) 195 196 def testNormalLogCDF(self): 197 with self.test_session(): 198 batch_size = 50 199 mu = self._rng.randn(batch_size) 200 sigma = self._rng.rand(batch_size) + 1.0 201 x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64) 202 203 normal = normal_lib.Normal(loc=mu, scale=sigma) 204 205 cdf = normal.log_cdf(x) 206 self.assertAllEqual(normal.batch_shape_tensor().eval(), cdf.get_shape()) 207 self.assertAllEqual(normal.batch_shape_tensor().eval(), cdf.eval().shape) 208 self.assertAllEqual(normal.batch_shape, cdf.get_shape()) 209 self.assertAllEqual(normal.batch_shape, cdf.eval().shape) 210 211 if not stats: 212 return 213 expected_cdf = stats.norm(mu, sigma).logcdf(x) 214 self.assertAllClose(expected_cdf, cdf.eval(), atol=0, rtol=1e-5) 215 216 def testFiniteGradientAtDifficultPoints(self): 217 for dtype in [np.float32, np.float64]: 218 g = ops.Graph() 219 with g.as_default(): 220 mu = variables.Variable(dtype(0.0)) 221 sigma = variables.Variable(dtype(1.0)) 222 dist = normal_lib.Normal(loc=mu, scale=sigma) 223 x = np.array([-100., -20., -5., 0., 5., 20., 100.]).astype(dtype) 224 for func in [ 225 dist.cdf, dist.log_cdf, dist.survival_function, 226 dist.log_survival_function, dist.log_prob, dist.prob 227 ]: 228 value = func(x) 229 grads = gradients_impl.gradients(value, [mu, sigma]) 230 with self.test_session(graph=g): 231 variables.global_variables_initializer().run() 232 self.assertAllFinite(value) 233 self.assertAllFinite(grads[0]) 234 self.assertAllFinite(grads[1]) 235 236 def testNormalLogSurvivalFunction(self): 237 with self.test_session(): 238 batch_size = 50 239 mu = self._rng.randn(batch_size) 240 sigma = self._rng.rand(batch_size) + 1.0 241 x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64) 242 243 normal = normal_lib.Normal(loc=mu, scale=sigma) 244 245 sf = normal.log_survival_function(x) 246 self.assertAllEqual(normal.batch_shape_tensor().eval(), sf.get_shape()) 247 self.assertAllEqual(normal.batch_shape_tensor().eval(), sf.eval().shape) 248 self.assertAllEqual(normal.batch_shape, sf.get_shape()) 249 self.assertAllEqual(normal.batch_shape, sf.eval().shape) 250 251 if not stats: 252 return 253 expected_sf = stats.norm(mu, sigma).logsf(x) 254 self.assertAllClose(expected_sf, sf.eval(), atol=0, rtol=1e-5) 255 256 def testNormalEntropyWithScalarInputs(self): 257 # Scipy.stats.norm cannot deal with the shapes in the other test. 258 with self.test_session(): 259 mu_v = 2.34 260 sigma_v = 4.56 261 normal = normal_lib.Normal(loc=mu_v, scale=sigma_v) 262 263 entropy = normal.entropy() 264 self.assertAllEqual(normal.batch_shape_tensor().eval(), 265 entropy.get_shape()) 266 self.assertAllEqual(normal.batch_shape_tensor().eval(), 267 entropy.eval().shape) 268 self.assertAllEqual(normal.batch_shape, entropy.get_shape()) 269 self.assertAllEqual(normal.batch_shape, entropy.eval().shape) 270 # scipy.stats.norm cannot deal with these shapes. 271 if not stats: 272 return 273 expected_entropy = stats.norm(mu_v, sigma_v).entropy() 274 self.assertAllClose(expected_entropy, entropy.eval()) 275 276 def testNormalEntropy(self): 277 with self.test_session(): 278 mu_v = np.array([1.0, 1.0, 1.0]) 279 sigma_v = np.array([[1.0, 2.0, 3.0]]).T 280 normal = normal_lib.Normal(loc=mu_v, scale=sigma_v) 281 282 # scipy.stats.norm cannot deal with these shapes. 283 sigma_broadcast = mu_v * sigma_v 284 expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast** 285 2) 286 entropy = normal.entropy() 287 np.testing.assert_allclose(expected_entropy, entropy.eval()) 288 self.assertAllEqual(normal.batch_shape_tensor().eval(), 289 entropy.get_shape()) 290 self.assertAllEqual(normal.batch_shape_tensor().eval(), 291 entropy.eval().shape) 292 self.assertAllEqual(normal.batch_shape, entropy.get_shape()) 293 self.assertAllEqual(normal.batch_shape, entropy.eval().shape) 294 295 def testNormalMeanAndMode(self): 296 with self.test_session(): 297 # Mu will be broadcast to [7, 7, 7]. 298 mu = [7.] 299 sigma = [11., 12., 13.] 300 301 normal = normal_lib.Normal(loc=mu, scale=sigma) 302 303 self.assertAllEqual((3,), normal.mean().get_shape()) 304 self.assertAllEqual([7., 7, 7], normal.mean().eval()) 305 306 self.assertAllEqual((3,), normal.mode().get_shape()) 307 self.assertAllEqual([7., 7, 7], normal.mode().eval()) 308 309 def testNormalQuantile(self): 310 with self.test_session(): 311 batch_size = 52 312 mu = self._rng.randn(batch_size) 313 sigma = self._rng.rand(batch_size) + 1.0 314 p = np.linspace(0., 1.0, batch_size - 2).astype(np.float64) 315 # Quantile performs piecewise rational approximation so adding some 316 # special input values to make sure we hit all the pieces. 317 p = np.hstack((p, np.exp(-33), 1. - np.exp(-33))) 318 319 normal = normal_lib.Normal(loc=mu, scale=sigma) 320 x = normal.quantile(p) 321 322 self.assertAllEqual(normal.batch_shape_tensor().eval(), x.get_shape()) 323 self.assertAllEqual(normal.batch_shape_tensor().eval(), x.eval().shape) 324 self.assertAllEqual(normal.batch_shape, x.get_shape()) 325 self.assertAllEqual(normal.batch_shape, x.eval().shape) 326 327 if not stats: 328 return 329 expected_x = stats.norm(mu, sigma).ppf(p) 330 self.assertAllClose(expected_x, x.eval(), atol=0.) 331 332 def _baseQuantileFiniteGradientAtDifficultPoints(self, dtype): 333 g = ops.Graph() 334 with g.as_default(): 335 mu = variables.Variable(dtype(0.0)) 336 sigma = variables.Variable(dtype(1.0)) 337 dist = normal_lib.Normal(loc=mu, scale=sigma) 338 p = variables.Variable( 339 np.array([0., 340 np.exp(-32.), np.exp(-2.), 341 1. - np.exp(-2.), 1. - np.exp(-32.), 342 1.]).astype(dtype)) 343 344 value = dist.quantile(p) 345 grads = gradients_impl.gradients(value, [mu, p]) 346 with self.test_session(graph=g): 347 variables.global_variables_initializer().run() 348 self.assertAllFinite(grads[0]) 349 self.assertAllFinite(grads[1]) 350 351 def testQuantileFiniteGradientAtDifficultPointsFloat32(self): 352 self._baseQuantileFiniteGradientAtDifficultPoints(np.float32) 353 354 def testQuantileFiniteGradientAtDifficultPointsFloat64(self): 355 self._baseQuantileFiniteGradientAtDifficultPoints(np.float64) 356 357 def testNormalVariance(self): 358 with self.test_session(): 359 # sigma will be broadcast to [7, 7, 7] 360 mu = [1., 2., 3.] 361 sigma = [7.] 362 363 normal = normal_lib.Normal(loc=mu, scale=sigma) 364 365 self.assertAllEqual((3,), normal.variance().get_shape()) 366 self.assertAllEqual([49., 49, 49], normal.variance().eval()) 367 368 def testNormalStandardDeviation(self): 369 with self.test_session(): 370 # sigma will be broadcast to [7, 7, 7] 371 mu = [1., 2., 3.] 372 sigma = [7.] 373 374 normal = normal_lib.Normal(loc=mu, scale=sigma) 375 376 self.assertAllEqual((3,), normal.stddev().get_shape()) 377 self.assertAllEqual([7., 7, 7], normal.stddev().eval()) 378 379 def testNormalSample(self): 380 with self.test_session(): 381 mu = constant_op.constant(3.0) 382 sigma = constant_op.constant(math.sqrt(3.0)) 383 mu_v = 3.0 384 sigma_v = np.sqrt(3.0) 385 n = constant_op.constant(100000) 386 normal = normal_lib.Normal(loc=mu, scale=sigma) 387 samples = normal.sample(n) 388 sample_values = samples.eval() 389 # Note that the standard error for the sample mean is ~ sigma / sqrt(n). 390 # The sample variance similarly is dependent on sigma and n. 391 # Thus, the tolerances below are very sensitive to number of samples 392 # as well as the variances chosen. 393 self.assertEqual(sample_values.shape, (100000,)) 394 self.assertAllClose(sample_values.mean(), mu_v, atol=1e-1) 395 self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1) 396 397 expected_samples_shape = tensor_shape.TensorShape([n.eval()]).concatenate( 398 tensor_shape.TensorShape(normal.batch_shape_tensor().eval())) 399 400 self.assertAllEqual(expected_samples_shape, samples.get_shape()) 401 self.assertAllEqual(expected_samples_shape, sample_values.shape) 402 403 expected_samples_shape = (tensor_shape.TensorShape( 404 [n.eval()]).concatenate(normal.batch_shape)) 405 406 self.assertAllEqual(expected_samples_shape, samples.get_shape()) 407 self.assertAllEqual(expected_samples_shape, sample_values.shape) 408 409 def testNormalSampleMultiDimensional(self): 410 with self.test_session(): 411 batch_size = 2 412 mu = constant_op.constant([[3.0, -3.0]] * batch_size) 413 sigma = constant_op.constant([[math.sqrt(2.0), math.sqrt(3.0)]] * 414 batch_size) 415 mu_v = [3.0, -3.0] 416 sigma_v = [np.sqrt(2.0), np.sqrt(3.0)] 417 n = constant_op.constant(100000) 418 normal = normal_lib.Normal(loc=mu, scale=sigma) 419 samples = normal.sample(n) 420 sample_values = samples.eval() 421 # Note that the standard error for the sample mean is ~ sigma / sqrt(n). 422 # The sample variance similarly is dependent on sigma and n. 423 # Thus, the tolerances below are very sensitive to number of samples 424 # as well as the variances chosen. 425 self.assertEqual(samples.get_shape(), (100000, batch_size, 2)) 426 self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=1e-1) 427 self.assertAllClose(sample_values[:, 0, 0].std(), sigma_v[0], atol=1e-1) 428 self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=1e-1) 429 self.assertAllClose(sample_values[:, 0, 1].std(), sigma_v[1], atol=1e-1) 430 431 expected_samples_shape = tensor_shape.TensorShape([n.eval()]).concatenate( 432 tensor_shape.TensorShape(normal.batch_shape_tensor().eval())) 433 self.assertAllEqual(expected_samples_shape, samples.get_shape()) 434 self.assertAllEqual(expected_samples_shape, sample_values.shape) 435 436 expected_samples_shape = (tensor_shape.TensorShape( 437 [n.eval()]).concatenate(normal.batch_shape)) 438 self.assertAllEqual(expected_samples_shape, samples.get_shape()) 439 self.assertAllEqual(expected_samples_shape, sample_values.shape) 440 441 def testNegativeSigmaFails(self): 442 with self.test_session(): 443 normal = normal_lib.Normal( 444 loc=[1.], scale=[-5.], validate_args=True, name="G") 445 with self.assertRaisesOpError("Condition x > 0 did not hold"): 446 normal.mean().eval() 447 448 def testNormalShape(self): 449 with self.test_session(): 450 mu = constant_op.constant([-3.0] * 5) 451 sigma = constant_op.constant(11.0) 452 normal = normal_lib.Normal(loc=mu, scale=sigma) 453 454 self.assertEqual(normal.batch_shape_tensor().eval(), [5]) 455 self.assertEqual(normal.batch_shape, tensor_shape.TensorShape([5])) 456 self.assertAllEqual(normal.event_shape_tensor().eval(), []) 457 self.assertEqual(normal.event_shape, tensor_shape.TensorShape([])) 458 459 def testNormalShapeWithPlaceholders(self): 460 mu = array_ops.placeholder(dtype=dtypes.float32) 461 sigma = array_ops.placeholder(dtype=dtypes.float32) 462 normal = normal_lib.Normal(loc=mu, scale=sigma) 463 464 with self.test_session() as sess: 465 # get_batch_shape should return an "<unknown>" tensor. 466 self.assertEqual(normal.batch_shape, tensor_shape.TensorShape(None)) 467 self.assertEqual(normal.event_shape, ()) 468 self.assertAllEqual(normal.event_shape_tensor().eval(), []) 469 self.assertAllEqual( 470 sess.run(normal.batch_shape_tensor(), 471 feed_dict={mu: 5.0, 472 sigma: [1.0, 2.0]}), [2]) 473 474 def testNormalNormalKL(self): 475 with self.test_session() as sess: 476 batch_size = 6 477 mu_a = np.array([3.0] * batch_size) 478 sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5]) 479 mu_b = np.array([-3.0] * batch_size) 480 sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0]) 481 482 n_a = normal_lib.Normal(loc=mu_a, scale=sigma_a) 483 n_b = normal_lib.Normal(loc=mu_b, scale=sigma_b) 484 485 kl = kullback_leibler.kl_divergence(n_a, n_b) 486 kl_val = sess.run(kl) 487 488 kl_expected = ((mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * ( 489 (sigma_a**2 / sigma_b**2) - 1 - 2 * np.log(sigma_a / sigma_b))) 490 491 self.assertEqual(kl.get_shape(), (batch_size,)) 492 self.assertAllClose(kl_val, kl_expected) 493 494 495 if __name__ == "__main__": 496 test.main() 497