1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """The Beta distribution class.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.framework import constant_op 24 from tensorflow.python.framework import dtypes 25 from tensorflow.python.framework import ops 26 from tensorflow.python.framework import tensor_shape 27 from tensorflow.python.ops import array_ops 28 from tensorflow.python.ops import check_ops 29 from tensorflow.python.ops import control_flow_ops 30 from tensorflow.python.ops import math_ops 31 from tensorflow.python.ops import nn 32 from tensorflow.python.ops import random_ops 33 from tensorflow.python.ops.distributions import distribution 34 from tensorflow.python.ops.distributions import kullback_leibler 35 from tensorflow.python.ops.distributions import util as distribution_util 36 from tensorflow.python.util.tf_export import tf_export 37 38 39 __all__ = [ 40 "Beta", 41 "BetaWithSoftplusConcentration", 42 ] 43 44 45 _beta_sample_note = """Note: `x` must have dtype `self.dtype` and be in 46 `[0, 1].` It must have a shape compatible with `self.batch_shape()`.""" 47 48 49 @tf_export("distributions.Beta") 50 class Beta(distribution.Distribution): 51 """Beta distribution. 52 53 The Beta distribution is defined over the `(0, 1)` interval using parameters 54 `concentration1` (aka "alpha") and `concentration0` (aka "beta"). 55 56 #### Mathematical Details 57 58 The probability density function (pdf) is, 59 60 ```none 61 pdf(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z 62 Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta) 63 ``` 64 65 where: 66 67 * `concentration1 = alpha`, 68 * `concentration0 = beta`, 69 * `Z` is the normalization constant, and, 70 * `Gamma` is the [gamma function]( 71 https://en.wikipedia.org/wiki/Gamma_function). 72 73 The concentration parameters represent mean total counts of a `1` or a `0`, 74 i.e., 75 76 ```none 77 concentration1 = alpha = mean * total_concentration 78 concentration0 = beta = (1. - mean) * total_concentration 79 ``` 80 81 where `mean` in `(0, 1)` and `total_concentration` is a positive real number 82 representing a mean `total_count = concentration1 + concentration0`. 83 84 Distribution parameters are automatically broadcast in all functions; see 85 examples for details. 86 87 #### Examples 88 89 ```python 90 # Create a batch of three Beta distributions. 91 alpha = [1, 2, 3] 92 beta = [1, 2, 3] 93 dist = Beta(alpha, beta) 94 95 dist.sample([4, 5]) # Shape [4, 5, 3] 96 97 # `x` has three batch entries, each with two samples. 98 x = [[.1, .4, .5], 99 [.2, .3, .5]] 100 # Calculate the probability of each pair of samples under the corresponding 101 # distribution in `dist`. 102 dist.prob(x) # Shape [2, 3] 103 ``` 104 105 ```python 106 # Create batch_shape=[2, 3] via parameter broadcast: 107 alpha = [[1.], [2]] # Shape [2, 1] 108 beta = [3., 4, 5] # Shape [3] 109 dist = Beta(alpha, beta) 110 111 # alpha broadcast as: [[1., 1, 1,], 112 # [2, 2, 2]] 113 # beta broadcast as: [[3., 4, 5], 114 # [3, 4, 5]] 115 # batch_Shape [2, 3] 116 dist.sample([4, 5]) # Shape [4, 5, 2, 3] 117 118 x = [.2, .3, .5] 119 # x will be broadcast as [[.2, .3, .5], 120 # [.2, .3, .5]], 121 # thus matching batch_shape [2, 3]. 122 dist.prob(x) # Shape [2, 3] 123 ``` 124 125 """ 126 127 def __init__(self, 128 concentration1=None, 129 concentration0=None, 130 validate_args=False, 131 allow_nan_stats=True, 132 name="Beta"): 133 """Initialize a batch of Beta distributions. 134 135 Args: 136 concentration1: Positive floating-point `Tensor` indicating mean 137 number of successes; aka "alpha". Implies `self.dtype` and 138 `self.batch_shape`, i.e., 139 `concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`. 140 concentration0: Positive floating-point `Tensor` indicating mean 141 number of failures; aka "beta". Otherwise has same semantics as 142 `concentration1`. 143 validate_args: Python `bool`, default `False`. When `True` distribution 144 parameters are checked for validity despite possibly degrading runtime 145 performance. When `False` invalid inputs may silently render incorrect 146 outputs. 147 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 148 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 149 result is undefined. When `False`, an exception is raised if one or 150 more of the statistic's batch members are undefined. 151 name: Python `str` name prefixed to Ops created by this class. 152 """ 153 parameters = locals() 154 with ops.name_scope(name, values=[concentration1, concentration0]): 155 self._concentration1 = self._maybe_assert_valid_concentration( 156 ops.convert_to_tensor(concentration1, name="concentration1"), 157 validate_args) 158 self._concentration0 = self._maybe_assert_valid_concentration( 159 ops.convert_to_tensor(concentration0, name="concentration0"), 160 validate_args) 161 check_ops.assert_same_float_dtype([ 162 self._concentration1, self._concentration0]) 163 self._total_concentration = self._concentration1 + self._concentration0 164 super(Beta, self).__init__( 165 dtype=self._total_concentration.dtype, 166 validate_args=validate_args, 167 allow_nan_stats=allow_nan_stats, 168 reparameterization_type=distribution.NOT_REPARAMETERIZED, 169 parameters=parameters, 170 graph_parents=[self._concentration1, 171 self._concentration0, 172 self._total_concentration], 173 name=name) 174 175 @staticmethod 176 def _param_shapes(sample_shape): 177 return dict(zip( 178 ["concentration1", "concentration0"], 179 [ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2)) 180 181 @property 182 def concentration1(self): 183 """Concentration parameter associated with a `1` outcome.""" 184 return self._concentration1 185 186 @property 187 def concentration0(self): 188 """Concentration parameter associated with a `0` outcome.""" 189 return self._concentration0 190 191 @property 192 def total_concentration(self): 193 """Sum of concentration parameters.""" 194 return self._total_concentration 195 196 def _batch_shape_tensor(self): 197 return array_ops.shape(self.total_concentration) 198 199 def _batch_shape(self): 200 return self.total_concentration.get_shape() 201 202 def _event_shape_tensor(self): 203 return constant_op.constant([], dtype=dtypes.int32) 204 205 def _event_shape(self): 206 return tensor_shape.scalar() 207 208 def _sample_n(self, n, seed=None): 209 expanded_concentration1 = array_ops.ones_like( 210 self.total_concentration, dtype=self.dtype) * self.concentration1 211 expanded_concentration0 = array_ops.ones_like( 212 self.total_concentration, dtype=self.dtype) * self.concentration0 213 gamma1_sample = random_ops.random_gamma( 214 shape=[n], 215 alpha=expanded_concentration1, 216 dtype=self.dtype, 217 seed=seed) 218 gamma2_sample = random_ops.random_gamma( 219 shape=[n], 220 alpha=expanded_concentration0, 221 dtype=self.dtype, 222 seed=distribution_util.gen_new_seed(seed, "beta")) 223 beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample) 224 return beta_sample 225 226 @distribution_util.AppendDocstring(_beta_sample_note) 227 def _log_prob(self, x): 228 return self._log_unnormalized_prob(x) - self._log_normalization() 229 230 @distribution_util.AppendDocstring(_beta_sample_note) 231 def _prob(self, x): 232 return math_ops.exp(self._log_prob(x)) 233 234 @distribution_util.AppendDocstring(_beta_sample_note) 235 def _log_cdf(self, x): 236 return math_ops.log(self._cdf(x)) 237 238 @distribution_util.AppendDocstring(_beta_sample_note) 239 def _cdf(self, x): 240 return math_ops.betainc(self.concentration1, self.concentration0, x) 241 242 def _log_unnormalized_prob(self, x): 243 x = self._maybe_assert_valid_sample(x) 244 return ((self.concentration1 - 1.) * math_ops.log(x) 245 + (self.concentration0 - 1.) * math_ops.log1p(-x)) 246 247 def _log_normalization(self): 248 return (math_ops.lgamma(self.concentration1) 249 + math_ops.lgamma(self.concentration0) 250 - math_ops.lgamma(self.total_concentration)) 251 252 def _entropy(self): 253 return ( 254 self._log_normalization() 255 - (self.concentration1 - 1.) * math_ops.digamma(self.concentration1) 256 - (self.concentration0 - 1.) * math_ops.digamma(self.concentration0) 257 + ((self.total_concentration - 2.) * 258 math_ops.digamma(self.total_concentration))) 259 260 def _mean(self): 261 return self._concentration1 / self._total_concentration 262 263 def _variance(self): 264 return self._mean() * (1. - self._mean()) / (1. + self.total_concentration) 265 266 @distribution_util.AppendDocstring( 267 """Note: The mode is undefined when `concentration1 <= 1` or 268 `concentration0 <= 1`. If `self.allow_nan_stats` is `True`, `NaN` 269 is used for undefined modes. If `self.allow_nan_stats` is `False` an 270 exception is raised when one or more modes are undefined.""") 271 def _mode(self): 272 mode = (self.concentration1 - 1.) / (self.total_concentration - 2.) 273 if self.allow_nan_stats: 274 nan = array_ops.fill( 275 self.batch_shape_tensor(), 276 np.array(np.nan, dtype=self.dtype.as_numpy_dtype()), 277 name="nan") 278 is_defined = math_ops.logical_and(self.concentration1 > 1., 279 self.concentration0 > 1.) 280 return array_ops.where(is_defined, mode, nan) 281 return control_flow_ops.with_dependencies([ 282 check_ops.assert_less( 283 array_ops.ones([], dtype=self.dtype), 284 self.concentration1, 285 message="Mode undefined for concentration1 <= 1."), 286 check_ops.assert_less( 287 array_ops.ones([], dtype=self.dtype), 288 self.concentration0, 289 message="Mode undefined for concentration0 <= 1.") 290 ], mode) 291 292 def _maybe_assert_valid_concentration(self, concentration, validate_args): 293 """Checks the validity of a concentration parameter.""" 294 if not validate_args: 295 return concentration 296 return control_flow_ops.with_dependencies([ 297 check_ops.assert_positive( 298 concentration, 299 message="Concentration parameter must be positive."), 300 ], concentration) 301 302 def _maybe_assert_valid_sample(self, x): 303 """Checks the validity of a sample.""" 304 if not self.validate_args: 305 return x 306 return control_flow_ops.with_dependencies([ 307 check_ops.assert_positive(x, message="sample must be positive"), 308 check_ops.assert_less( 309 x, 310 array_ops.ones([], self.dtype), 311 message="sample must be less than `1`."), 312 ], x) 313 314 315 class BetaWithSoftplusConcentration(Beta): 316 """Beta with softplus transform of `concentration1` and `concentration0`.""" 317 318 def __init__(self, 319 concentration1, 320 concentration0, 321 validate_args=False, 322 allow_nan_stats=True, 323 name="BetaWithSoftplusConcentration"): 324 parameters = locals() 325 with ops.name_scope(name, values=[concentration1, 326 concentration0]) as ns: 327 super(BetaWithSoftplusConcentration, self).__init__( 328 concentration1=nn.softplus(concentration1, 329 name="softplus_concentration1"), 330 concentration0=nn.softplus(concentration0, 331 name="softplus_concentration0"), 332 validate_args=validate_args, 333 allow_nan_stats=allow_nan_stats, 334 name=ns) 335 self._parameters = parameters 336 337 338 @kullback_leibler.RegisterKL(Beta, Beta) 339 def _kl_beta_beta(d1, d2, name=None): 340 """Calculate the batchwise KL divergence KL(d1 || d2) with d1 and d2 Beta. 341 342 Args: 343 d1: instance of a Beta distribution object. 344 d2: instance of a Beta distribution object. 345 name: (optional) Name to use for created operations. 346 default is "kl_beta_beta". 347 348 Returns: 349 Batchwise KL(d1 || d2) 350 """ 351 def delta(fn, is_property=True): 352 fn1 = getattr(d1, fn) 353 fn2 = getattr(d2, fn) 354 return (fn2 - fn1) if is_property else (fn2() - fn1()) 355 with ops.name_scope(name, "kl_beta_beta", values=[ 356 d1.concentration1, 357 d1.concentration0, 358 d1.total_concentration, 359 d2.concentration1, 360 d2.concentration0, 361 d2.total_concentration, 362 ]): 363 return (delta("_log_normalization", is_property=False) 364 - math_ops.digamma(d1.concentration1) * delta("concentration1") 365 - math_ops.digamma(d1.concentration0) * delta("concentration0") 366 + (math_ops.digamma(d1.total_concentration) 367 * delta("total_concentration"))) 368