1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Multivariate Normal distribution classes.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib.distributions.python.ops import distribution_util 22 from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop 23 from tensorflow.python.framework import ops 24 from tensorflow.python.ops import nn 25 26 27 __all__ = [ 28 "MultivariateNormalDiag", 29 "MultivariateNormalDiagWithSoftplusScale", 30 ] 31 32 33 class MultivariateNormalDiag( 34 mvn_linop.MultivariateNormalLinearOperator): 35 """The multivariate normal distribution on `R^k`. 36 37 The Multivariate Normal distribution is defined over `R^k` and parameterized 38 by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k` 39 `scale` matrix; `covariance = scale @ scale.T` where `@` denotes 40 matrix-multiplication. 41 42 #### Mathematical Details 43 44 The probability density function (pdf) is, 45 46 ```none 47 pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z, 48 y = inv(scale) @ (x - loc), 49 Z = (2 pi)**(0.5 k) |det(scale)|, 50 ``` 51 52 where: 53 54 * `loc` is a vector in `R^k`, 55 * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`, 56 * `Z` denotes the normalization constant, and, 57 * `||y||**2` denotes the squared Euclidean norm of `y`. 58 59 A (non-batch) `scale` matrix is: 60 61 ```none 62 scale = diag(scale_diag + scale_identity_multiplier * ones(k)) 63 ``` 64 65 where: 66 67 * `scale_diag.shape = [k]`, and, 68 * `scale_identity_multiplier.shape = []`. 69 70 Additional leading dimensions (if any) will index batches. 71 72 If both `scale_diag` and `scale_identity_multiplier` are `None`, then 73 `scale` is the Identity matrix. 74 75 The MultivariateNormal distribution is a member of the [location-scale 76 family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be 77 constructed as, 78 79 ```none 80 X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift. 81 Y = scale @ X + loc 82 ``` 83 84 #### Examples 85 86 ```python 87 tfd = tf.contrib.distributions 88 89 # Initialize a single 2-variate Gaussian. 90 mvn = tfd.MultivariateNormalDiag( 91 loc=[1., -1], 92 scale_diag=[1, 2.]) 93 94 mvn.mean().eval() 95 # ==> [1., -1] 96 97 mvn.stddev().eval() 98 # ==> [1., 2] 99 100 # Evaluate this on an observation in `R^2`, returning a scalar. 101 mvn.prob([-1., 0]).eval() # shape: [] 102 103 # Initialize a 3-batch, 2-variate scaled-identity Gaussian. 104 mvn = tfd.MultivariateNormalDiag( 105 loc=[1., -1], 106 scale_identity_multiplier=[1, 2., 3]) 107 108 mvn.mean().eval() # shape: [3, 2] 109 # ==> [[1., -1] 110 # [1, -1], 111 # [1, -1]] 112 113 mvn.stddev().eval() # shape: [3, 2] 114 # ==> [[1., 1], 115 # [2, 2], 116 # [3, 3]] 117 118 # Evaluate this on an observation in `R^2`, returning a length-3 vector. 119 mvn.prob([-1., 0]).eval() # shape: [3] 120 121 # Initialize a 2-batch of 3-variate Gaussians. 122 mvn = tfd.MultivariateNormalDiag( 123 loc=[[1., 2, 3], 124 [11, 22, 33]] # shape: [2, 3] 125 scale_diag=[[1., 2, 3], 126 [0.5, 1, 1.5]]) # shape: [2, 3] 127 128 # Evaluate this on a two observations, each in `R^3`, returning a length-2 129 # vector. 130 x = [[-1., 0, 1], 131 [-11, 0, 11.]] # shape: [2, 3]. 132 mvn.prob(x).eval() # shape: [2] 133 ``` 134 135 """ 136 137 def __init__(self, 138 loc=None, 139 scale_diag=None, 140 scale_identity_multiplier=None, 141 validate_args=False, 142 allow_nan_stats=True, 143 name="MultivariateNormalDiag"): 144 """Construct Multivariate Normal distribution on `R^k`. 145 146 The `batch_shape` is the broadcast shape between `loc` and `scale` 147 arguments. 148 149 The `event_shape` is given by last dimension of the matrix implied by 150 `scale`. The last dimension of `loc` (if provided) must broadcast with this. 151 152 Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is: 153 154 ```none 155 scale = diag(scale_diag + scale_identity_multiplier * ones(k)) 156 ``` 157 158 where: 159 160 * `scale_diag.shape = [k]`, and, 161 * `scale_identity_multiplier.shape = []`. 162 163 Additional leading dimensions (if any) will index batches. 164 165 If both `scale_diag` and `scale_identity_multiplier` are `None`, then 166 `scale` is the Identity matrix. 167 168 Args: 169 loc: Floating-point `Tensor`. If this is set to `None`, `loc` is 170 implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where 171 `b >= 0` and `k` is the event size. 172 scale_diag: Non-zero, floating-point `Tensor` representing a diagonal 173 matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`, 174 and characterizes `b`-batches of `k x k` diagonal matrices added to 175 `scale`. When both `scale_identity_multiplier` and `scale_diag` are 176 `None` then `scale` is the `Identity`. 177 scale_identity_multiplier: Non-zero, floating-point `Tensor` representing 178 a scaled-identity-matrix added to `scale`. May have shape 179 `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scaled 180 `k x k` identity matrices added to `scale`. When both 181 `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is 182 the `Identity`. 183 validate_args: Python `bool`, default `False`. When `True` distribution 184 parameters are checked for validity despite possibly degrading runtime 185 performance. When `False` invalid inputs may silently render incorrect 186 outputs. 187 allow_nan_stats: Python `bool`, default `True`. When `True`, 188 statistics (e.g., mean, mode, variance) use the value "`NaN`" to 189 indicate the result is undefined. When `False`, an exception is raised 190 if one or more of the statistic's batch members are undefined. 191 name: Python `str` name prefixed to Ops created by this class. 192 193 Raises: 194 ValueError: if at most `scale_identity_multiplier` is specified. 195 """ 196 parameters = locals() 197 with ops.name_scope(name): 198 with ops.name_scope("init", values=[ 199 loc, scale_diag, scale_identity_multiplier]): 200 # No need to validate_args while making diag_scale. The returned 201 # LinearOperatorDiag has an assert_non_singular method that is called by 202 # the Bijector. 203 scale = distribution_util.make_diag_scale( 204 loc=loc, 205 scale_diag=scale_diag, 206 scale_identity_multiplier=scale_identity_multiplier, 207 validate_args=False, 208 assert_positive=False) 209 super(MultivariateNormalDiag, self).__init__( 210 loc=loc, 211 scale=scale, 212 validate_args=validate_args, 213 allow_nan_stats=allow_nan_stats, 214 name=name) 215 self._parameters = parameters 216 217 218 class MultivariateNormalDiagWithSoftplusScale(MultivariateNormalDiag): 219 """MultivariateNormalDiag with `diag_stddev = softplus(diag_stddev)`.""" 220 221 def __init__(self, 222 loc, 223 scale_diag, 224 validate_args=False, 225 allow_nan_stats=True, 226 name="MultivariateNormalDiagWithSoftplusScale"): 227 parameters = locals() 228 with ops.name_scope(name, values=[scale_diag]): 229 super(MultivariateNormalDiagWithSoftplusScale, self).__init__( 230 loc=loc, 231 scale_diag=nn.softplus(scale_diag), 232 validate_args=validate_args, 233 allow_nan_stats=allow_nan_stats, 234 name=name) 235 self._parameters = parameters 236