Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """The same-family Mixture distribution class."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.contrib.distributions.python.ops import distribution_util as distribution_utils
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.ops import control_flow_ops
     27 from tensorflow.python.ops import math_ops
     28 from tensorflow.python.ops import nn_ops
     29 from tensorflow.python.ops.distributions import distribution
     30 from tensorflow.python.ops.distributions import util as distribution_util
     31 
     32 
     33 class MixtureSameFamily(distribution.Distribution):
     34   """Mixture (same-family) distribution.
     35 
     36   The `MixtureSameFamily` distribution implements a (batch of) mixture
     37   distribution where all components are from different parameterizations of the
     38   same distribution type. It is parameterized by a `Categorical` "selecting
     39   distribution" (over `k` components) and a components distribution, i.e., a
     40   `Distribution` with a rightmost batch shape (equal to `[k]`) which indexes
     41   each (batch of) component.
     42 
     43   #### Examples
     44 
     45   ```python
     46   tfd = tf.contrib.distributions
     47 
     48   ### Create a mixture of two scalar Gaussians:
     49 
     50   gm = tfd.MixtureSameFamily(
     51       mixture_distribution=tfd.Categorical(
     52           probs=[0.3, 0.7]),
     53       components_distribution=tfd.Normal(
     54         loc=[-1., 1],       # One for each component.
     55         scale=[0.1, 0.5]))  # And same here.
     56 
     57   gm.mean()
     58   # ==> 0.4
     59 
     60   gm.variance()
     61   # ==> 1.018
     62 
     63   # Plot PDF.
     64   x = np.linspace(-2., 3., int(1e4), dtype=np.float32)
     65   import matplotlib.pyplot as plt
     66   plt.plot(x, gm.prob(x).eval());
     67 
     68   ### Create a mixture of two Bivariate Gaussians:
     69 
     70   gm = tfd.MixtureSameFamily(
     71       mixture_distribution=tfd.Categorical(
     72           probs=[0.3, 0.7]),
     73       components_distribution=tfd.MultivariateNormalDiag(
     74           loc=[[-1., 1],  # component 1
     75                [1, -1]],  # component 2
     76           scale_identity_multiplier=[.3, .6]))
     77 
     78   gm.mean()
     79   # ==> array([ 0.4, -0.4], dtype=float32)
     80 
     81   gm.covariance()
     82   # ==> array([[ 1.119, -0.84],
     83   #            [-0.84,  1.119]], dtype=float32)
     84 
     85   # Plot PDF contours.
     86   def meshgrid(x, y=x):
     87     [gx, gy] = np.meshgrid(x, y, indexing='ij')
     88     gx, gy = np.float32(gx), np.float32(gy)
     89     grid = np.concatenate([gx.ravel()[None, :], gy.ravel()[None, :]], axis=0)
     90     return grid.T.reshape(x.size, y.size, 2)
     91   grid = meshgrid(np.linspace(-2, 2, 100, dtype=np.float32))
     92   plt.contour(grid[..., 0], grid[..., 1], gm.prob(grid).eval());
     93 
     94   ```
     95 
     96   """
     97 
     98   def __init__(self,
     99                mixture_distribution,
    100                components_distribution,
    101                validate_args=False,
    102                allow_nan_stats=True,
    103                name="MixtureSameFamily"):
    104     """Construct a `MixtureSameFamily` distribution.
    105 
    106     Args:
    107       mixture_distribution: `tf.distributions.Categorical`-like instance.
    108         Manages the probability of selecting components. The number of
    109         categories must match the rightmost batch dimension of the
    110         `components_distribution`. Must have either scalar `batch_shape` or
    111         `batch_shape` matching `components_distribution.batch_shape[:-1]`.
    112       components_distribution: `tf.distributions.Distribution`-like instance.
    113         Right-most batch dimension indexes components.
    114       validate_args: Python `bool`, default `False`. When `True` distribution
    115         parameters are checked for validity despite possibly degrading runtime
    116         performance. When `False` invalid inputs may silently render incorrect
    117         outputs.
    118       allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
    119         (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
    120         result is undefined. When `False`, an exception is raised if one or
    121         more of the statistic's batch members are undefined.
    122       name: Python `str` name prefixed to Ops created by this class.
    123 
    124     Raises:
    125       ValueError: `if not mixture_distribution.dtype.is_integer`.
    126       ValueError: if mixture_distribution does not have scalar `event_shape`.
    127       ValueError: if `mixture_distribution.batch_shape` and
    128         `components_distribution.batch_shape[:-1]` are both fully defined and
    129         the former is neither scalar nor equal to the latter.
    130       ValueError: if `mixture_distribution` categories does not equal
    131         `components_distribution` rightmost batch shape.
    132     """
    133     parameters = locals()
    134     with ops.name_scope(name):
    135       self._mixture_distribution = mixture_distribution
    136       self._components_distribution = components_distribution
    137       self._runtime_assertions = []
    138 
    139       s = components_distribution.event_shape_tensor()
    140       self._event_ndims = (s.shape[0].value
    141                            if s.shape.with_rank_at_least(1)[0].value is not None
    142                            else array_ops.shape(s)[0])
    143 
    144       if not mixture_distribution.dtype.is_integer:
    145         raise ValueError(
    146             "`mixture_distribution.dtype` ({}) is not over integers".format(
    147                 mixture_distribution.dtype.name))
    148 
    149       if (mixture_distribution.event_shape.ndims is not None
    150           and mixture_distribution.event_shape.ndims != 0):
    151         raise ValueError("`mixture_distribution` must have scalar `event_dim`s")
    152       elif validate_args:
    153         self._runtime_assertions += [
    154             control_flow_ops.assert_has_rank(
    155                 mixture_distribution.event_shape_tensor(), 0,
    156                 message="`mixture_distribution` must have scalar `event_dim`s"),
    157         ]
    158 
    159       mdbs = mixture_distribution.batch_shape
    160       cdbs = components_distribution.batch_shape.with_rank_at_least(1)[:-1]
    161       if mdbs.is_fully_defined() and cdbs.is_fully_defined():
    162         if mdbs.ndims != 0 and mdbs != cdbs:
    163           raise ValueError(
    164               "`mixture_distribution.batch_shape` (`{}`) is not "
    165               "compatible with `components_distribution.batch_shape` "
    166               "(`{}`)".format(mdbs.as_list(), cdbs.as_list()))
    167       elif validate_args:
    168         mdbs = mixture_distribution.batch_shape_tensor()
    169         cdbs = components_distribution.batch_shape_tensor()[:-1]
    170         self._runtime_assertions += [
    171             control_flow_ops.assert_equal(
    172                 distribution_util.pick_vector(
    173                     mixture_distribution.is_scalar_batch(), cdbs, mdbs),
    174                 cdbs,
    175                 message=(
    176                     "`mixture_distribution.batch_shape` is not "
    177                     "compatible with `components_distribution.batch_shape`"))]
    178 
    179       km = mixture_distribution.logits.shape.with_rank_at_least(1)[-1].value
    180       kc = components_distribution.batch_shape.with_rank_at_least(1)[-1].value
    181       if km is not None and kc is not None and km != kc:
    182         raise ValueError("`mixture_distribution components` ({}) does not "
    183                          "equal `components_distribution.batch_shape[-1]` "
    184                          "({})".format(km, kc))
    185       elif validate_args:
    186         km = array_ops.shape(mixture_distribution.logits)[-1]
    187         kc = components_distribution.batch_shape_tensor()[-1]
    188         self._runtime_assertions += [
    189             control_flow_ops.assert_equal(
    190                 km, kc,
    191                 message=("`mixture_distribution components` does not equal "
    192                          "`components_distribution.batch_shape[-1:]`")),
    193         ]
    194       elif km is None:
    195         km = array_ops.shape(mixture_distribution.logits)[-1]
    196 
    197       self._num_components = km
    198 
    199       super(MixtureSameFamily, self).__init__(
    200           dtype=self._components_distribution.dtype,
    201           reparameterization_type=distribution.NOT_REPARAMETERIZED,
    202           validate_args=validate_args,
    203           allow_nan_stats=allow_nan_stats,
    204           parameters=parameters,
    205           graph_parents=(
    206               self._mixture_distribution._graph_parents  # pylint: disable=protected-access
    207               + self._components_distribution._graph_parents),  # pylint: disable=protected-access
    208           name=name)
    209 
    210   @property
    211   def mixture_distribution(self):
    212     return self._mixture_distribution
    213 
    214   @property
    215   def components_distribution(self):
    216     return self._components_distribution
    217 
    218   def _batch_shape_tensor(self):
    219     with ops.control_dependencies(self._runtime_assertions):
    220       return self.components_distribution.batch_shape_tensor()[:-1]
    221 
    222   def _batch_shape(self):
    223     return self.components_distribution.batch_shape.with_rank_at_least(1)[:-1]
    224 
    225   def _event_shape_tensor(self):
    226     with ops.control_dependencies(self._runtime_assertions):
    227       return self.components_distribution.event_shape_tensor()
    228 
    229   def _event_shape(self):
    230     return self.components_distribution.event_shape
    231 
    232   def _sample_n(self, n, seed):
    233     with ops.control_dependencies(self._runtime_assertions):
    234       x = self.components_distribution.sample(n)             # [n, B, k, E]
    235       # TODO(jvdillon): Consider using tf.gather (by way of index unrolling).
    236       npdt = x.dtype.as_numpy_dtype
    237       mask = array_ops.one_hot(
    238           indices=self.mixture_distribution.sample(n),       # [n, B]
    239           depth=self._num_components,                        # == k
    240           on_value=np.ones([], dtype=npdt),
    241           off_value=np.zeros([], dtype=npdt))                # [n, B, k]
    242       mask = distribution_utils.pad_mixture_dimensions(
    243           mask, self, self.mixture_distribution,
    244           self._event_shape().ndims)                         # [n, B, k, [1]*e]
    245       return math_ops.reduce_sum(
    246           x * mask, axis=-1 - self._event_ndims)             # [n, B, E]
    247 
    248   def _log_prob(self, x):
    249     with ops.control_dependencies(self._runtime_assertions):
    250       x = self._pad_sample_dims(x)
    251       log_prob_x = self.components_distribution.log_prob(x)  # [S, B, k]
    252       log_mix_prob = nn_ops.log_softmax(
    253           self.mixture_distribution.logits, axis=-1)         # [B, k]
    254       return math_ops.reduce_logsumexp(
    255           log_prob_x + log_mix_prob, axis=-1)                # [S, B]
    256 
    257   def _mean(self):
    258     with ops.control_dependencies(self._runtime_assertions):
    259       probs = distribution_utils.pad_mixture_dimensions(
    260           self.mixture_distribution.probs, self, self.mixture_distribution,
    261           self._event_shape().ndims)                         # [B, k, [1]*e]
    262       return math_ops.reduce_sum(
    263           probs * self.components_distribution.mean(),
    264           axis=-1 - self._event_ndims)                       # [B, E]
    265 
    266   def _log_cdf(self, x):
    267     x = self._pad_sample_dims(x)
    268     log_cdf_x = self.components_distribution.log_cdf(x)      # [S, B, k]
    269     log_mix_prob = nn_ops.log_softmax(
    270         self.mixture_distribution.logits, axis=-1)           # [B, k]
    271     return math_ops.reduce_logsumexp(
    272         log_cdf_x + log_mix_prob, axis=-1)                   # [S, B]
    273 
    274   def _variance(self):
    275     with ops.control_dependencies(self._runtime_assertions):
    276       # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
    277       probs = distribution_utils.pad_mixture_dimensions(
    278           self.mixture_distribution.probs, self, self.mixture_distribution,
    279           self._event_shape().ndims)                         # [B, k, [1]*e]
    280       mean_cond_var = math_ops.reduce_sum(
    281           probs * self.components_distribution.variance(),
    282           axis=-1 - self._event_ndims)                       # [B, E]
    283       var_cond_mean = math_ops.reduce_sum(
    284           probs * math_ops.squared_difference(
    285               self.components_distribution.mean(),
    286               self._pad_sample_dims(self._mean())),
    287           axis=-1 - self._event_ndims)                       # [B, E]
    288       return mean_cond_var + var_cond_mean                   # [B, E]
    289 
    290   def _covariance(self):
    291     static_event_ndims = self.event_shape.ndims
    292     if static_event_ndims != 1:
    293       # Covariance is defined only for vector distributions.
    294       raise NotImplementedError("covariance is not implemented")
    295 
    296     with ops.control_dependencies(self._runtime_assertions):
    297       # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
    298       probs = distribution_utils.pad_mixture_dimensions(
    299           distribution_utils.pad_mixture_dimensions(
    300               self.mixture_distribution.probs, self, self.mixture_distribution,
    301               self._event_shape().ndims),
    302           self, self.mixture_distribution,
    303           self._event_shape().ndims)                         # [B, k, 1, 1]
    304       mean_cond_var = math_ops.reduce_sum(
    305           probs * self.components_distribution.covariance(),
    306           axis=-3)                                           # [B, e, e]
    307       var_cond_mean = math_ops.reduce_sum(
    308           probs * _outer_squared_difference(
    309               self.components_distribution.mean(),
    310               self._pad_sample_dims(self._mean())),
    311           axis=-3)                                           # [B, e, e]
    312       return mean_cond_var + var_cond_mean                   # [B, e, e]
    313 
    314   def _pad_sample_dims(self, x):
    315     with ops.name_scope("pad_sample_dims", values=[x]):
    316       ndims = x.shape.ndims if x.shape.ndims is not None else array_ops.rank(x)
    317       shape = array_ops.shape(x)
    318       d = ndims - self._event_ndims
    319       x = array_ops.reshape(x, shape=array_ops.concat([
    320           shape[:d], [1], shape[d:]], axis=0))
    321       return x
    322 
    323 
    324 def _outer_squared_difference(x, y):
    325   """Convenience function analogous to tf.squared_difference."""
    326   z = x - y
    327   return z[..., array_ops.newaxis, :] * z[..., array_ops.newaxis]
    328