1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Multivariate Normal distribution class initialized with a full covariance.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib.distributions.python.ops import mvn_tril 22 from tensorflow.python.framework import ops 23 from tensorflow.python.ops import array_ops 24 from tensorflow.python.ops import check_ops 25 from tensorflow.python.ops import control_flow_ops 26 from tensorflow.python.ops import linalg_ops 27 28 29 __all__ = [ 30 "MultivariateNormalFullCovariance", 31 ] 32 33 34 class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): 35 """The multivariate normal distribution on `R^k`. 36 37 The Multivariate Normal distribution is defined over `R^k` and parameterized 38 by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k` 39 `covariance_matrix` matrices that are the covariance. 40 This is different than the other multivariate normals, which are parameterized 41 by a matrix more akin to the standard deviation. 42 43 #### Mathematical Details 44 45 The probability density function (pdf) is, with `@` as matrix multiplication, 46 47 ```none 48 pdf(x; loc, covariance_matrix) = exp(-0.5 ||y||**2) / Z, 49 y = (x - loc)^T @ inv(covariance_matrix) @ (x - loc) 50 Z = (2 pi)**(0.5 k) |det(covariance_matrix)|**(0.5). 51 ``` 52 53 where: 54 55 * `loc` is a vector in `R^k`, 56 * `covariance_matrix` is an `R^{k x k}` symmetric positive definite matrix, 57 * `Z` denotes the normalization constant, and, 58 * `||y||**2` denotes the squared Euclidean norm of `y`. 59 60 Additional leading dimensions (if any) in `loc` and `covariance_matrix` allow 61 for batch dimensions. 62 63 The MultivariateNormal distribution is a member of the [location-scale 64 family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be 65 constructed e.g. as, 66 67 ```none 68 X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift. 69 scale = Cholesky(covariance_matrix) 70 Y = scale @ X + loc 71 ``` 72 73 #### Examples 74 75 ```python 76 tfd = tf.contrib.distributions 77 78 # Initialize a single 3-variate Gaussian. 79 mu = [1., 2, 3] 80 cov = [[ 0.36, 0.12, 0.06], 81 [ 0.12, 0.29, -0.13], 82 [ 0.06, -0.13, 0.26]] 83 mvn = tfd.MultivariateNormalFullCovariance( 84 loc=mu, 85 covariance_matrix=cov) 86 87 mvn.mean().eval() 88 # ==> [1., 2, 3] 89 90 # Covariance agrees with covariance_matrix. 91 mvn.covariance().eval() 92 # ==> [[ 0.36, 0.12, 0.06], 93 # [ 0.12, 0.29, -0.13], 94 # [ 0.06, -0.13, 0.26]] 95 96 # Compute the pdf of an observation in `R^3` ; return a scalar. 97 mvn.prob([-1., 0, 1]).eval() # shape: [] 98 99 # Initialize a 2-batch of 3-variate Gaussians. 100 mu = [[1., 2, 3], 101 [11, 22, 33]] # shape: [2, 3] 102 covariance_matrix = ... # shape: [2, 3, 3], symmetric, positive definite. 103 mvn = tfd.MultivariateNormalFullCovariance( 104 loc=mu, 105 covariance=covariance_matrix) 106 107 # Compute the pdf of two `R^3` observations; return a length-2 vector. 108 x = [[-0.9, 0, 0.1], 109 [-10, 0, 9]] # shape: [2, 3] 110 mvn.prob(x).eval() # shape: [2] 111 112 ``` 113 114 """ 115 116 def __init__(self, 117 loc=None, 118 covariance_matrix=None, 119 validate_args=False, 120 allow_nan_stats=True, 121 name="MultivariateNormalFullCovariance"): 122 """Construct Multivariate Normal distribution on `R^k`. 123 124 The `batch_shape` is the broadcast shape between `loc` and 125 `covariance_matrix` arguments. 126 127 The `event_shape` is given by last dimension of the matrix implied by 128 `covariance_matrix`. The last dimension of `loc` (if provided) must 129 broadcast with this. 130 131 A non-batch `covariance_matrix` matrix is a `k x k` symmetric positive 132 definite matrix. In other words it is (real) symmetric with all eigenvalues 133 strictly positive. 134 135 Additional leading dimensions (if any) will index batches. 136 137 Args: 138 loc: Floating-point `Tensor`. If this is set to `None`, `loc` is 139 implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where 140 `b >= 0` and `k` is the event size. 141 covariance_matrix: Floating-point, symmetric positive definite `Tensor` of 142 same `dtype` as `loc`. The strict upper triangle of `covariance_matrix` 143 is ignored, so if `covariance_matrix` is not symmetric no error will be 144 raised (unless `validate_args is True`). `covariance_matrix` has shape 145 `[B1, ..., Bb, k, k]` where `b >= 0` and `k` is the event size. 146 validate_args: Python `bool`, default `False`. When `True` distribution 147 parameters are checked for validity despite possibly degrading runtime 148 performance. When `False` invalid inputs may silently render incorrect 149 outputs. 150 allow_nan_stats: Python `bool`, default `True`. When `True`, 151 statistics (e.g., mean, mode, variance) use the value "`NaN`" to 152 indicate the result is undefined. When `False`, an exception is raised 153 if one or more of the statistic's batch members are undefined. 154 name: Python `str` name prefixed to Ops created by this class. 155 156 Raises: 157 ValueError: if neither `loc` nor `covariance_matrix` are specified. 158 """ 159 parameters = locals() 160 161 # Convert the covariance_matrix up to a scale_tril and call MVNTriL. 162 with ops.name_scope(name): 163 with ops.name_scope("init", values=[loc, covariance_matrix]): 164 if covariance_matrix is None: 165 scale_tril = None 166 else: 167 covariance_matrix = ops.convert_to_tensor( 168 covariance_matrix, name="covariance_matrix") 169 if validate_args: 170 covariance_matrix = control_flow_ops.with_dependencies([ 171 check_ops.assert_near( 172 covariance_matrix, 173 array_ops.matrix_transpose(covariance_matrix), 174 message="Matrix was not symmetric")], covariance_matrix) 175 # No need to validate that covariance_matrix is non-singular. 176 # LinearOperatorLowerTriangular has an assert_non_singular method that 177 # is called by the Bijector. 178 # However, cholesky() ignores the upper triangular part, so we do need 179 # to separately assert symmetric. 180 scale_tril = linalg_ops.cholesky(covariance_matrix) 181 super(MultivariateNormalFullCovariance, self).__init__( 182 loc=loc, 183 scale_tril=scale_tril, 184 validate_args=validate_args, 185 allow_nan_stats=allow_nan_stats, 186 name=name) 187 self._parameters = parameters 188