Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Multivariate Normal distribution class initialized with a full covariance."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.distributions.python.ops import mvn_tril
     22 from tensorflow.python.framework import ops
     23 from tensorflow.python.ops import array_ops
     24 from tensorflow.python.ops import check_ops
     25 from tensorflow.python.ops import control_flow_ops
     26 from tensorflow.python.ops import linalg_ops
     27 
     28 
     29 __all__ = [
     30     "MultivariateNormalFullCovariance",
     31 ]
     32 
     33 
     34 class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL):
     35   """The multivariate normal distribution on `R^k`.
     36 
     37   The Multivariate Normal distribution is defined over `R^k` and parameterized
     38   by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k`
     39   `covariance_matrix` matrices that are the covariance.
     40   This is different than the other multivariate normals, which are parameterized
     41   by a matrix more akin to the standard deviation.
     42 
     43   #### Mathematical Details
     44 
     45   The probability density function (pdf) is, with `@` as matrix multiplication,
     46 
     47   ```none
     48   pdf(x; loc, covariance_matrix) = exp(-0.5 ||y||**2) / Z,
     49   y = (x - loc)^T @ inv(covariance_matrix) @ (x - loc)
     50   Z = (2 pi)**(0.5 k) |det(covariance_matrix)|**(0.5).
     51   ```
     52 
     53   where:
     54 
     55   * `loc` is a vector in `R^k`,
     56   * `covariance_matrix` is an `R^{k x k}` symmetric positive definite matrix,
     57   * `Z` denotes the normalization constant, and,
     58   * `||y||**2` denotes the squared Euclidean norm of `y`.
     59 
     60   Additional leading dimensions (if any) in `loc` and `covariance_matrix` allow
     61   for batch dimensions.
     62 
     63   The MultivariateNormal distribution is a member of the [location-scale
     64   family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
     65   constructed e.g. as,
     66 
     67   ```none
     68   X ~ MultivariateNormal(loc=0, scale=1)   # Identity scale, zero shift.
     69   scale = Cholesky(covariance_matrix)
     70   Y = scale @ X + loc
     71   ```
     72 
     73   #### Examples
     74 
     75   ```python
     76   tfd = tf.contrib.distributions
     77 
     78   # Initialize a single 3-variate Gaussian.
     79   mu = [1., 2, 3]
     80   cov = [[ 0.36,  0.12,  0.06],
     81          [ 0.12,  0.29, -0.13],
     82          [ 0.06, -0.13,  0.26]]
     83   mvn = tfd.MultivariateNormalFullCovariance(
     84       loc=mu,
     85       covariance_matrix=cov)
     86 
     87   mvn.mean().eval()
     88   # ==> [1., 2, 3]
     89 
     90   # Covariance agrees with covariance_matrix.
     91   mvn.covariance().eval()
     92   # ==> [[ 0.36,  0.12,  0.06],
     93   #      [ 0.12,  0.29, -0.13],
     94   #      [ 0.06, -0.13,  0.26]]
     95 
     96   # Compute the pdf of an observation in `R^3` ; return a scalar.
     97   mvn.prob([-1., 0, 1]).eval()  # shape: []
     98 
     99   # Initialize a 2-batch of 3-variate Gaussians.
    100   mu = [[1., 2, 3],
    101         [11, 22, 33]]              # shape: [2, 3]
    102   covariance_matrix = ...  # shape: [2, 3, 3], symmetric, positive definite.
    103   mvn = tfd.MultivariateNormalFullCovariance(
    104       loc=mu,
    105       covariance=covariance_matrix)
    106 
    107   # Compute the pdf of two `R^3` observations; return a length-2 vector.
    108   x = [[-0.9, 0, 0.1],
    109        [-10, 0, 9]]     # shape: [2, 3]
    110   mvn.prob(x).eval()    # shape: [2]
    111 
    112   ```
    113 
    114   """
    115 
    116   def __init__(self,
    117                loc=None,
    118                covariance_matrix=None,
    119                validate_args=False,
    120                allow_nan_stats=True,
    121                name="MultivariateNormalFullCovariance"):
    122     """Construct Multivariate Normal distribution on `R^k`.
    123 
    124     The `batch_shape` is the broadcast shape between `loc` and
    125     `covariance_matrix` arguments.
    126 
    127     The `event_shape` is given by last dimension of the matrix implied by
    128     `covariance_matrix`. The last dimension of `loc` (if provided) must
    129     broadcast with this.
    130 
    131     A non-batch `covariance_matrix` matrix is a `k x k` symmetric positive
    132     definite matrix.  In other words it is (real) symmetric with all eigenvalues
    133     strictly positive.
    134 
    135     Additional leading dimensions (if any) will index batches.
    136 
    137     Args:
    138       loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
    139         implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
    140         `b >= 0` and `k` is the event size.
    141       covariance_matrix: Floating-point, symmetric positive definite `Tensor` of
    142         same `dtype` as `loc`.  The strict upper triangle of `covariance_matrix`
    143         is ignored, so if `covariance_matrix` is not symmetric no error will be
    144         raised (unless `validate_args is True`).  `covariance_matrix` has shape
    145         `[B1, ..., Bb, k, k]` where `b >= 0` and `k` is the event size.
    146       validate_args: Python `bool`, default `False`. When `True` distribution
    147         parameters are checked for validity despite possibly degrading runtime
    148         performance. When `False` invalid inputs may silently render incorrect
    149         outputs.
    150       allow_nan_stats: Python `bool`, default `True`. When `True`,
    151         statistics (e.g., mean, mode, variance) use the value "`NaN`" to
    152         indicate the result is undefined. When `False`, an exception is raised
    153         if one or more of the statistic's batch members are undefined.
    154       name: Python `str` name prefixed to Ops created by this class.
    155 
    156     Raises:
    157       ValueError: if neither `loc` nor `covariance_matrix` are specified.
    158     """
    159     parameters = locals()
    160 
    161     # Convert the covariance_matrix up to a scale_tril and call MVNTriL.
    162     with ops.name_scope(name):
    163       with ops.name_scope("init", values=[loc, covariance_matrix]):
    164         if covariance_matrix is None:
    165           scale_tril = None
    166         else:
    167           covariance_matrix = ops.convert_to_tensor(
    168               covariance_matrix, name="covariance_matrix")
    169           if validate_args:
    170             covariance_matrix = control_flow_ops.with_dependencies([
    171                 check_ops.assert_near(
    172                     covariance_matrix,
    173                     array_ops.matrix_transpose(covariance_matrix),
    174                     message="Matrix was not symmetric")], covariance_matrix)
    175           # No need to validate that covariance_matrix is non-singular.
    176           # LinearOperatorLowerTriangular has an assert_non_singular method that
    177           # is called by the Bijector.
    178           # However, cholesky() ignores the upper triangular part, so we do need
    179           # to separately assert symmetric.
    180           scale_tril = linalg_ops.cholesky(covariance_matrix)
    181         super(MultivariateNormalFullCovariance, self).__init__(
    182             loc=loc,
    183             scale_tril=scale_tril,
    184             validate_args=validate_args,
    185             allow_nan_stats=allow_nan_stats,
    186             name=name)
    187     self._parameters = parameters
    188