1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """The same-family Mixture distribution class.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.contrib.distributions.python.ops import distribution_util as distribution_utils 24 from tensorflow.python.framework import ops 25 from tensorflow.python.ops import array_ops 26 from tensorflow.python.ops import control_flow_ops 27 from tensorflow.python.ops import math_ops 28 from tensorflow.python.ops import nn_ops 29 from tensorflow.python.ops.distributions import distribution 30 from tensorflow.python.ops.distributions import util as distribution_util 31 32 33 class MixtureSameFamily(distribution.Distribution): 34 """Mixture (same-family) distribution. 35 36 The `MixtureSameFamily` distribution implements a (batch of) mixture 37 distribution where all components are from different parameterizations of the 38 same distribution type. It is parameterized by a `Categorical` "selecting 39 distribution" (over `k` components) and a components distribution, i.e., a 40 `Distribution` with a rightmost batch shape (equal to `[k]`) which indexes 41 each (batch of) component. 42 43 #### Examples 44 45 ```python 46 tfd = tf.contrib.distributions 47 48 ### Create a mixture of two scalar Gaussians: 49 50 gm = tfd.MixtureSameFamily( 51 mixture_distribution=tfd.Categorical( 52 probs=[0.3, 0.7]), 53 components_distribution=tfd.Normal( 54 loc=[-1., 1], # One for each component. 55 scale=[0.1, 0.5])) # And same here. 56 57 gm.mean() 58 # ==> 0.4 59 60 gm.variance() 61 # ==> 1.018 62 63 # Plot PDF. 64 x = np.linspace(-2., 3., int(1e4), dtype=np.float32) 65 import matplotlib.pyplot as plt 66 plt.plot(x, gm.prob(x).eval()); 67 68 ### Create a mixture of two Bivariate Gaussians: 69 70 gm = tfd.MixtureSameFamily( 71 mixture_distribution=tfd.Categorical( 72 probs=[0.3, 0.7]), 73 components_distribution=tfd.MultivariateNormalDiag( 74 loc=[[-1., 1], # component 1 75 [1, -1]], # component 2 76 scale_identity_multiplier=[.3, .6])) 77 78 gm.mean() 79 # ==> array([ 0.4, -0.4], dtype=float32) 80 81 gm.covariance() 82 # ==> array([[ 1.119, -0.84], 83 # [-0.84, 1.119]], dtype=float32) 84 85 # Plot PDF contours. 86 def meshgrid(x, y=x): 87 [gx, gy] = np.meshgrid(x, y, indexing='ij') 88 gx, gy = np.float32(gx), np.float32(gy) 89 grid = np.concatenate([gx.ravel()[None, :], gy.ravel()[None, :]], axis=0) 90 return grid.T.reshape(x.size, y.size, 2) 91 grid = meshgrid(np.linspace(-2, 2, 100, dtype=np.float32)) 92 plt.contour(grid[..., 0], grid[..., 1], gm.prob(grid).eval()); 93 94 ``` 95 96 """ 97 98 def __init__(self, 99 mixture_distribution, 100 components_distribution, 101 validate_args=False, 102 allow_nan_stats=True, 103 name="MixtureSameFamily"): 104 """Construct a `MixtureSameFamily` distribution. 105 106 Args: 107 mixture_distribution: `tf.distributions.Categorical`-like instance. 108 Manages the probability of selecting components. The number of 109 categories must match the rightmost batch dimension of the 110 `components_distribution`. Must have either scalar `batch_shape` or 111 `batch_shape` matching `components_distribution.batch_shape[:-1]`. 112 components_distribution: `tf.distributions.Distribution`-like instance. 113 Right-most batch dimension indexes components. 114 validate_args: Python `bool`, default `False`. When `True` distribution 115 parameters are checked for validity despite possibly degrading runtime 116 performance. When `False` invalid inputs may silently render incorrect 117 outputs. 118 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 119 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 120 result is undefined. When `False`, an exception is raised if one or 121 more of the statistic's batch members are undefined. 122 name: Python `str` name prefixed to Ops created by this class. 123 124 Raises: 125 ValueError: `if not mixture_distribution.dtype.is_integer`. 126 ValueError: if mixture_distribution does not have scalar `event_shape`. 127 ValueError: if `mixture_distribution.batch_shape` and 128 `components_distribution.batch_shape[:-1]` are both fully defined and 129 the former is neither scalar nor equal to the latter. 130 ValueError: if `mixture_distribution` categories does not equal 131 `components_distribution` rightmost batch shape. 132 """ 133 parameters = locals() 134 with ops.name_scope(name): 135 self._mixture_distribution = mixture_distribution 136 self._components_distribution = components_distribution 137 self._runtime_assertions = [] 138 139 s = components_distribution.event_shape_tensor() 140 self._event_ndims = (s.shape[0].value 141 if s.shape.with_rank_at_least(1)[0].value is not None 142 else array_ops.shape(s)[0]) 143 144 if not mixture_distribution.dtype.is_integer: 145 raise ValueError( 146 "`mixture_distribution.dtype` ({}) is not over integers".format( 147 mixture_distribution.dtype.name)) 148 149 if (mixture_distribution.event_shape.ndims is not None 150 and mixture_distribution.event_shape.ndims != 0): 151 raise ValueError("`mixture_distribution` must have scalar `event_dim`s") 152 elif validate_args: 153 self._runtime_assertions += [ 154 control_flow_ops.assert_has_rank( 155 mixture_distribution.event_shape_tensor(), 0, 156 message="`mixture_distribution` must have scalar `event_dim`s"), 157 ] 158 159 mdbs = mixture_distribution.batch_shape 160 cdbs = components_distribution.batch_shape.with_rank_at_least(1)[:-1] 161 if mdbs.is_fully_defined() and cdbs.is_fully_defined(): 162 if mdbs.ndims != 0 and mdbs != cdbs: 163 raise ValueError( 164 "`mixture_distribution.batch_shape` (`{}`) is not " 165 "compatible with `components_distribution.batch_shape` " 166 "(`{}`)".format(mdbs.as_list(), cdbs.as_list())) 167 elif validate_args: 168 mdbs = mixture_distribution.batch_shape_tensor() 169 cdbs = components_distribution.batch_shape_tensor()[:-1] 170 self._runtime_assertions += [ 171 control_flow_ops.assert_equal( 172 distribution_util.pick_vector( 173 mixture_distribution.is_scalar_batch(), cdbs, mdbs), 174 cdbs, 175 message=( 176 "`mixture_distribution.batch_shape` is not " 177 "compatible with `components_distribution.batch_shape`"))] 178 179 km = mixture_distribution.logits.shape.with_rank_at_least(1)[-1].value 180 kc = components_distribution.batch_shape.with_rank_at_least(1)[-1].value 181 if km is not None and kc is not None and km != kc: 182 raise ValueError("`mixture_distribution components` ({}) does not " 183 "equal `components_distribution.batch_shape[-1]` " 184 "({})".format(km, kc)) 185 elif validate_args: 186 km = array_ops.shape(mixture_distribution.logits)[-1] 187 kc = components_distribution.batch_shape_tensor()[-1] 188 self._runtime_assertions += [ 189 control_flow_ops.assert_equal( 190 km, kc, 191 message=("`mixture_distribution components` does not equal " 192 "`components_distribution.batch_shape[-1:]`")), 193 ] 194 elif km is None: 195 km = array_ops.shape(mixture_distribution.logits)[-1] 196 197 self._num_components = km 198 199 super(MixtureSameFamily, self).__init__( 200 dtype=self._components_distribution.dtype, 201 reparameterization_type=distribution.NOT_REPARAMETERIZED, 202 validate_args=validate_args, 203 allow_nan_stats=allow_nan_stats, 204 parameters=parameters, 205 graph_parents=( 206 self._mixture_distribution._graph_parents # pylint: disable=protected-access 207 + self._components_distribution._graph_parents), # pylint: disable=protected-access 208 name=name) 209 210 @property 211 def mixture_distribution(self): 212 return self._mixture_distribution 213 214 @property 215 def components_distribution(self): 216 return self._components_distribution 217 218 def _batch_shape_tensor(self): 219 with ops.control_dependencies(self._runtime_assertions): 220 return self.components_distribution.batch_shape_tensor()[:-1] 221 222 def _batch_shape(self): 223 return self.components_distribution.batch_shape.with_rank_at_least(1)[:-1] 224 225 def _event_shape_tensor(self): 226 with ops.control_dependencies(self._runtime_assertions): 227 return self.components_distribution.event_shape_tensor() 228 229 def _event_shape(self): 230 return self.components_distribution.event_shape 231 232 def _sample_n(self, n, seed): 233 with ops.control_dependencies(self._runtime_assertions): 234 x = self.components_distribution.sample(n) # [n, B, k, E] 235 # TODO(jvdillon): Consider using tf.gather (by way of index unrolling). 236 npdt = x.dtype.as_numpy_dtype 237 mask = array_ops.one_hot( 238 indices=self.mixture_distribution.sample(n), # [n, B] 239 depth=self._num_components, # == k 240 on_value=np.ones([], dtype=npdt), 241 off_value=np.zeros([], dtype=npdt)) # [n, B, k] 242 mask = distribution_utils.pad_mixture_dimensions( 243 mask, self, self.mixture_distribution, 244 self._event_shape().ndims) # [n, B, k, [1]*e] 245 return math_ops.reduce_sum( 246 x * mask, axis=-1 - self._event_ndims) # [n, B, E] 247 248 def _log_prob(self, x): 249 with ops.control_dependencies(self._runtime_assertions): 250 x = self._pad_sample_dims(x) 251 log_prob_x = self.components_distribution.log_prob(x) # [S, B, k] 252 log_mix_prob = nn_ops.log_softmax( 253 self.mixture_distribution.logits, axis=-1) # [B, k] 254 return math_ops.reduce_logsumexp( 255 log_prob_x + log_mix_prob, axis=-1) # [S, B] 256 257 def _mean(self): 258 with ops.control_dependencies(self._runtime_assertions): 259 probs = distribution_utils.pad_mixture_dimensions( 260 self.mixture_distribution.probs, self, self.mixture_distribution, 261 self._event_shape().ndims) # [B, k, [1]*e] 262 return math_ops.reduce_sum( 263 probs * self.components_distribution.mean(), 264 axis=-1 - self._event_ndims) # [B, E] 265 266 def _log_cdf(self, x): 267 x = self._pad_sample_dims(x) 268 log_cdf_x = self.components_distribution.log_cdf(x) # [S, B, k] 269 log_mix_prob = nn_ops.log_softmax( 270 self.mixture_distribution.logits, axis=-1) # [B, k] 271 return math_ops.reduce_logsumexp( 272 log_cdf_x + log_mix_prob, axis=-1) # [S, B] 273 274 def _variance(self): 275 with ops.control_dependencies(self._runtime_assertions): 276 # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) 277 probs = distribution_utils.pad_mixture_dimensions( 278 self.mixture_distribution.probs, self, self.mixture_distribution, 279 self._event_shape().ndims) # [B, k, [1]*e] 280 mean_cond_var = math_ops.reduce_sum( 281 probs * self.components_distribution.variance(), 282 axis=-1 - self._event_ndims) # [B, E] 283 var_cond_mean = math_ops.reduce_sum( 284 probs * math_ops.squared_difference( 285 self.components_distribution.mean(), 286 self._pad_sample_dims(self._mean())), 287 axis=-1 - self._event_ndims) # [B, E] 288 return mean_cond_var + var_cond_mean # [B, E] 289 290 def _covariance(self): 291 static_event_ndims = self.event_shape.ndims 292 if static_event_ndims != 1: 293 # Covariance is defined only for vector distributions. 294 raise NotImplementedError("covariance is not implemented") 295 296 with ops.control_dependencies(self._runtime_assertions): 297 # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) 298 probs = distribution_utils.pad_mixture_dimensions( 299 distribution_utils.pad_mixture_dimensions( 300 self.mixture_distribution.probs, self, self.mixture_distribution, 301 self._event_shape().ndims), 302 self, self.mixture_distribution, 303 self._event_shape().ndims) # [B, k, 1, 1] 304 mean_cond_var = math_ops.reduce_sum( 305 probs * self.components_distribution.covariance(), 306 axis=-3) # [B, e, e] 307 var_cond_mean = math_ops.reduce_sum( 308 probs * _outer_squared_difference( 309 self.components_distribution.mean(), 310 self._pad_sample_dims(self._mean())), 311 axis=-3) # [B, e, e] 312 return mean_cond_var + var_cond_mean # [B, e, e] 313 314 def _pad_sample_dims(self, x): 315 with ops.name_scope("pad_sample_dims", values=[x]): 316 ndims = x.shape.ndims if x.shape.ndims is not None else array_ops.rank(x) 317 shape = array_ops.shape(x) 318 d = ndims - self._event_ndims 319 x = array_ops.reshape(x, shape=array_ops.concat([ 320 shape[:d], [1], shape[d:]], axis=0)) 321 return x 322 323 324 def _outer_squared_difference(x, y): 325 """Convenience function analogous to tf.squared_difference.""" 326 z = x - y 327 return z[..., array_ops.newaxis, :] * z[..., array_ops.newaxis] 328