Home | History | Annotate | Download | only in distributions
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """The Beta distribution class."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.framework import constant_op
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.framework import tensor_shape
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.ops import check_ops
     29 from tensorflow.python.ops import control_flow_ops
     30 from tensorflow.python.ops import math_ops
     31 from tensorflow.python.ops import nn
     32 from tensorflow.python.ops import random_ops
     33 from tensorflow.python.ops.distributions import distribution
     34 from tensorflow.python.ops.distributions import kullback_leibler
     35 from tensorflow.python.ops.distributions import util as distribution_util
     36 from tensorflow.python.util.tf_export import tf_export
     37 
     38 
     39 __all__ = [
     40     "Beta",
     41     "BetaWithSoftplusConcentration",
     42 ]
     43 
     44 
     45 _beta_sample_note = """Note: `x` must have dtype `self.dtype` and be in
     46 `[0, 1].` It must have a shape compatible with `self.batch_shape()`."""
     47 
     48 
     49 @tf_export("distributions.Beta")
     50 class Beta(distribution.Distribution):
     51   """Beta distribution.
     52 
     53   The Beta distribution is defined over the `(0, 1)` interval using parameters
     54   `concentration1` (aka "alpha") and `concentration0` (aka "beta").
     55 
     56   #### Mathematical Details
     57 
     58   The probability density function (pdf) is,
     59 
     60   ```none
     61   pdf(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z
     62   Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta)
     63   ```
     64 
     65   where:
     66 
     67   * `concentration1 = alpha`,
     68   * `concentration0 = beta`,
     69   * `Z` is the normalization constant, and,
     70   * `Gamma` is the [gamma function](
     71     https://en.wikipedia.org/wiki/Gamma_function).
     72 
     73   The concentration parameters represent mean total counts of a `1` or a `0`,
     74   i.e.,
     75 
     76   ```none
     77   concentration1 = alpha = mean * total_concentration
     78   concentration0 = beta  = (1. - mean) * total_concentration
     79   ```
     80 
     81   where `mean` in `(0, 1)` and `total_concentration` is a positive real number
     82   representing a mean `total_count = concentration1 + concentration0`.
     83 
     84   Distribution parameters are automatically broadcast in all functions; see
     85   examples for details.
     86 
     87   #### Examples
     88 
     89   ```python
     90   # Create a batch of three Beta distributions.
     91   alpha = [1, 2, 3]
     92   beta = [1, 2, 3]
     93   dist = Beta(alpha, beta)
     94 
     95   dist.sample([4, 5])  # Shape [4, 5, 3]
     96 
     97   # `x` has three batch entries, each with two samples.
     98   x = [[.1, .4, .5],
     99        [.2, .3, .5]]
    100   # Calculate the probability of each pair of samples under the corresponding
    101   # distribution in `dist`.
    102   dist.prob(x)         # Shape [2, 3]
    103   ```
    104 
    105   ```python
    106   # Create batch_shape=[2, 3] via parameter broadcast:
    107   alpha = [[1.], [2]]      # Shape [2, 1]
    108   beta = [3., 4, 5]        # Shape [3]
    109   dist = Beta(alpha, beta)
    110 
    111   # alpha broadcast as: [[1., 1, 1,],
    112   #                      [2, 2, 2]]
    113   # beta broadcast as:  [[3., 4, 5],
    114   #                      [3, 4, 5]]
    115   # batch_Shape [2, 3]
    116   dist.sample([4, 5])  # Shape [4, 5, 2, 3]
    117 
    118   x = [.2, .3, .5]
    119   # x will be broadcast as [[.2, .3, .5],
    120   #                         [.2, .3, .5]],
    121   # thus matching batch_shape [2, 3].
    122   dist.prob(x)         # Shape [2, 3]
    123   ```
    124 
    125   """
    126 
    127   def __init__(self,
    128                concentration1=None,
    129                concentration0=None,
    130                validate_args=False,
    131                allow_nan_stats=True,
    132                name="Beta"):
    133     """Initialize a batch of Beta distributions.
    134 
    135     Args:
    136       concentration1: Positive floating-point `Tensor` indicating mean
    137         number of successes; aka "alpha". Implies `self.dtype` and
    138         `self.batch_shape`, i.e.,
    139         `concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`.
    140       concentration0: Positive floating-point `Tensor` indicating mean
    141         number of failures; aka "beta". Otherwise has same semantics as
    142         `concentration1`.
    143       validate_args: Python `bool`, default `False`. When `True` distribution
    144         parameters are checked for validity despite possibly degrading runtime
    145         performance. When `False` invalid inputs may silently render incorrect
    146         outputs.
    147       allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
    148         (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
    149         result is undefined. When `False`, an exception is raised if one or
    150         more of the statistic's batch members are undefined.
    151       name: Python `str` name prefixed to Ops created by this class.
    152     """
    153     parameters = locals()
    154     with ops.name_scope(name, values=[concentration1, concentration0]):
    155       self._concentration1 = self._maybe_assert_valid_concentration(
    156           ops.convert_to_tensor(concentration1, name="concentration1"),
    157           validate_args)
    158       self._concentration0 = self._maybe_assert_valid_concentration(
    159           ops.convert_to_tensor(concentration0, name="concentration0"),
    160           validate_args)
    161       check_ops.assert_same_float_dtype([
    162           self._concentration1, self._concentration0])
    163       self._total_concentration = self._concentration1 + self._concentration0
    164     super(Beta, self).__init__(
    165         dtype=self._total_concentration.dtype,
    166         validate_args=validate_args,
    167         allow_nan_stats=allow_nan_stats,
    168         reparameterization_type=distribution.NOT_REPARAMETERIZED,
    169         parameters=parameters,
    170         graph_parents=[self._concentration1,
    171                        self._concentration0,
    172                        self._total_concentration],
    173         name=name)
    174 
    175   @staticmethod
    176   def _param_shapes(sample_shape):
    177     return dict(zip(
    178         ["concentration1", "concentration0"],
    179         [ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2))
    180 
    181   @property
    182   def concentration1(self):
    183     """Concentration parameter associated with a `1` outcome."""
    184     return self._concentration1
    185 
    186   @property
    187   def concentration0(self):
    188     """Concentration parameter associated with a `0` outcome."""
    189     return self._concentration0
    190 
    191   @property
    192   def total_concentration(self):
    193     """Sum of concentration parameters."""
    194     return self._total_concentration
    195 
    196   def _batch_shape_tensor(self):
    197     return array_ops.shape(self.total_concentration)
    198 
    199   def _batch_shape(self):
    200     return self.total_concentration.get_shape()
    201 
    202   def _event_shape_tensor(self):
    203     return constant_op.constant([], dtype=dtypes.int32)
    204 
    205   def _event_shape(self):
    206     return tensor_shape.scalar()
    207 
    208   def _sample_n(self, n, seed=None):
    209     expanded_concentration1 = array_ops.ones_like(
    210         self.total_concentration, dtype=self.dtype) * self.concentration1
    211     expanded_concentration0 = array_ops.ones_like(
    212         self.total_concentration, dtype=self.dtype) * self.concentration0
    213     gamma1_sample = random_ops.random_gamma(
    214         shape=[n],
    215         alpha=expanded_concentration1,
    216         dtype=self.dtype,
    217         seed=seed)
    218     gamma2_sample = random_ops.random_gamma(
    219         shape=[n],
    220         alpha=expanded_concentration0,
    221         dtype=self.dtype,
    222         seed=distribution_util.gen_new_seed(seed, "beta"))
    223     beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample)
    224     return beta_sample
    225 
    226   @distribution_util.AppendDocstring(_beta_sample_note)
    227   def _log_prob(self, x):
    228     return self._log_unnormalized_prob(x) - self._log_normalization()
    229 
    230   @distribution_util.AppendDocstring(_beta_sample_note)
    231   def _prob(self, x):
    232     return math_ops.exp(self._log_prob(x))
    233 
    234   @distribution_util.AppendDocstring(_beta_sample_note)
    235   def _log_cdf(self, x):
    236     return math_ops.log(self._cdf(x))
    237 
    238   @distribution_util.AppendDocstring(_beta_sample_note)
    239   def _cdf(self, x):
    240     return math_ops.betainc(self.concentration1, self.concentration0, x)
    241 
    242   def _log_unnormalized_prob(self, x):
    243     x = self._maybe_assert_valid_sample(x)
    244     return ((self.concentration1 - 1.) * math_ops.log(x)
    245             + (self.concentration0 - 1.) * math_ops.log1p(-x))
    246 
    247   def _log_normalization(self):
    248     return (math_ops.lgamma(self.concentration1)
    249             + math_ops.lgamma(self.concentration0)
    250             - math_ops.lgamma(self.total_concentration))
    251 
    252   def _entropy(self):
    253     return (
    254         self._log_normalization()
    255         - (self.concentration1 - 1.) * math_ops.digamma(self.concentration1)
    256         - (self.concentration0 - 1.) * math_ops.digamma(self.concentration0)
    257         + ((self.total_concentration - 2.) *
    258            math_ops.digamma(self.total_concentration)))
    259 
    260   def _mean(self):
    261     return self._concentration1 / self._total_concentration
    262 
    263   def _variance(self):
    264     return self._mean() * (1. - self._mean()) / (1. + self.total_concentration)
    265 
    266   @distribution_util.AppendDocstring(
    267       """Note: The mode is undefined when `concentration1 <= 1` or
    268       `concentration0 <= 1`. If `self.allow_nan_stats` is `True`, `NaN`
    269       is used for undefined modes. If `self.allow_nan_stats` is `False` an
    270       exception is raised when one or more modes are undefined.""")
    271   def _mode(self):
    272     mode = (self.concentration1 - 1.) / (self.total_concentration - 2.)
    273     if self.allow_nan_stats:
    274       nan = array_ops.fill(
    275           self.batch_shape_tensor(),
    276           np.array(np.nan, dtype=self.dtype.as_numpy_dtype()),
    277           name="nan")
    278       is_defined = math_ops.logical_and(self.concentration1 > 1.,
    279                                         self.concentration0 > 1.)
    280       return array_ops.where(is_defined, mode, nan)
    281     return control_flow_ops.with_dependencies([
    282         check_ops.assert_less(
    283             array_ops.ones([], dtype=self.dtype),
    284             self.concentration1,
    285             message="Mode undefined for concentration1 <= 1."),
    286         check_ops.assert_less(
    287             array_ops.ones([], dtype=self.dtype),
    288             self.concentration0,
    289             message="Mode undefined for concentration0 <= 1.")
    290     ], mode)
    291 
    292   def _maybe_assert_valid_concentration(self, concentration, validate_args):
    293     """Checks the validity of a concentration parameter."""
    294     if not validate_args:
    295       return concentration
    296     return control_flow_ops.with_dependencies([
    297         check_ops.assert_positive(
    298             concentration,
    299             message="Concentration parameter must be positive."),
    300     ], concentration)
    301 
    302   def _maybe_assert_valid_sample(self, x):
    303     """Checks the validity of a sample."""
    304     if not self.validate_args:
    305       return x
    306     return control_flow_ops.with_dependencies([
    307         check_ops.assert_positive(x, message="sample must be positive"),
    308         check_ops.assert_less(
    309             x,
    310             array_ops.ones([], self.dtype),
    311             message="sample must be less than `1`."),
    312     ], x)
    313 
    314 
    315 class BetaWithSoftplusConcentration(Beta):
    316   """Beta with softplus transform of `concentration1` and `concentration0`."""
    317 
    318   def __init__(self,
    319                concentration1,
    320                concentration0,
    321                validate_args=False,
    322                allow_nan_stats=True,
    323                name="BetaWithSoftplusConcentration"):
    324     parameters = locals()
    325     with ops.name_scope(name, values=[concentration1,
    326                                       concentration0]) as ns:
    327       super(BetaWithSoftplusConcentration, self).__init__(
    328           concentration1=nn.softplus(concentration1,
    329                                      name="softplus_concentration1"),
    330           concentration0=nn.softplus(concentration0,
    331                                      name="softplus_concentration0"),
    332           validate_args=validate_args,
    333           allow_nan_stats=allow_nan_stats,
    334           name=ns)
    335     self._parameters = parameters
    336 
    337 
    338 @kullback_leibler.RegisterKL(Beta, Beta)
    339 def _kl_beta_beta(d1, d2, name=None):
    340   """Calculate the batchwise KL divergence KL(d1 || d2) with d1 and d2 Beta.
    341 
    342   Args:
    343     d1: instance of a Beta distribution object.
    344     d2: instance of a Beta distribution object.
    345     name: (optional) Name to use for created operations.
    346       default is "kl_beta_beta".
    347 
    348   Returns:
    349     Batchwise KL(d1 || d2)
    350   """
    351   def delta(fn, is_property=True):
    352     fn1 = getattr(d1, fn)
    353     fn2 = getattr(d2, fn)
    354     return (fn2 - fn1) if is_property else (fn2() - fn1())
    355   with ops.name_scope(name, "kl_beta_beta", values=[
    356       d1.concentration1,
    357       d1.concentration0,
    358       d1.total_concentration,
    359       d2.concentration1,
    360       d2.concentration0,
    361       d2.total_concentration,
    362   ]):
    363     return (delta("_log_normalization", is_property=False)
    364             - math_ops.digamma(d1.concentration1) * delta("concentration1")
    365             - math_ops.digamma(d1.concentration0) * delta("concentration0")
    366             + (math_ops.digamma(d1.total_concentration)
    367                * delta("total_concentration")))
    368