Home | History | Annotate | Download | only in distributions
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """The Multinomial distribution class."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.framework import dtypes
     22 from tensorflow.python.framework import ops
     23 from tensorflow.python.ops import array_ops
     24 from tensorflow.python.ops import check_ops
     25 from tensorflow.python.ops import control_flow_ops
     26 from tensorflow.python.ops import map_fn
     27 from tensorflow.python.ops import math_ops
     28 from tensorflow.python.ops import nn_ops
     29 from tensorflow.python.ops import random_ops
     30 from tensorflow.python.ops.distributions import distribution
     31 from tensorflow.python.ops.distributions import util as distribution_util
     32 from tensorflow.python.util import deprecation
     33 from tensorflow.python.util.tf_export import tf_export
     34 
     35 
     36 __all__ = [
     37     "Multinomial",
     38 ]
     39 
     40 
     41 _multinomial_sample_note = """For each batch of counts, `value = [n_0, ...
     42 ,n_{k-1}]`, `P[value]` is the probability that after sampling `self.total_count`
     43 draws from this Multinomial distribution, the number of draws falling in class
     44 `j` is `n_j`. Since this definition is [exchangeable](
     45 https://en.wikipedia.org/wiki/Exchangeable_random_variables); different
     46 sequences have the same counts so the probability includes a combinatorial
     47 coefficient.
     48 
     49 Note: `value` must be a non-negative tensor with dtype `self.dtype`, have no
     50 fractional components, and such that
     51 `tf.reduce_sum(value, -1) = self.total_count`. Its shape must be broadcastable
     52 with `self.probs` and `self.total_count`."""
     53 
     54 
     55 @tf_export(v1=["distributions.Multinomial"])
     56 class Multinomial(distribution.Distribution):
     57   """Multinomial distribution.
     58 
     59   This Multinomial distribution is parameterized by `probs`, a (batch of)
     60   length-`K` `prob` (probability) vectors (`K > 1`) such that
     61   `tf.reduce_sum(probs, -1) = 1`, and a `total_count` number of trials, i.e.,
     62   the number of trials per draw from the Multinomial. It is defined over a
     63   (batch of) length-`K` vector `counts` such that
     64   `tf.reduce_sum(counts, -1) = total_count`. The Multinomial is identically the
     65   Binomial distribution when `K = 2`.
     66 
     67   #### Mathematical Details
     68 
     69   The Multinomial is a distribution over `K`-class counts, i.e., a length-`K`
     70   vector of non-negative integer `counts = n = [n_0, ..., n_{K-1}]`.
     71 
     72   The probability mass function (pmf) is,
     73 
     74   ```none
     75   pmf(n; pi, N) = prod_j (pi_j)**n_j / Z
     76   Z = (prod_j n_j!) / N!
     77   ```
     78 
     79   where:
     80   * `probs = pi = [pi_0, ..., pi_{K-1}]`, `pi_j > 0`, `sum_j pi_j = 1`,
     81   * `total_count = N`, `N` a positive integer,
     82   * `Z` is the normalization constant, and,
     83   * `N!` denotes `N` factorial.
     84 
     85   Distribution parameters are automatically broadcast in all functions; see
     86   examples for details.
     87 
     88   #### Pitfalls
     89 
     90   The number of classes, `K`, must not exceed:
     91   - the largest integer representable by `self.dtype`, i.e.,
     92     `2**(mantissa_bits+1)` (IEE754),
     93   - the maximum `Tensor` index, i.e., `2**31-1`.
     94 
     95   In other words,
     96 
     97   ```python
     98   K <= min(2**31-1, {
     99     tf.float16: 2**11,
    100     tf.float32: 2**24,
    101     tf.float64: 2**53 }[param.dtype])
    102   ```
    103 
    104   Note: This condition is validated only when `self.validate_args = True`.
    105 
    106   #### Examples
    107 
    108   Create a 3-class distribution, with the 3rd class is most likely to be drawn,
    109   using logits.
    110 
    111   ```python
    112   logits = [-50., -43, 0]
    113   dist = Multinomial(total_count=4., logits=logits)
    114   ```
    115 
    116   Create a 3-class distribution, with the 3rd class is most likely to be drawn.
    117 
    118   ```python
    119   p = [.2, .3, .5]
    120   dist = Multinomial(total_count=4., probs=p)
    121   ```
    122 
    123   The distribution functions can be evaluated on counts.
    124 
    125   ```python
    126   # counts same shape as p.
    127   counts = [1., 0, 3]
    128   dist.prob(counts)  # Shape []
    129 
    130   # p will be broadcast to [[.2, .3, .5], [.2, .3, .5]] to match counts.
    131   counts = [[1., 2, 1], [2, 2, 0]]
    132   dist.prob(counts)  # Shape [2]
    133 
    134   # p will be broadcast to shape [5, 7, 3] to match counts.
    135   counts = [[...]]  # Shape [5, 7, 3]
    136   dist.prob(counts)  # Shape [5, 7]
    137   ```
    138 
    139   Create a 2-batch of 3-class distributions.
    140 
    141   ```python
    142   p = [[.1, .2, .7], [.3, .3, .4]]  # Shape [2, 3]
    143   dist = Multinomial(total_count=[4., 5], probs=p)
    144 
    145   counts = [[2., 1, 1], [3, 1, 1]]
    146   dist.prob(counts)  # Shape [2]
    147 
    148   dist.sample(5) # Shape [5, 2, 3]
    149   ```
    150   """
    151 
    152   @deprecation.deprecated(
    153       "2019-01-01",
    154       "The TensorFlow Distributions library has moved to "
    155       "TensorFlow Probability "
    156       "(https://github.com/tensorflow/probability). You "
    157       "should update all references to use `tfp.distributions` "
    158       "instead of `tf.distributions`.",
    159       warn_once=True)
    160   def __init__(self,
    161                total_count,
    162                logits=None,
    163                probs=None,
    164                validate_args=False,
    165                allow_nan_stats=True,
    166                name="Multinomial"):
    167     """Initialize a batch of Multinomial distributions.
    168 
    169     Args:
    170       total_count: Non-negative floating point tensor with shape broadcastable
    171         to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of
    172         `N1 x ... x Nm` different Multinomial distributions. Its components
    173         should be equal to integer values.
    174       logits: Floating point tensor representing unnormalized log-probabilities
    175         of a positive event with shape broadcastable to
    176         `[N1,..., Nm, K]` `m >= 0`, and the same dtype as `total_count`. Defines
    177         this as a batch of `N1 x ... x Nm` different `K` class Multinomial
    178         distributions. Only one of `logits` or `probs` should be passed in.
    179       probs: Positive floating point tensor with shape broadcastable to
    180         `[N1,..., Nm, K]` `m >= 0` and same dtype as `total_count`. Defines
    181         this as a batch of `N1 x ... x Nm` different `K` class Multinomial
    182         distributions. `probs`'s components in the last portion of its shape
    183         should sum to `1`. Only one of `logits` or `probs` should be passed in.
    184       validate_args: Python `bool`, default `False`. When `True` distribution
    185         parameters are checked for validity despite possibly degrading runtime
    186         performance. When `False` invalid inputs may silently render incorrect
    187         outputs.
    188       allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
    189         (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
    190         result is undefined. When `False`, an exception is raised if one or
    191         more of the statistic's batch members are undefined.
    192       name: Python `str` name prefixed to Ops created by this class.
    193     """
    194     parameters = dict(locals())
    195     with ops.name_scope(name, values=[total_count, logits, probs]) as name:
    196       self._total_count = ops.convert_to_tensor(total_count, name="total_count")
    197       if validate_args:
    198         self._total_count = (
    199             distribution_util.embed_check_nonnegative_integer_form(
    200                 self._total_count))
    201       self._logits, self._probs = distribution_util.get_logits_and_probs(
    202           logits=logits,
    203           probs=probs,
    204           multidimensional=True,
    205           validate_args=validate_args,
    206           name=name)
    207       self._mean_val = self._total_count[..., array_ops.newaxis] * self._probs
    208     super(Multinomial, self).__init__(
    209         dtype=self._probs.dtype,
    210         reparameterization_type=distribution.NOT_REPARAMETERIZED,
    211         validate_args=validate_args,
    212         allow_nan_stats=allow_nan_stats,
    213         parameters=parameters,
    214         graph_parents=[self._total_count,
    215                        self._logits,
    216                        self._probs],
    217         name=name)
    218 
    219   @property
    220   def total_count(self):
    221     """Number of trials used to construct a sample."""
    222     return self._total_count
    223 
    224   @property
    225   def logits(self):
    226     """Vector of coordinatewise logits."""
    227     return self._logits
    228 
    229   @property
    230   def probs(self):
    231     """Probability of drawing a `1` in that coordinate."""
    232     return self._probs
    233 
    234   def _batch_shape_tensor(self):
    235     return array_ops.shape(self._mean_val)[:-1]
    236 
    237   def _batch_shape(self):
    238     return self._mean_val.get_shape().with_rank_at_least(1)[:-1]
    239 
    240   def _event_shape_tensor(self):
    241     return array_ops.shape(self._mean_val)[-1:]
    242 
    243   def _event_shape(self):
    244     return self._mean_val.get_shape().with_rank_at_least(1)[-1:]
    245 
    246   def _sample_n(self, n, seed=None):
    247     n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
    248     k = self.event_shape_tensor()[0]
    249 
    250     # broadcast the total_count and logits to same shape
    251     n_draws = array_ops.ones_like(
    252         self.logits[..., 0], dtype=n_draws.dtype) * n_draws
    253     logits = array_ops.ones_like(
    254         n_draws[..., array_ops.newaxis], dtype=self.logits.dtype) * self.logits
    255 
    256     # flatten the total_count and logits
    257     flat_logits = array_ops.reshape(logits, [-1, k])  # [B1B2...Bm, k]
    258     flat_ndraws = n * array_ops.reshape(n_draws, [-1])  # [B1B2...Bm]
    259 
    260     # computes each total_count and logits situation by map_fn
    261     def _sample_single(args):
    262       logits, n_draw = args[0], args[1]  # [K], []
    263       x = random_ops.multinomial(logits[array_ops.newaxis, ...], n_draw,
    264                                  seed)  # [1, n*n_draw]
    265       x = array_ops.reshape(x, shape=[n, -1])  # [n, n_draw]
    266       x = math_ops.reduce_sum(array_ops.one_hot(x, depth=k), axis=-2)  # [n, k]
    267       return x
    268 
    269     x = map_fn.map_fn(
    270         _sample_single, [flat_logits, flat_ndraws],
    271         dtype=self.dtype)  # [B1B2...Bm, n, k]
    272 
    273     # reshape the results to proper shape
    274     x = array_ops.transpose(x, perm=[1, 0, 2])
    275     final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0)
    276     x = array_ops.reshape(x, final_shape)  # [n, B1, B2,..., Bm, k]
    277     return x
    278 
    279   @distribution_util.AppendDocstring(_multinomial_sample_note)
    280   def _log_prob(self, counts):
    281     return self._log_unnormalized_prob(counts) - self._log_normalization(counts)
    282 
    283   def _log_unnormalized_prob(self, counts):
    284     counts = self._maybe_assert_valid_sample(counts)
    285     return math_ops.reduce_sum(counts * nn_ops.log_softmax(self.logits), -1)
    286 
    287   def _log_normalization(self, counts):
    288     counts = self._maybe_assert_valid_sample(counts)
    289     return -distribution_util.log_combinations(self.total_count, counts)
    290 
    291   def _mean(self):
    292     return array_ops.identity(self._mean_val)
    293 
    294   def _covariance(self):
    295     p = self.probs * array_ops.ones_like(
    296         self.total_count)[..., array_ops.newaxis]
    297     return array_ops.matrix_set_diag(
    298         -math_ops.matmul(self._mean_val[..., array_ops.newaxis],
    299                          p[..., array_ops.newaxis, :]),  # outer product
    300         self._variance())
    301 
    302   def _variance(self):
    303     p = self.probs * array_ops.ones_like(
    304         self.total_count)[..., array_ops.newaxis]
    305     return self._mean_val - self._mean_val * p
    306 
    307   def _maybe_assert_valid_sample(self, counts):
    308     """Check counts for proper shape, values, then return tensor version."""
    309     if not self.validate_args:
    310       return counts
    311     counts = distribution_util.embed_check_nonnegative_integer_form(counts)
    312     return control_flow_ops.with_dependencies([
    313         check_ops.assert_equal(
    314             self.total_count, math_ops.reduce_sum(counts, -1),
    315             message="counts must sum to `self.total_count`"),
    316     ], counts)
    317