1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """The Multinomial distribution class.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.framework import dtypes 22 from tensorflow.python.framework import ops 23 from tensorflow.python.ops import array_ops 24 from tensorflow.python.ops import check_ops 25 from tensorflow.python.ops import control_flow_ops 26 from tensorflow.python.ops import map_fn 27 from tensorflow.python.ops import math_ops 28 from tensorflow.python.ops import nn_ops 29 from tensorflow.python.ops import random_ops 30 from tensorflow.python.ops.distributions import distribution 31 from tensorflow.python.ops.distributions import util as distribution_util 32 from tensorflow.python.util import deprecation 33 from tensorflow.python.util.tf_export import tf_export 34 35 36 __all__ = [ 37 "Multinomial", 38 ] 39 40 41 _multinomial_sample_note = """For each batch of counts, `value = [n_0, ... 42 ,n_{k-1}]`, `P[value]` is the probability that after sampling `self.total_count` 43 draws from this Multinomial distribution, the number of draws falling in class 44 `j` is `n_j`. Since this definition is [exchangeable]( 45 https://en.wikipedia.org/wiki/Exchangeable_random_variables); different 46 sequences have the same counts so the probability includes a combinatorial 47 coefficient. 48 49 Note: `value` must be a non-negative tensor with dtype `self.dtype`, have no 50 fractional components, and such that 51 `tf.reduce_sum(value, -1) = self.total_count`. Its shape must be broadcastable 52 with `self.probs` and `self.total_count`.""" 53 54 55 @tf_export(v1=["distributions.Multinomial"]) 56 class Multinomial(distribution.Distribution): 57 """Multinomial distribution. 58 59 This Multinomial distribution is parameterized by `probs`, a (batch of) 60 length-`K` `prob` (probability) vectors (`K > 1`) such that 61 `tf.reduce_sum(probs, -1) = 1`, and a `total_count` number of trials, i.e., 62 the number of trials per draw from the Multinomial. It is defined over a 63 (batch of) length-`K` vector `counts` such that 64 `tf.reduce_sum(counts, -1) = total_count`. The Multinomial is identically the 65 Binomial distribution when `K = 2`. 66 67 #### Mathematical Details 68 69 The Multinomial is a distribution over `K`-class counts, i.e., a length-`K` 70 vector of non-negative integer `counts = n = [n_0, ..., n_{K-1}]`. 71 72 The probability mass function (pmf) is, 73 74 ```none 75 pmf(n; pi, N) = prod_j (pi_j)**n_j / Z 76 Z = (prod_j n_j!) / N! 77 ``` 78 79 where: 80 * `probs = pi = [pi_0, ..., pi_{K-1}]`, `pi_j > 0`, `sum_j pi_j = 1`, 81 * `total_count = N`, `N` a positive integer, 82 * `Z` is the normalization constant, and, 83 * `N!` denotes `N` factorial. 84 85 Distribution parameters are automatically broadcast in all functions; see 86 examples for details. 87 88 #### Pitfalls 89 90 The number of classes, `K`, must not exceed: 91 - the largest integer representable by `self.dtype`, i.e., 92 `2**(mantissa_bits+1)` (IEE754), 93 - the maximum `Tensor` index, i.e., `2**31-1`. 94 95 In other words, 96 97 ```python 98 K <= min(2**31-1, { 99 tf.float16: 2**11, 100 tf.float32: 2**24, 101 tf.float64: 2**53 }[param.dtype]) 102 ``` 103 104 Note: This condition is validated only when `self.validate_args = True`. 105 106 #### Examples 107 108 Create a 3-class distribution, with the 3rd class is most likely to be drawn, 109 using logits. 110 111 ```python 112 logits = [-50., -43, 0] 113 dist = Multinomial(total_count=4., logits=logits) 114 ``` 115 116 Create a 3-class distribution, with the 3rd class is most likely to be drawn. 117 118 ```python 119 p = [.2, .3, .5] 120 dist = Multinomial(total_count=4., probs=p) 121 ``` 122 123 The distribution functions can be evaluated on counts. 124 125 ```python 126 # counts same shape as p. 127 counts = [1., 0, 3] 128 dist.prob(counts) # Shape [] 129 130 # p will be broadcast to [[.2, .3, .5], [.2, .3, .5]] to match counts. 131 counts = [[1., 2, 1], [2, 2, 0]] 132 dist.prob(counts) # Shape [2] 133 134 # p will be broadcast to shape [5, 7, 3] to match counts. 135 counts = [[...]] # Shape [5, 7, 3] 136 dist.prob(counts) # Shape [5, 7] 137 ``` 138 139 Create a 2-batch of 3-class distributions. 140 141 ```python 142 p = [[.1, .2, .7], [.3, .3, .4]] # Shape [2, 3] 143 dist = Multinomial(total_count=[4., 5], probs=p) 144 145 counts = [[2., 1, 1], [3, 1, 1]] 146 dist.prob(counts) # Shape [2] 147 148 dist.sample(5) # Shape [5, 2, 3] 149 ``` 150 """ 151 152 @deprecation.deprecated( 153 "2019-01-01", 154 "The TensorFlow Distributions library has moved to " 155 "TensorFlow Probability " 156 "(https://github.com/tensorflow/probability). You " 157 "should update all references to use `tfp.distributions` " 158 "instead of `tf.distributions`.", 159 warn_once=True) 160 def __init__(self, 161 total_count, 162 logits=None, 163 probs=None, 164 validate_args=False, 165 allow_nan_stats=True, 166 name="Multinomial"): 167 """Initialize a batch of Multinomial distributions. 168 169 Args: 170 total_count: Non-negative floating point tensor with shape broadcastable 171 to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of 172 `N1 x ... x Nm` different Multinomial distributions. Its components 173 should be equal to integer values. 174 logits: Floating point tensor representing unnormalized log-probabilities 175 of a positive event with shape broadcastable to 176 `[N1,..., Nm, K]` `m >= 0`, and the same dtype as `total_count`. Defines 177 this as a batch of `N1 x ... x Nm` different `K` class Multinomial 178 distributions. Only one of `logits` or `probs` should be passed in. 179 probs: Positive floating point tensor with shape broadcastable to 180 `[N1,..., Nm, K]` `m >= 0` and same dtype as `total_count`. Defines 181 this as a batch of `N1 x ... x Nm` different `K` class Multinomial 182 distributions. `probs`'s components in the last portion of its shape 183 should sum to `1`. Only one of `logits` or `probs` should be passed in. 184 validate_args: Python `bool`, default `False`. When `True` distribution 185 parameters are checked for validity despite possibly degrading runtime 186 performance. When `False` invalid inputs may silently render incorrect 187 outputs. 188 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 189 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 190 result is undefined. When `False`, an exception is raised if one or 191 more of the statistic's batch members are undefined. 192 name: Python `str` name prefixed to Ops created by this class. 193 """ 194 parameters = dict(locals()) 195 with ops.name_scope(name, values=[total_count, logits, probs]) as name: 196 self._total_count = ops.convert_to_tensor(total_count, name="total_count") 197 if validate_args: 198 self._total_count = ( 199 distribution_util.embed_check_nonnegative_integer_form( 200 self._total_count)) 201 self._logits, self._probs = distribution_util.get_logits_and_probs( 202 logits=logits, 203 probs=probs, 204 multidimensional=True, 205 validate_args=validate_args, 206 name=name) 207 self._mean_val = self._total_count[..., array_ops.newaxis] * self._probs 208 super(Multinomial, self).__init__( 209 dtype=self._probs.dtype, 210 reparameterization_type=distribution.NOT_REPARAMETERIZED, 211 validate_args=validate_args, 212 allow_nan_stats=allow_nan_stats, 213 parameters=parameters, 214 graph_parents=[self._total_count, 215 self._logits, 216 self._probs], 217 name=name) 218 219 @property 220 def total_count(self): 221 """Number of trials used to construct a sample.""" 222 return self._total_count 223 224 @property 225 def logits(self): 226 """Vector of coordinatewise logits.""" 227 return self._logits 228 229 @property 230 def probs(self): 231 """Probability of drawing a `1` in that coordinate.""" 232 return self._probs 233 234 def _batch_shape_tensor(self): 235 return array_ops.shape(self._mean_val)[:-1] 236 237 def _batch_shape(self): 238 return self._mean_val.get_shape().with_rank_at_least(1)[:-1] 239 240 def _event_shape_tensor(self): 241 return array_ops.shape(self._mean_val)[-1:] 242 243 def _event_shape(self): 244 return self._mean_val.get_shape().with_rank_at_least(1)[-1:] 245 246 def _sample_n(self, n, seed=None): 247 n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32) 248 k = self.event_shape_tensor()[0] 249 250 # broadcast the total_count and logits to same shape 251 n_draws = array_ops.ones_like( 252 self.logits[..., 0], dtype=n_draws.dtype) * n_draws 253 logits = array_ops.ones_like( 254 n_draws[..., array_ops.newaxis], dtype=self.logits.dtype) * self.logits 255 256 # flatten the total_count and logits 257 flat_logits = array_ops.reshape(logits, [-1, k]) # [B1B2...Bm, k] 258 flat_ndraws = n * array_ops.reshape(n_draws, [-1]) # [B1B2...Bm] 259 260 # computes each total_count and logits situation by map_fn 261 def _sample_single(args): 262 logits, n_draw = args[0], args[1] # [K], [] 263 x = random_ops.multinomial(logits[array_ops.newaxis, ...], n_draw, 264 seed) # [1, n*n_draw] 265 x = array_ops.reshape(x, shape=[n, -1]) # [n, n_draw] 266 x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k), axis=-2) # [n, k] 267 return x 268 269 x = map_fn.map_fn( 270 _sample_single, [flat_logits, flat_ndraws], 271 dtype=self.dtype) # [B1B2...Bm, n, k] 272 273 # reshape the results to proper shape 274 x = array_ops.transpose(x, perm=[1, 0, 2]) 275 final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0) 276 x = array_ops.reshape(x, final_shape) # [n, B1, B2,..., Bm, k] 277 return x 278 279 @distribution_util.AppendDocstring(_multinomial_sample_note) 280 def _log_prob(self, counts): 281 return self._log_unnormalized_prob(counts) - self._log_normalization(counts) 282 283 def _log_unnormalized_prob(self, counts): 284 counts = self._maybe_assert_valid_sample(counts) 285 return math_ops.reduce_sum(counts * nn_ops.log_softmax(self.logits), -1) 286 287 def _log_normalization(self, counts): 288 counts = self._maybe_assert_valid_sample(counts) 289 return -distribution_util.log_combinations(self.total_count, counts) 290 291 def _mean(self): 292 return array_ops.identity(self._mean_val) 293 294 def _covariance(self): 295 p = self.probs * array_ops.ones_like( 296 self.total_count)[..., array_ops.newaxis] 297 return array_ops.matrix_set_diag( 298 -math_ops.matmul(self._mean_val[..., array_ops.newaxis], 299 p[..., array_ops.newaxis, :]), # outer product 300 self._variance()) 301 302 def _variance(self): 303 p = self.probs * array_ops.ones_like( 304 self.total_count)[..., array_ops.newaxis] 305 return self._mean_val - self._mean_val * p 306 307 def _maybe_assert_valid_sample(self, counts): 308 """Check counts for proper shape, values, then return tensor version.""" 309 if not self.validate_args: 310 return counts 311 counts = distribution_util.embed_check_nonnegative_integer_form(counts) 312 return control_flow_ops.with_dependencies([ 313 check_ops.assert_equal( 314 self.total_count, math_ops.reduce_sum(counts, -1), 315 message="counts must sum to `self.total_count`"), 316 ], counts) 317