Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """The Independent distribution class."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.framework import tensor_shape
     26 from tensorflow.python.framework import tensor_util
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.ops import check_ops
     29 from tensorflow.python.ops import math_ops
     30 from tensorflow.python.ops.distributions import distribution as distribution_lib
     31 from tensorflow.python.ops.distributions import kullback_leibler
     32 from tensorflow.python.util import deprecation
     33 
     34 
     35 class Independent(distribution_lib.Distribution):
     36   """Independent distribution from batch of distributions.
     37 
     38   This distribution is useful for regarding a collection of independent,
     39   non-identical distributions as a single random variable. For example, the
     40   `Independent` distribution composed of a collection of `Bernoulli`
     41   distributions might define a distribution over an image (where each
     42   `Bernoulli` is a distribution over each pixel).
     43 
     44   More precisely, a collection of `B` (independent) `E`-variate random variables
     45   (rv) `{X_1, ..., X_B}`, can be regarded as a `[B, E]`-variate random variable
     46   `(X_1, ..., X_B)` with probability
     47   `p(x_1, ..., x_B) = p_1(x_1) * ... * p_B(x_B)` where `p_b(X_b)` is the
     48   probability of the `b`-th rv. More generally `B, E` can be arbitrary shapes.
     49 
     50   Similarly, the `Independent` distribution specifies a distribution over `[B,
     51   E]`-shaped events. It operates by reinterpreting the rightmost batch dims as
     52   part of the event dimensions. The `reinterpreted_batch_ndims` parameter
     53   controls the number of batch dims which are absorbed as event dims;
     54   `reinterpreted_batch_ndims < len(batch_shape)`.  For example, the `log_prob`
     55   function entails a `reduce_sum` over the rightmost `reinterpreted_batch_ndims`
     56   after calling the base distribution's `log_prob`.  In other words, since the
     57   batch dimension(s) index independent distributions, the resultant multivariate
     58   will have independent components.
     59 
     60   #### Mathematical Details
     61 
     62   The probability function is,
     63 
     64   ```none
     65   prob(x; reinterpreted_batch_ndims) = tf.reduce_prod(
     66       dist.prob(x),
     67       axis=-1-range(reinterpreted_batch_ndims))
     68   ```
     69 
     70   #### Examples
     71 
     72   ```python
     73   import tensorflow_probability as tfp
     74   tfd = tfp.distributions
     75 
     76   # Make independent distribution from a 2-batch Normal.
     77   ind = tfd.Independent(
     78       distribution=tfd.Normal(loc=[-1., 1], scale=[0.1, 0.5]),
     79       reinterpreted_batch_ndims=1)
     80 
     81   # All batch dims have been "absorbed" into event dims.
     82   ind.batch_shape  # ==> []
     83   ind.event_shape  # ==> [2]
     84 
     85   # Make independent distribution from a 2-batch bivariate Normal.
     86   ind = tfd.Independent(
     87       distribution=tfd.MultivariateNormalDiag(
     88           loc=[[-1., 1], [1, -1]],
     89           scale_identity_multiplier=[1., 0.5]),
     90       reinterpreted_batch_ndims=1)
     91 
     92   # All batch dims have been "absorbed" into event dims.
     93   ind.batch_shape  # ==> []
     94   ind.event_shape  # ==> [2, 2]
     95   ```
     96 
     97   """
     98 
     99   @deprecation.deprecated(
    100       "2018-10-01",
    101       "The TensorFlow Distributions library has moved to "
    102       "TensorFlow Probability "
    103       "(https://github.com/tensorflow/probability). You "
    104       "should update all references to use `tfp.distributions` "
    105       "instead of `tf.contrib.distributions`.",
    106       warn_once=True)
    107   def __init__(
    108       self, distribution, reinterpreted_batch_ndims=None,
    109       validate_args=False, name=None):
    110     """Construct a `Independent` distribution.
    111 
    112     Args:
    113       distribution: The base distribution instance to transform. Typically an
    114         instance of `Distribution`.
    115       reinterpreted_batch_ndims: Scalar, integer number of rightmost batch dims
    116         which will be regarded as event dims. When `None` all but the first
    117         batch axis (batch axis 0) will be transferred to event dimensions
    118         (analogous to `tf.layers.flatten`).
    119       validate_args: Python `bool`.  Whether to validate input with asserts.
    120         If `validate_args` is `False`, and the inputs are invalid,
    121         correct behavior is not guaranteed.
    122       name: The name for ops managed by the distribution.
    123         Default value: `Independent + distribution.name`.
    124 
    125     Raises:
    126       ValueError: if `reinterpreted_batch_ndims` exceeds
    127         `distribution.batch_ndims`
    128     """
    129     parameters = dict(locals())
    130     name = name or "Independent" + distribution.name
    131     self._distribution = distribution
    132     with ops.name_scope(name) as name:
    133       if reinterpreted_batch_ndims is None:
    134         reinterpreted_batch_ndims = self._get_default_reinterpreted_batch_ndims(
    135             distribution)
    136       reinterpreted_batch_ndims = ops.convert_to_tensor(
    137           reinterpreted_batch_ndims,
    138           dtype=dtypes.int32,
    139           name="reinterpreted_batch_ndims")
    140       self._reinterpreted_batch_ndims = reinterpreted_batch_ndims
    141       self._static_reinterpreted_batch_ndims = tensor_util.constant_value(
    142           reinterpreted_batch_ndims)
    143       if self._static_reinterpreted_batch_ndims is not None:
    144         self._reinterpreted_batch_ndims = self._static_reinterpreted_batch_ndims
    145       super(Independent, self).__init__(
    146           dtype=self._distribution.dtype,
    147           reparameterization_type=self._distribution.reparameterization_type,
    148           validate_args=validate_args,
    149           allow_nan_stats=self._distribution.allow_nan_stats,
    150           parameters=parameters,
    151           graph_parents=(
    152               [reinterpreted_batch_ndims] +
    153               distribution._graph_parents),  # pylint: disable=protected-access
    154           name=name)
    155       self._runtime_assertions = self._make_runtime_assertions(
    156           distribution, reinterpreted_batch_ndims, validate_args)
    157 
    158   @property
    159   def distribution(self):
    160     return self._distribution
    161 
    162   @property
    163   def reinterpreted_batch_ndims(self):
    164     return self._reinterpreted_batch_ndims
    165 
    166   def _batch_shape_tensor(self):
    167     with ops.control_dependencies(self._runtime_assertions):
    168       batch_shape = self.distribution.batch_shape_tensor()
    169       dim0 = tensor_shape.dimension_value(
    170           batch_shape.shape.with_rank_at_least(1)[0])
    171       batch_ndims = (dim0
    172                      if dim0 is not None
    173                      else array_ops.shape(batch_shape)[0])
    174       return batch_shape[:batch_ndims - self.reinterpreted_batch_ndims]
    175 
    176   def _batch_shape(self):
    177     batch_shape = self.distribution.batch_shape
    178     if (self._static_reinterpreted_batch_ndims is None
    179         or batch_shape.ndims is None):
    180       return tensor_shape.TensorShape(None)
    181     d = batch_shape.ndims - self._static_reinterpreted_batch_ndims
    182     return batch_shape[:d]
    183 
    184   def _event_shape_tensor(self):
    185     with ops.control_dependencies(self._runtime_assertions):
    186       batch_shape = self.distribution.batch_shape_tensor()
    187       dim0 = tensor_shape.dimension_value(
    188           batch_shape.shape.with_rank_at_least(1)[0])
    189       batch_ndims = (dim0
    190                      if dim0 is not None
    191                      else array_ops.shape(batch_shape)[0])
    192       return array_ops.concat([
    193           batch_shape[batch_ndims - self.reinterpreted_batch_ndims:],
    194           self.distribution.event_shape_tensor(),
    195       ], axis=0)
    196 
    197   def _event_shape(self):
    198     batch_shape = self.distribution.batch_shape
    199     if (self._static_reinterpreted_batch_ndims is None
    200         or batch_shape.ndims is None):
    201       return tensor_shape.TensorShape(None)
    202     d = batch_shape.ndims - self._static_reinterpreted_batch_ndims
    203     return batch_shape[d:].concatenate(self.distribution.event_shape)
    204 
    205   def _sample_n(self, n, seed):
    206     with ops.control_dependencies(self._runtime_assertions):
    207       return self.distribution.sample(sample_shape=n, seed=seed)
    208 
    209   def _log_prob(self, x):
    210     with ops.control_dependencies(self._runtime_assertions):
    211       return self._reduce_sum(self.distribution.log_prob(x))
    212 
    213   def _entropy(self):
    214     with ops.control_dependencies(self._runtime_assertions):
    215       return self._reduce_sum(self.distribution.entropy())
    216 
    217   def _mean(self):
    218     with ops.control_dependencies(self._runtime_assertions):
    219       return self.distribution.mean()
    220 
    221   def _variance(self):
    222     with ops.control_dependencies(self._runtime_assertions):
    223       return self.distribution.variance()
    224 
    225   def _stddev(self):
    226     with ops.control_dependencies(self._runtime_assertions):
    227       return self.distribution.stddev()
    228 
    229   def _mode(self):
    230     with ops.control_dependencies(self._runtime_assertions):
    231       return self.distribution.mode()
    232 
    233   def _make_runtime_assertions(
    234       self, distribution, reinterpreted_batch_ndims, validate_args):
    235     assertions = []
    236     static_reinterpreted_batch_ndims = tensor_util.constant_value(
    237         reinterpreted_batch_ndims)
    238     batch_ndims = distribution.batch_shape.ndims
    239     if batch_ndims is not None and static_reinterpreted_batch_ndims is not None:
    240       if static_reinterpreted_batch_ndims > batch_ndims:
    241         raise ValueError("reinterpreted_batch_ndims({}) cannot exceed "
    242                          "distribution.batch_ndims({})".format(
    243                              static_reinterpreted_batch_ndims, batch_ndims))
    244     elif validate_args:
    245       batch_shape = distribution.batch_shape_tensor()
    246       dim0 = tensor_shape.dimension_value(
    247           batch_shape.shape.with_rank_at_least(1)[0])
    248       batch_ndims = (
    249           dim0
    250           if dim0 is not None
    251           else array_ops.shape(batch_shape)[0])
    252       assertions.append(check_ops.assert_less_equal(
    253           reinterpreted_batch_ndims, batch_ndims,
    254           message=("reinterpreted_batch_ndims cannot exceed "
    255                    "distribution.batch_ndims")))
    256     return assertions
    257 
    258   def _reduce_sum(self, stat):
    259     if self._static_reinterpreted_batch_ndims is None:
    260       range_ = math_ops.range(self._reinterpreted_batch_ndims)
    261     else:
    262       range_ = np.arange(self._static_reinterpreted_batch_ndims)
    263     return math_ops.reduce_sum(stat, axis=-1-range_)
    264 
    265   def _get_default_reinterpreted_batch_ndims(self, distribution):
    266     """Computes the default value for reinterpreted_batch_ndim __init__ arg."""
    267     ndims = distribution.batch_shape.ndims
    268     if ndims is None:
    269       which_maximum = math_ops.maximum
    270       ndims = array_ops.shape(distribution.batch_shape_tensor())[0]
    271     else:
    272       which_maximum = np.maximum
    273     return which_maximum(0, ndims - 1)
    274 
    275 
    276 @kullback_leibler.RegisterKL(Independent, Independent)
    277 @deprecation.deprecated(
    278     "2018-10-01",
    279     "The TensorFlow Distributions library has moved to "
    280     "TensorFlow Probability "
    281     "(https://github.com/tensorflow/probability). You "
    282     "should update all references to use `tfp.distributions` "
    283     "instead of `tf.contrib.distributions`.",
    284     warn_once=True)
    285 def _kl_independent(a, b, name="kl_independent"):
    286   """Batched KL divergence `KL(a || b)` for Independent distributions.
    287 
    288   We can leverage the fact that
    289   ```
    290   KL(Independent(a) || Independent(b)) = sum(KL(a || b))
    291   ```
    292   where the sum is over the `reinterpreted_batch_ndims`.
    293 
    294   Args:
    295     a: Instance of `Independent`.
    296     b: Instance of `Independent`.
    297     name: (optional) name to use for created ops. Default "kl_independent".
    298 
    299   Returns:
    300     Batchwise `KL(a || b)`.
    301 
    302   Raises:
    303     ValueError: If the event space for `a` and `b`, or their underlying
    304       distributions don't match.
    305   """
    306   p = a.distribution
    307   q = b.distribution
    308 
    309   # The KL between any two (non)-batched distributions is a scalar.
    310   # Given that the KL between two factored distributions is the sum, i.e.
    311   # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute
    312   # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions.
    313   if a.event_shape.is_fully_defined() and b.event_shape.is_fully_defined():
    314     if a.event_shape == b.event_shape:
    315       if p.event_shape == q.event_shape:
    316         num_reduce_dims = a.event_shape.ndims - p.event_shape.ndims
    317         reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)]
    318 
    319         return math_ops.reduce_sum(
    320             kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims)
    321       else:
    322         raise NotImplementedError("KL between Independents with different "
    323                                   "event shapes not supported.")
    324     else:
    325       raise ValueError("Event shapes do not match.")
    326   else:
    327     with ops.control_dependencies([
    328         check_ops.assert_equal(a.event_shape_tensor(), b.event_shape_tensor()),
    329         check_ops.assert_equal(p.event_shape_tensor(), q.event_shape_tensor())
    330     ]):
    331       num_reduce_dims = (
    332           array_ops.shape(a.event_shape_tensor()[0]) -
    333           array_ops.shape(p.event_shape_tensor()[0]))
    334       reduce_dims = math_ops.range(-num_reduce_dims - 1, -1, 1)
    335       return math_ops.reduce_sum(
    336           kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims)
    337