1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """The DirichletMultinomial distribution class.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.framework import dtypes 22 from tensorflow.python.framework import ops 23 from tensorflow.python.ops import array_ops 24 from tensorflow.python.ops import check_ops 25 from tensorflow.python.ops import control_flow_ops 26 from tensorflow.python.ops import math_ops 27 from tensorflow.python.ops import random_ops 28 from tensorflow.python.ops import special_math_ops 29 from tensorflow.python.ops.distributions import distribution 30 from tensorflow.python.ops.distributions import util as distribution_util 31 from tensorflow.python.util.tf_export import tf_export 32 33 34 __all__ = [ 35 "DirichletMultinomial", 36 ] 37 38 39 _dirichlet_multinomial_sample_note = """For each batch of counts, 40 `value = [n_0, ..., n_{K-1}]`, `P[value]` is the probability that after 41 sampling `self.total_count` draws from this Dirichlet-Multinomial distribution, 42 the number of draws falling in class `j` is `n_j`. Since this definition is 43 [exchangeable](https://en.wikipedia.org/wiki/Exchangeable_random_variables); 44 different sequences have the same counts so the probability includes a 45 combinatorial coefficient. 46 47 Note: `value` must be a non-negative tensor with dtype `self.dtype`, have no 48 fractional components, and such that 49 `tf.reduce_sum(value, -1) = self.total_count`. Its shape must be broadcastable 50 with `self.concentration` and `self.total_count`.""" 51 52 53 @tf_export("distributions.DirichletMultinomial") 54 class DirichletMultinomial(distribution.Distribution): 55 """Dirichlet-Multinomial compound distribution. 56 57 The Dirichlet-Multinomial distribution is parameterized by a (batch of) 58 length-`K` `concentration` vectors (`K > 1`) and a `total_count` number of 59 trials, i.e., the number of trials per draw from the DirichletMultinomial. It 60 is defined over a (batch of) length-`K` vector `counts` such that 61 `tf.reduce_sum(counts, -1) = total_count`. The Dirichlet-Multinomial is 62 identically the Beta-Binomial distribution when `K = 2`. 63 64 #### Mathematical Details 65 66 The Dirichlet-Multinomial is a distribution over `K`-class counts, i.e., a 67 length-`K` vector of non-negative integer `counts = n = [n_0, ..., n_{K-1}]`. 68 69 The probability mass function (pmf) is, 70 71 ```none 72 pmf(n; alpha, N) = Beta(alpha + n) / (prod_j n_j!) / Z 73 Z = Beta(alpha) / N! 74 ``` 75 76 where: 77 78 * `concentration = alpha = [alpha_0, ..., alpha_{K-1}]`, `alpha_j > 0`, 79 * `total_count = N`, `N` a positive integer, 80 * `N!` is `N` factorial, and, 81 * `Beta(x) = prod_j Gamma(x_j) / Gamma(sum_j x_j)` is the 82 [multivariate beta function]( 83 https://en.wikipedia.org/wiki/Beta_function#Multivariate_beta_function), 84 and, 85 * `Gamma` is the [gamma function]( 86 https://en.wikipedia.org/wiki/Gamma_function). 87 88 Dirichlet-Multinomial is a [compound distribution]( 89 https://en.wikipedia.org/wiki/Compound_probability_distribution), i.e., its 90 samples are generated as follows. 91 92 1. Choose class probabilities: 93 `probs = [p_0,...,p_{K-1}] ~ Dir(concentration)` 94 2. Draw integers: 95 `counts = [n_0,...,n_{K-1}] ~ Multinomial(total_count, probs)` 96 97 The last `concentration` dimension parametrizes a single Dirichlet-Multinomial 98 distribution. When calling distribution functions (e.g., `dist.prob(counts)`), 99 `concentration`, `total_count` and `counts` are broadcast to the same shape. 100 The last dimension of `counts` corresponds single Dirichlet-Multinomial 101 distributions. 102 103 Distribution parameters are automatically broadcast in all functions; see 104 examples for details. 105 106 #### Pitfalls 107 108 The number of classes, `K`, must not exceed: 109 - the largest integer representable by `self.dtype`, i.e., 110 `2**(mantissa_bits+1)` (IEE754), 111 - the maximum `Tensor` index, i.e., `2**31-1`. 112 113 In other words, 114 115 ```python 116 K <= min(2**31-1, { 117 tf.float16: 2**11, 118 tf.float32: 2**24, 119 tf.float64: 2**53 }[param.dtype]) 120 ``` 121 122 Note: This condition is validated only when `self.validate_args = True`. 123 124 #### Examples 125 126 ```python 127 alpha = [1., 2., 3.] 128 n = 2. 129 dist = DirichletMultinomial(n, alpha) 130 ``` 131 132 Creates a 3-class distribution, with the 3rd class is most likely to be 133 drawn. 134 The distribution functions can be evaluated on counts. 135 136 ```python 137 # counts same shape as alpha. 138 counts = [0., 0., 2.] 139 dist.prob(counts) # Shape [] 140 141 # alpha will be broadcast to [[1., 2., 3.], [1., 2., 3.]] to match counts. 142 counts = [[1., 1., 0.], [1., 0., 1.]] 143 dist.prob(counts) # Shape [2] 144 145 # alpha will be broadcast to shape [5, 7, 3] to match counts. 146 counts = [[...]] # Shape [5, 7, 3] 147 dist.prob(counts) # Shape [5, 7] 148 ``` 149 150 Creates a 2-batch of 3-class distributions. 151 152 ```python 153 alpha = [[1., 2., 3.], [4., 5., 6.]] # Shape [2, 3] 154 n = [3., 3.] 155 dist = DirichletMultinomial(n, alpha) 156 157 # counts will be broadcast to [[2., 1., 0.], [2., 1., 0.]] to match alpha. 158 counts = [2., 1., 0.] 159 dist.prob(counts) # Shape [2] 160 ``` 161 162 """ 163 164 # TODO(b/27419586) Change docstring for dtype of concentration once int 165 # allowed. 166 def __init__(self, 167 total_count, 168 concentration, 169 validate_args=False, 170 allow_nan_stats=True, 171 name="DirichletMultinomial"): 172 """Initialize a batch of DirichletMultinomial distributions. 173 174 Args: 175 total_count: Non-negative floating point tensor, whose dtype is the same 176 as `concentration`. The shape is broadcastable to `[N1,..., Nm]` with 177 `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different 178 Dirichlet multinomial distributions. Its components should be equal to 179 integer values. 180 concentration: Positive floating point tensor, whose dtype is the 181 same as `n` with shape broadcastable to `[N1,..., Nm, K]` `m >= 0`. 182 Defines this as a batch of `N1 x ... x Nm` different `K` class Dirichlet 183 multinomial distributions. 184 validate_args: Python `bool`, default `False`. When `True` distribution 185 parameters are checked for validity despite possibly degrading runtime 186 performance. When `False` invalid inputs may silently render incorrect 187 outputs. 188 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 189 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 190 result is undefined. When `False`, an exception is raised if one or 191 more of the statistic's batch members are undefined. 192 name: Python `str` name prefixed to Ops created by this class. 193 """ 194 parameters = locals() 195 with ops.name_scope(name, values=[total_count, concentration]): 196 # Broadcasting works because: 197 # * The broadcasting convention is to prepend dimensions of size [1], and 198 # we use the last dimension for the distribution, whereas 199 # the batch dimensions are the leading dimensions, which forces the 200 # distribution dimension to be defined explicitly (i.e. it cannot be 201 # created automatically by prepending). This forces enough explicitness. 202 # * All calls involving `counts` eventually require a broadcast between 203 # `counts` and concentration. 204 self._total_count = ops.convert_to_tensor(total_count, name="total_count") 205 if validate_args: 206 self._total_count = ( 207 distribution_util.embed_check_nonnegative_integer_form( 208 self._total_count)) 209 self._concentration = self._maybe_assert_valid_concentration( 210 ops.convert_to_tensor(concentration, 211 name="concentration"), 212 validate_args) 213 self._total_concentration = math_ops.reduce_sum(self._concentration, -1) 214 super(DirichletMultinomial, self).__init__( 215 dtype=self._concentration.dtype, 216 validate_args=validate_args, 217 allow_nan_stats=allow_nan_stats, 218 reparameterization_type=distribution.NOT_REPARAMETERIZED, 219 parameters=parameters, 220 graph_parents=[self._total_count, 221 self._concentration], 222 name=name) 223 224 @property 225 def total_count(self): 226 """Number of trials used to construct a sample.""" 227 return self._total_count 228 229 @property 230 def concentration(self): 231 """Concentration parameter; expected prior counts for that coordinate.""" 232 return self._concentration 233 234 @property 235 def total_concentration(self): 236 """Sum of last dim of concentration parameter.""" 237 return self._total_concentration 238 239 def _batch_shape_tensor(self): 240 return array_ops.shape(self.total_concentration) 241 242 def _batch_shape(self): 243 return self.total_concentration.get_shape() 244 245 def _event_shape_tensor(self): 246 return array_ops.shape(self.concentration)[-1:] 247 248 def _event_shape(self): 249 # Event shape depends only on total_concentration, not "n". 250 return self.concentration.get_shape().with_rank_at_least(1)[-1:] 251 252 def _sample_n(self, n, seed=None): 253 n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32) 254 k = self.event_shape_tensor()[0] 255 unnormalized_logits = array_ops.reshape( 256 math_ops.log(random_ops.random_gamma( 257 shape=[n], 258 alpha=self.concentration, 259 dtype=self.dtype, 260 seed=seed)), 261 shape=[-1, k]) 262 draws = random_ops.multinomial( 263 logits=unnormalized_logits, 264 num_samples=n_draws, 265 seed=distribution_util.gen_new_seed(seed, salt="dirichlet_multinomial")) 266 x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k), -2) 267 final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0) 268 x = array_ops.reshape(x, final_shape) 269 return math_ops.cast(x, self.dtype) 270 271 @distribution_util.AppendDocstring(_dirichlet_multinomial_sample_note) 272 def _log_prob(self, counts): 273 counts = self._maybe_assert_valid_sample(counts) 274 ordered_prob = ( 275 special_math_ops.lbeta(self.concentration + counts) 276 - special_math_ops.lbeta(self.concentration)) 277 return ordered_prob + distribution_util.log_combinations( 278 self.total_count, counts) 279 280 @distribution_util.AppendDocstring(_dirichlet_multinomial_sample_note) 281 def _prob(self, counts): 282 return math_ops.exp(self._log_prob(counts)) 283 284 def _mean(self): 285 return self.total_count * (self.concentration / 286 self.total_concentration[..., array_ops.newaxis]) 287 288 @distribution_util.AppendDocstring( 289 """The covariance for each batch member is defined as the following: 290 291 ```none 292 Var(X_j) = n * alpha_j / alpha_0 * (1 - alpha_j / alpha_0) * 293 (n + alpha_0) / (1 + alpha_0) 294 ``` 295 296 where `concentration = alpha` and 297 `total_concentration = alpha_0 = sum_j alpha_j`. 298 299 The covariance between elements in a batch is defined as: 300 301 ```none 302 Cov(X_i, X_j) = -n * alpha_i * alpha_j / alpha_0 ** 2 * 303 (n + alpha_0) / (1 + alpha_0) 304 ``` 305 """) 306 def _covariance(self): 307 x = self._variance_scale_term() * self._mean() 308 return array_ops.matrix_set_diag( 309 -math_ops.matmul(x[..., array_ops.newaxis], 310 x[..., array_ops.newaxis, :]), # outer prod 311 self._variance()) 312 313 def _variance(self): 314 scale = self._variance_scale_term() 315 x = scale * self._mean() 316 return x * (self.total_count * scale - x) 317 318 def _variance_scale_term(self): 319 """Helper to `_covariance` and `_variance` which computes a shared scale.""" 320 # We must take care to expand back the last dim whenever we use the 321 # total_concentration. 322 c0 = self.total_concentration[..., array_ops.newaxis] 323 return math_ops.sqrt((1. + c0 / self.total_count) / (1. + c0)) 324 325 def _maybe_assert_valid_concentration(self, concentration, validate_args): 326 """Checks the validity of the concentration parameter.""" 327 if not validate_args: 328 return concentration 329 concentration = distribution_util.embed_check_categorical_event_shape( 330 concentration) 331 return control_flow_ops.with_dependencies([ 332 check_ops.assert_positive( 333 concentration, 334 message="Concentration parameter must be positive."), 335 ], concentration) 336 337 def _maybe_assert_valid_sample(self, counts): 338 """Check counts for proper shape, values, then return tensor version.""" 339 if not self.validate_args: 340 return counts 341 counts = distribution_util.embed_check_nonnegative_integer_form(counts) 342 return control_flow_ops.with_dependencies([ 343 check_ops.assert_equal( 344 self.total_count, math_ops.reduce_sum(counts, -1), 345 message="counts last-dimension must sum to `self.total_count`"), 346 ], counts) 347