Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Multivariate Normal distribution classes."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.distributions.python.ops import distribution_util
     22 from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop
     23 from tensorflow.python.framework import ops
     24 from tensorflow.python.ops import nn
     25 
     26 
     27 __all__ = [
     28     "MultivariateNormalDiag",
     29     "MultivariateNormalDiagWithSoftplusScale",
     30 ]
     31 
     32 
     33 class MultivariateNormalDiag(
     34     mvn_linop.MultivariateNormalLinearOperator):
     35   """The multivariate normal distribution on `R^k`.
     36 
     37   The Multivariate Normal distribution is defined over `R^k` and parameterized
     38   by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k`
     39   `scale` matrix; `covariance = scale @ scale.T` where `@` denotes
     40   matrix-multiplication.
     41 
     42   #### Mathematical Details
     43 
     44   The probability density function (pdf) is,
     45 
     46   ```none
     47   pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z,
     48   y = inv(scale) @ (x - loc),
     49   Z = (2 pi)**(0.5 k) |det(scale)|,
     50   ```
     51 
     52   where:
     53 
     54   * `loc` is a vector in `R^k`,
     55   * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`,
     56   * `Z` denotes the normalization constant, and,
     57   * `||y||**2` denotes the squared Euclidean norm of `y`.
     58 
     59   A (non-batch) `scale` matrix is:
     60 
     61   ```none
     62   scale = diag(scale_diag + scale_identity_multiplier * ones(k))
     63   ```
     64 
     65   where:
     66 
     67   * `scale_diag.shape = [k]`, and,
     68   * `scale_identity_multiplier.shape = []`.
     69 
     70   Additional leading dimensions (if any) will index batches.
     71 
     72   If both `scale_diag` and `scale_identity_multiplier` are `None`, then
     73   `scale` is the Identity matrix.
     74 
     75   The MultivariateNormal distribution is a member of the [location-scale
     76   family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
     77   constructed as,
     78 
     79   ```none
     80   X ~ MultivariateNormal(loc=0, scale=1)   # Identity scale, zero shift.
     81   Y = scale @ X + loc
     82   ```
     83 
     84   #### Examples
     85 
     86   ```python
     87   tfd = tf.contrib.distributions
     88 
     89   # Initialize a single 2-variate Gaussian.
     90   mvn = tfd.MultivariateNormalDiag(
     91       loc=[1., -1],
     92       scale_diag=[1, 2.])
     93 
     94   mvn.mean().eval()
     95   # ==> [1., -1]
     96 
     97   mvn.stddev().eval()
     98   # ==> [1., 2]
     99 
    100   # Evaluate this on an observation in `R^2`, returning a scalar.
    101   mvn.prob([-1., 0]).eval()  # shape: []
    102 
    103   # Initialize a 3-batch, 2-variate scaled-identity Gaussian.
    104   mvn = tfd.MultivariateNormalDiag(
    105       loc=[1., -1],
    106       scale_identity_multiplier=[1, 2., 3])
    107 
    108   mvn.mean().eval()  # shape: [3, 2]
    109   # ==> [[1., -1]
    110   #      [1, -1],
    111   #      [1, -1]]
    112 
    113   mvn.stddev().eval()  # shape: [3, 2]
    114   # ==> [[1., 1],
    115   #      [2, 2],
    116   #      [3, 3]]
    117 
    118   # Evaluate this on an observation in `R^2`, returning a length-3 vector.
    119   mvn.prob([-1., 0]).eval()  # shape: [3]
    120 
    121   # Initialize a 2-batch of 3-variate Gaussians.
    122   mvn = tfd.MultivariateNormalDiag(
    123       loc=[[1., 2, 3],
    124            [11, 22, 33]]           # shape: [2, 3]
    125       scale_diag=[[1., 2, 3],
    126                   [0.5, 1, 1.5]])  # shape: [2, 3]
    127 
    128   # Evaluate this on a two observations, each in `R^3`, returning a length-2
    129   # vector.
    130   x = [[-1., 0, 1],
    131        [-11, 0, 11.]]   # shape: [2, 3].
    132   mvn.prob(x).eval()    # shape: [2]
    133   ```
    134 
    135   """
    136 
    137   def __init__(self,
    138                loc=None,
    139                scale_diag=None,
    140                scale_identity_multiplier=None,
    141                validate_args=False,
    142                allow_nan_stats=True,
    143                name="MultivariateNormalDiag"):
    144     """Construct Multivariate Normal distribution on `R^k`.
    145 
    146     The `batch_shape` is the broadcast shape between `loc` and `scale`
    147     arguments.
    148 
    149     The `event_shape` is given by last dimension of the matrix implied by
    150     `scale`. The last dimension of `loc` (if provided) must broadcast with this.
    151 
    152     Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is:
    153 
    154     ```none
    155     scale = diag(scale_diag + scale_identity_multiplier * ones(k))
    156     ```
    157 
    158     where:
    159 
    160     * `scale_diag.shape = [k]`, and,
    161     * `scale_identity_multiplier.shape = []`.
    162 
    163     Additional leading dimensions (if any) will index batches.
    164 
    165     If both `scale_diag` and `scale_identity_multiplier` are `None`, then
    166     `scale` is the Identity matrix.
    167 
    168     Args:
    169       loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
    170         implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
    171         `b >= 0` and `k` is the event size.
    172       scale_diag: Non-zero, floating-point `Tensor` representing a diagonal
    173         matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`,
    174         and characterizes `b`-batches of `k x k` diagonal matrices added to
    175         `scale`. When both `scale_identity_multiplier` and `scale_diag` are
    176         `None` then `scale` is the `Identity`.
    177       scale_identity_multiplier: Non-zero, floating-point `Tensor` representing
    178         a scaled-identity-matrix added to `scale`. May have shape
    179         `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scaled
    180         `k x k` identity matrices added to `scale`. When both
    181         `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is
    182         the `Identity`.
    183       validate_args: Python `bool`, default `False`. When `True` distribution
    184         parameters are checked for validity despite possibly degrading runtime
    185         performance. When `False` invalid inputs may silently render incorrect
    186         outputs.
    187       allow_nan_stats: Python `bool`, default `True`. When `True`,
    188         statistics (e.g., mean, mode, variance) use the value "`NaN`" to
    189         indicate the result is undefined. When `False`, an exception is raised
    190         if one or more of the statistic's batch members are undefined.
    191       name: Python `str` name prefixed to Ops created by this class.
    192 
    193     Raises:
    194       ValueError: if at most `scale_identity_multiplier` is specified.
    195     """
    196     parameters = locals()
    197     with ops.name_scope(name):
    198       with ops.name_scope("init", values=[
    199           loc, scale_diag, scale_identity_multiplier]):
    200         # No need to validate_args while making diag_scale.  The returned
    201         # LinearOperatorDiag has an assert_non_singular method that is called by
    202         # the Bijector.
    203         scale = distribution_util.make_diag_scale(
    204             loc=loc,
    205             scale_diag=scale_diag,
    206             scale_identity_multiplier=scale_identity_multiplier,
    207             validate_args=False,
    208             assert_positive=False)
    209     super(MultivariateNormalDiag, self).__init__(
    210         loc=loc,
    211         scale=scale,
    212         validate_args=validate_args,
    213         allow_nan_stats=allow_nan_stats,
    214         name=name)
    215     self._parameters = parameters
    216 
    217 
    218 class MultivariateNormalDiagWithSoftplusScale(MultivariateNormalDiag):
    219   """MultivariateNormalDiag with `diag_stddev = softplus(diag_stddev)`."""
    220 
    221   def __init__(self,
    222                loc,
    223                scale_diag,
    224                validate_args=False,
    225                allow_nan_stats=True,
    226                name="MultivariateNormalDiagWithSoftplusScale"):
    227     parameters = locals()
    228     with ops.name_scope(name, values=[scale_diag]):
    229       super(MultivariateNormalDiagWithSoftplusScale, self).__init__(
    230           loc=loc,
    231           scale_diag=nn.softplus(scale_diag),
    232           validate_args=validate_args,
    233           allow_nan_stats=allow_nan_stats,
    234           name=name)
    235     self._parameters = parameters
    236