1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Multivariate Normal distribution classes.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib import linalg 22 from tensorflow.contrib.distributions.python.ops import distribution_util 23 from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop 24 from tensorflow.python.framework import ops 25 26 27 __all__ = [ 28 "MultivariateNormalDiagPlusLowRank", 29 ] 30 31 32 class MultivariateNormalDiagPlusLowRank( 33 mvn_linop.MultivariateNormalLinearOperator): 34 """The multivariate normal distribution on `R^k`. 35 36 The Multivariate Normal distribution is defined over `R^k` and parameterized 37 by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k` 38 `scale` matrix; `covariance = scale @ scale.T` where `@` denotes 39 matrix-multiplication. 40 41 #### Mathematical Details 42 43 The probability density function (pdf) is, 44 45 ```none 46 pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z, 47 y = inv(scale) @ (x - loc), 48 Z = (2 pi)**(0.5 k) |det(scale)|, 49 ``` 50 51 where: 52 53 * `loc` is a vector in `R^k`, 54 * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`, 55 * `Z` denotes the normalization constant, and, 56 * `||y||**2` denotes the squared Euclidean norm of `y`. 57 58 A (non-batch) `scale` matrix is: 59 60 ```none 61 scale = diag(scale_diag + scale_identity_multiplier ones(k)) + 62 scale_perturb_factor @ diag(scale_perturb_diag) @ scale_perturb_factor.T 63 ``` 64 65 where: 66 67 * `scale_diag.shape = [k]`, 68 * `scale_identity_multiplier.shape = []`, 69 * `scale_perturb_factor.shape = [k, r]`, typically `k >> r`, and, 70 * `scale_perturb_diag.shape = [r]`. 71 72 Additional leading dimensions (if any) will index batches. 73 74 If both `scale_diag` and `scale_identity_multiplier` are `None`, then 75 `scale` is the Identity matrix. 76 77 The MultivariateNormal distribution is a member of the [location-scale 78 family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be 79 constructed as, 80 81 ```none 82 X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift. 83 Y = scale @ X + loc 84 ``` 85 86 #### Examples 87 88 ```python 89 tfd = tf.contrib.distributions 90 91 # Initialize a single 3-variate Gaussian with covariance `cov = S @ S.T`, 92 # `S = diag(d) + U @ diag(m) @ U.T`. The perturbation, `U @ diag(m) @ U.T`, is 93 # a rank-2 update. 94 mu = [-0.5., 0, 0.5] # shape: [3] 95 d = [1.5, 0.5, 2] # shape: [3] 96 U = [[1., 2], 97 [-1, 1], 98 [2, -0.5]] # shape: [3, 2] 99 m = [4., 5] # shape: [2] 100 mvn = tfd.MultivariateNormalDiagPlusLowRank( 101 loc=mu 102 scale_diag=d 103 scale_perturb_factor=U, 104 scale_perturb_diag=m) 105 106 # Evaluate this on an observation in `R^3`, returning a scalar. 107 mvn.prob([-1, 0, 1]).eval() # shape: [] 108 109 # Initialize a 2-batch of 3-variate Gaussians; `S = diag(d) + U @ U.T`. 110 mu = [[1., 2, 3], 111 [11, 22, 33]] # shape: [b, k] = [2, 3] 112 U = [[[1., 2], 113 [3, 4], 114 [5, 6]], 115 [[0.5, 0.75], 116 [1,0, 0.25], 117 [1.5, 1.25]]] # shape: [b, k, r] = [2, 3, 2] 118 m = [[0.1, 0.2], 119 [0.4, 0.5]] # shape: [b, r] = [2, 2] 120 121 mvn = tfd.MultivariateNormalDiagPlusLowRank( 122 loc=mu, 123 scale_perturb_factor=U, 124 scale_perturb_diag=m) 125 126 mvn.covariance().eval() # shape: [2, 3, 3] 127 # ==> [[[ 15.63 31.57 48.51] 128 # [ 31.57 69.31 105.05] 129 # [ 48.51 105.05 162.59]] 130 # 131 # [[ 2.59 1.41 3.35] 132 # [ 1.41 2.71 3.34] 133 # [ 3.35 3.34 8.35]]] 134 135 # Compute the pdf of two `R^3` observations (one from each batch); 136 # return a length-2 vector. 137 x = [[-0.9, 0, 0.1], 138 [-10, 0, 9]] # shape: [2, 3] 139 mvn.prob(x).eval() # shape: [2] 140 ``` 141 142 """ 143 144 def __init__(self, 145 loc=None, 146 scale_diag=None, 147 scale_identity_multiplier=None, 148 scale_perturb_factor=None, 149 scale_perturb_diag=None, 150 validate_args=False, 151 allow_nan_stats=True, 152 name="MultivariateNormalDiagPlusLowRank"): 153 """Construct Multivariate Normal distribution on `R^k`. 154 155 The `batch_shape` is the broadcast shape between `loc` and `scale` 156 arguments. 157 158 The `event_shape` is given by last dimension of the matrix implied by 159 `scale`. The last dimension of `loc` (if provided) must broadcast with this. 160 161 Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is: 162 163 ```none 164 scale = diag(scale_diag + scale_identity_multiplier ones(k)) + 165 scale_perturb_factor @ diag(scale_perturb_diag) @ scale_perturb_factor.T 166 ``` 167 168 where: 169 170 * `scale_diag.shape = [k]`, 171 * `scale_identity_multiplier.shape = []`, 172 * `scale_perturb_factor.shape = [k, r]`, typically `k >> r`, and, 173 * `scale_perturb_diag.shape = [r]`. 174 175 Additional leading dimensions (if any) will index batches. 176 177 If both `scale_diag` and `scale_identity_multiplier` are `None`, then 178 `scale` is the Identity matrix. 179 180 Args: 181 loc: Floating-point `Tensor`. If this is set to `None`, `loc` is 182 implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where 183 `b >= 0` and `k` is the event size. 184 scale_diag: Non-zero, floating-point `Tensor` representing a diagonal 185 matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`, 186 and characterizes `b`-batches of `k x k` diagonal matrices added to 187 `scale`. When both `scale_identity_multiplier` and `scale_diag` are 188 `None` then `scale` is the `Identity`. 189 scale_identity_multiplier: Non-zero, floating-point `Tensor` representing 190 a scaled-identity-matrix added to `scale`. May have shape 191 `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scaled 192 `k x k` identity matrices added to `scale`. When both 193 `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is 194 the `Identity`. 195 scale_perturb_factor: Floating-point `Tensor` representing a rank-`r` 196 perturbation added to `scale`. May have shape `[B1, ..., Bb, k, r]`, 197 `b >= 0`, and characterizes `b`-batches of rank-`r` updates to `scale`. 198 When `None`, no rank-`r` update is added to `scale`. 199 scale_perturb_diag: Floating-point `Tensor` representing a diagonal matrix 200 inside the rank-`r` perturbation added to `scale`. May have shape 201 `[B1, ..., Bb, r]`, `b >= 0`, and characterizes `b`-batches of `r x r` 202 diagonal matrices inside the perturbation added to `scale`. When 203 `None`, an identity matrix is used inside the perturbation. Can only be 204 specified if `scale_perturb_factor` is also specified. 205 validate_args: Python `bool`, default `False`. When `True` distribution 206 parameters are checked for validity despite possibly degrading runtime 207 performance. When `False` invalid inputs may silently render incorrect 208 outputs. 209 allow_nan_stats: Python `bool`, default `True`. When `True`, 210 statistics (e.g., mean, mode, variance) use the value "`NaN`" to 211 indicate the result is undefined. When `False`, an exception is raised 212 if one or more of the statistic's batch members are undefined. 213 name: Python `str` name prefixed to Ops created by this class. 214 215 Raises: 216 ValueError: if at most `scale_identity_multiplier` is specified. 217 """ 218 parameters = locals() 219 def _convert_to_tensor(x, name): 220 return None if x is None else ops.convert_to_tensor(x, name=name) 221 with ops.name_scope(name): 222 with ops.name_scope("init", values=[ 223 loc, scale_diag, scale_identity_multiplier, scale_perturb_factor, 224 scale_perturb_diag]): 225 has_low_rank = (scale_perturb_factor is not None or 226 scale_perturb_diag is not None) 227 scale = distribution_util.make_diag_scale( 228 loc=loc, 229 scale_diag=scale_diag, 230 scale_identity_multiplier=scale_identity_multiplier, 231 validate_args=validate_args, 232 assert_positive=has_low_rank) 233 scale_perturb_factor = _convert_to_tensor( 234 scale_perturb_factor, 235 name="scale_perturb_factor") 236 scale_perturb_diag = _convert_to_tensor( 237 scale_perturb_diag, 238 name="scale_perturb_diag") 239 if has_low_rank: 240 scale = linalg.LinearOperatorLowRankUpdate( 241 scale, 242 u=scale_perturb_factor, 243 diag_update=scale_perturb_diag, 244 is_diag_update_positive=scale_perturb_diag is None, 245 is_non_singular=True, # Implied by is_positive_definite=True. 246 is_self_adjoint=True, 247 is_positive_definite=True, 248 is_square=True) 249 super(MultivariateNormalDiagPlusLowRank, self).__init__( 250 loc=loc, 251 scale=scale, 252 validate_args=validate_args, 253 allow_nan_stats=allow_nan_stats, 254 name=name) 255 self._parameters = parameters 256