Home | History | Annotate | Download | only in distributions
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """The DirichletMultinomial distribution class."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.framework import dtypes
     22 from tensorflow.python.framework import ops
     23 from tensorflow.python.ops import array_ops
     24 from tensorflow.python.ops import check_ops
     25 from tensorflow.python.ops import control_flow_ops
     26 from tensorflow.python.ops import math_ops
     27 from tensorflow.python.ops import random_ops
     28 from tensorflow.python.ops import special_math_ops
     29 from tensorflow.python.ops.distributions import distribution
     30 from tensorflow.python.ops.distributions import util as distribution_util
     31 from tensorflow.python.util.tf_export import tf_export
     32 
     33 
     34 __all__ = [
     35     "DirichletMultinomial",
     36 ]
     37 
     38 
     39 _dirichlet_multinomial_sample_note = """For each batch of counts,
     40 `value = [n_0, ..., n_{K-1}]`, `P[value]` is the probability that after
     41 sampling `self.total_count` draws from this Dirichlet-Multinomial distribution,
     42 the number of draws falling in class `j` is `n_j`. Since this definition is
     43 [exchangeable](https://en.wikipedia.org/wiki/Exchangeable_random_variables);
     44 different sequences have the same counts so the probability includes a
     45 combinatorial coefficient.
     46 
     47 Note: `value` must be a non-negative tensor with dtype `self.dtype`, have no
     48 fractional components, and such that
     49 `tf.reduce_sum(value, -1) = self.total_count`. Its shape must be broadcastable
     50 with `self.concentration` and `self.total_count`."""
     51 
     52 
     53 @tf_export("distributions.DirichletMultinomial")
     54 class DirichletMultinomial(distribution.Distribution):
     55   """Dirichlet-Multinomial compound distribution.
     56 
     57   The Dirichlet-Multinomial distribution is parameterized by a (batch of)
     58   length-`K` `concentration` vectors (`K > 1`) and a `total_count` number of
     59   trials, i.e., the number of trials per draw from the DirichletMultinomial. It
     60   is defined over a (batch of) length-`K` vector `counts` such that
     61   `tf.reduce_sum(counts, -1) = total_count`. The Dirichlet-Multinomial is
     62   identically the Beta-Binomial distribution when `K = 2`.
     63 
     64   #### Mathematical Details
     65 
     66   The Dirichlet-Multinomial is a distribution over `K`-class counts, i.e., a
     67   length-`K` vector of non-negative integer `counts = n = [n_0, ..., n_{K-1}]`.
     68 
     69   The probability mass function (pmf) is,
     70 
     71   ```none
     72   pmf(n; alpha, N) = Beta(alpha + n) / (prod_j n_j!) / Z
     73   Z = Beta(alpha) / N!
     74   ```
     75 
     76   where:
     77 
     78   * `concentration = alpha = [alpha_0, ..., alpha_{K-1}]`, `alpha_j > 0`,
     79   * `total_count = N`, `N` a positive integer,
     80   * `N!` is `N` factorial, and,
     81   * `Beta(x) = prod_j Gamma(x_j) / Gamma(sum_j x_j)` is the
     82     [multivariate beta function](
     83     https://en.wikipedia.org/wiki/Beta_function#Multivariate_beta_function),
     84     and,
     85   * `Gamma` is the [gamma function](
     86     https://en.wikipedia.org/wiki/Gamma_function).
     87 
     88   Dirichlet-Multinomial is a [compound distribution](
     89   https://en.wikipedia.org/wiki/Compound_probability_distribution), i.e., its
     90   samples are generated as follows.
     91 
     92     1. Choose class probabilities:
     93        `probs = [p_0,...,p_{K-1}] ~ Dir(concentration)`
     94     2. Draw integers:
     95        `counts = [n_0,...,n_{K-1}] ~ Multinomial(total_count, probs)`
     96 
     97   The last `concentration` dimension parametrizes a single Dirichlet-Multinomial
     98   distribution. When calling distribution functions (e.g., `dist.prob(counts)`),
     99   `concentration`, `total_count` and `counts` are broadcast to the same shape.
    100   The last dimension of `counts` corresponds single Dirichlet-Multinomial
    101   distributions.
    102 
    103   Distribution parameters are automatically broadcast in all functions; see
    104   examples for details.
    105 
    106   #### Pitfalls
    107 
    108   The number of classes, `K`, must not exceed:
    109   - the largest integer representable by `self.dtype`, i.e.,
    110     `2**(mantissa_bits+1)` (IEE754),
    111   - the maximum `Tensor` index, i.e., `2**31-1`.
    112 
    113   In other words,
    114 
    115   ```python
    116   K <= min(2**31-1, {
    117     tf.float16: 2**11,
    118     tf.float32: 2**24,
    119     tf.float64: 2**53 }[param.dtype])
    120   ```
    121 
    122   Note: This condition is validated only when `self.validate_args = True`.
    123 
    124   #### Examples
    125 
    126   ```python
    127   alpha = [1., 2., 3.]
    128   n = 2.
    129   dist = DirichletMultinomial(n, alpha)
    130   ```
    131 
    132   Creates a 3-class distribution, with the 3rd class is most likely to be
    133   drawn.
    134   The distribution functions can be evaluated on counts.
    135 
    136   ```python
    137   # counts same shape as alpha.
    138   counts = [0., 0., 2.]
    139   dist.prob(counts)  # Shape []
    140 
    141   # alpha will be broadcast to [[1., 2., 3.], [1., 2., 3.]] to match counts.
    142   counts = [[1., 1., 0.], [1., 0., 1.]]
    143   dist.prob(counts)  # Shape [2]
    144 
    145   # alpha will be broadcast to shape [5, 7, 3] to match counts.
    146   counts = [[...]]  # Shape [5, 7, 3]
    147   dist.prob(counts)  # Shape [5, 7]
    148   ```
    149 
    150   Creates a 2-batch of 3-class distributions.
    151 
    152   ```python
    153   alpha = [[1., 2., 3.], [4., 5., 6.]]  # Shape [2, 3]
    154   n = [3., 3.]
    155   dist = DirichletMultinomial(n, alpha)
    156 
    157   # counts will be broadcast to [[2., 1., 0.], [2., 1., 0.]] to match alpha.
    158   counts = [2., 1., 0.]
    159   dist.prob(counts)  # Shape [2]
    160   ```
    161 
    162   """
    163 
    164   # TODO(b/27419586) Change docstring for dtype of concentration once int
    165   # allowed.
    166   def __init__(self,
    167                total_count,
    168                concentration,
    169                validate_args=False,
    170                allow_nan_stats=True,
    171                name="DirichletMultinomial"):
    172     """Initialize a batch of DirichletMultinomial distributions.
    173 
    174     Args:
    175       total_count:  Non-negative floating point tensor, whose dtype is the same
    176         as `concentration`. The shape is broadcastable to `[N1,..., Nm]` with
    177         `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different
    178         Dirichlet multinomial distributions. Its components should be equal to
    179         integer values.
    180       concentration: Positive floating point tensor, whose dtype is the
    181         same as `n` with shape broadcastable to `[N1,..., Nm, K]` `m >= 0`.
    182         Defines this as a batch of `N1 x ... x Nm` different `K` class Dirichlet
    183         multinomial distributions.
    184       validate_args: Python `bool`, default `False`. When `True` distribution
    185         parameters are checked for validity despite possibly degrading runtime
    186         performance. When `False` invalid inputs may silently render incorrect
    187         outputs.
    188       allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
    189         (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
    190         result is undefined. When `False`, an exception is raised if one or
    191         more of the statistic's batch members are undefined.
    192       name: Python `str` name prefixed to Ops created by this class.
    193     """
    194     parameters = locals()
    195     with ops.name_scope(name, values=[total_count, concentration]):
    196       # Broadcasting works because:
    197       # * The broadcasting convention is to prepend dimensions of size [1], and
    198       #   we use the last dimension for the distribution, whereas
    199       #   the batch dimensions are the leading dimensions, which forces the
    200       #   distribution dimension to be defined explicitly (i.e. it cannot be
    201       #   created automatically by prepending). This forces enough explicitness.
    202       # * All calls involving `counts` eventually require a broadcast between
    203       #  `counts` and concentration.
    204       self._total_count = ops.convert_to_tensor(total_count, name="total_count")
    205       if validate_args:
    206         self._total_count = (
    207             distribution_util.embed_check_nonnegative_integer_form(
    208                 self._total_count))
    209       self._concentration = self._maybe_assert_valid_concentration(
    210           ops.convert_to_tensor(concentration,
    211                                 name="concentration"),
    212           validate_args)
    213       self._total_concentration = math_ops.reduce_sum(self._concentration, -1)
    214     super(DirichletMultinomial, self).__init__(
    215         dtype=self._concentration.dtype,
    216         validate_args=validate_args,
    217         allow_nan_stats=allow_nan_stats,
    218         reparameterization_type=distribution.NOT_REPARAMETERIZED,
    219         parameters=parameters,
    220         graph_parents=[self._total_count,
    221                        self._concentration],
    222         name=name)
    223 
    224   @property
    225   def total_count(self):
    226     """Number of trials used to construct a sample."""
    227     return self._total_count
    228 
    229   @property
    230   def concentration(self):
    231     """Concentration parameter; expected prior counts for that coordinate."""
    232     return self._concentration
    233 
    234   @property
    235   def total_concentration(self):
    236     """Sum of last dim of concentration parameter."""
    237     return self._total_concentration
    238 
    239   def _batch_shape_tensor(self):
    240     return array_ops.shape(self.total_concentration)
    241 
    242   def _batch_shape(self):
    243     return self.total_concentration.get_shape()
    244 
    245   def _event_shape_tensor(self):
    246     return array_ops.shape(self.concentration)[-1:]
    247 
    248   def _event_shape(self):
    249     # Event shape depends only on total_concentration, not "n".
    250     return self.concentration.get_shape().with_rank_at_least(1)[-1:]
    251 
    252   def _sample_n(self, n, seed=None):
    253     n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
    254     k = self.event_shape_tensor()[0]
    255     unnormalized_logits = array_ops.reshape(
    256         math_ops.log(random_ops.random_gamma(
    257             shape=[n],
    258             alpha=self.concentration,
    259             dtype=self.dtype,
    260             seed=seed)),
    261         shape=[-1, k])
    262     draws = random_ops.multinomial(
    263         logits=unnormalized_logits,
    264         num_samples=n_draws,
    265         seed=distribution_util.gen_new_seed(seed, salt="dirichlet_multinomial"))
    266     x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k), -2)
    267     final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0)
    268     x = array_ops.reshape(x, final_shape)
    269     return math_ops.cast(x, self.dtype)
    270 
    271   @distribution_util.AppendDocstring(_dirichlet_multinomial_sample_note)
    272   def _log_prob(self, counts):
    273     counts = self._maybe_assert_valid_sample(counts)
    274     ordered_prob = (
    275         special_math_ops.lbeta(self.concentration + counts)
    276         - special_math_ops.lbeta(self.concentration))
    277     return ordered_prob + distribution_util.log_combinations(
    278         self.total_count, counts)
    279 
    280   @distribution_util.AppendDocstring(_dirichlet_multinomial_sample_note)
    281   def _prob(self, counts):
    282     return math_ops.exp(self._log_prob(counts))
    283 
    284   def _mean(self):
    285     return self.total_count * (self.concentration /
    286                                self.total_concentration[..., array_ops.newaxis])
    287 
    288   @distribution_util.AppendDocstring(
    289       """The covariance for each batch member is defined as the following:
    290 
    291       ```none
    292       Var(X_j) = n * alpha_j / alpha_0 * (1 - alpha_j / alpha_0) *
    293       (n + alpha_0) / (1 + alpha_0)
    294       ```
    295 
    296       where `concentration = alpha` and
    297       `total_concentration = alpha_0 = sum_j alpha_j`.
    298 
    299       The covariance between elements in a batch is defined as:
    300 
    301       ```none
    302       Cov(X_i, X_j) = -n * alpha_i * alpha_j / alpha_0 ** 2 *
    303       (n + alpha_0) / (1 + alpha_0)
    304       ```
    305       """)
    306   def _covariance(self):
    307     x = self._variance_scale_term() * self._mean()
    308     return array_ops.matrix_set_diag(
    309         -math_ops.matmul(x[..., array_ops.newaxis],
    310                          x[..., array_ops.newaxis, :]),  # outer prod
    311         self._variance())
    312 
    313   def _variance(self):
    314     scale = self._variance_scale_term()
    315     x = scale * self._mean()
    316     return x * (self.total_count * scale - x)
    317 
    318   def _variance_scale_term(self):
    319     """Helper to `_covariance` and `_variance` which computes a shared scale."""
    320     # We must take care to expand back the last dim whenever we use the
    321     # total_concentration.
    322     c0 = self.total_concentration[..., array_ops.newaxis]
    323     return math_ops.sqrt((1. + c0 / self.total_count) / (1. + c0))
    324 
    325   def _maybe_assert_valid_concentration(self, concentration, validate_args):
    326     """Checks the validity of the concentration parameter."""
    327     if not validate_args:
    328       return concentration
    329     concentration = distribution_util.embed_check_categorical_event_shape(
    330         concentration)
    331     return control_flow_ops.with_dependencies([
    332         check_ops.assert_positive(
    333             concentration,
    334             message="Concentration parameter must be positive."),
    335     ], concentration)
    336 
    337   def _maybe_assert_valid_sample(self, counts):
    338     """Check counts for proper shape, values, then return tensor version."""
    339     if not self.validate_args:
    340       return counts
    341     counts = distribution_util.embed_check_nonnegative_integer_form(counts)
    342     return control_flow_ops.with_dependencies([
    343         check_ops.assert_equal(
    344             self.total_count, math_ops.reduce_sum(counts, -1),
    345             message="counts last-dimension must sum to `self.total_count`"),
    346     ], counts)
    347