1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """The Independent distribution class.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.framework import dtypes 24 from tensorflow.python.framework import ops 25 from tensorflow.python.framework import tensor_shape 26 from tensorflow.python.framework import tensor_util 27 from tensorflow.python.ops import array_ops 28 from tensorflow.python.ops import check_ops 29 from tensorflow.python.ops import math_ops 30 from tensorflow.python.ops.distributions import distribution as distribution_lib 31 from tensorflow.python.ops.distributions import kullback_leibler 32 from tensorflow.python.util import deprecation 33 34 35 class Independent(distribution_lib.Distribution): 36 """Independent distribution from batch of distributions. 37 38 This distribution is useful for regarding a collection of independent, 39 non-identical distributions as a single random variable. For example, the 40 `Independent` distribution composed of a collection of `Bernoulli` 41 distributions might define a distribution over an image (where each 42 `Bernoulli` is a distribution over each pixel). 43 44 More precisely, a collection of `B` (independent) `E`-variate random variables 45 (rv) `{X_1, ..., X_B}`, can be regarded as a `[B, E]`-variate random variable 46 `(X_1, ..., X_B)` with probability 47 `p(x_1, ..., x_B) = p_1(x_1) * ... * p_B(x_B)` where `p_b(X_b)` is the 48 probability of the `b`-th rv. More generally `B, E` can be arbitrary shapes. 49 50 Similarly, the `Independent` distribution specifies a distribution over `[B, 51 E]`-shaped events. It operates by reinterpreting the rightmost batch dims as 52 part of the event dimensions. The `reinterpreted_batch_ndims` parameter 53 controls the number of batch dims which are absorbed as event dims; 54 `reinterpreted_batch_ndims < len(batch_shape)`. For example, the `log_prob` 55 function entails a `reduce_sum` over the rightmost `reinterpreted_batch_ndims` 56 after calling the base distribution's `log_prob`. In other words, since the 57 batch dimension(s) index independent distributions, the resultant multivariate 58 will have independent components. 59 60 #### Mathematical Details 61 62 The probability function is, 63 64 ```none 65 prob(x; reinterpreted_batch_ndims) = tf.reduce_prod( 66 dist.prob(x), 67 axis=-1-range(reinterpreted_batch_ndims)) 68 ``` 69 70 #### Examples 71 72 ```python 73 import tensorflow_probability as tfp 74 tfd = tfp.distributions 75 76 # Make independent distribution from a 2-batch Normal. 77 ind = tfd.Independent( 78 distribution=tfd.Normal(loc=[-1., 1], scale=[0.1, 0.5]), 79 reinterpreted_batch_ndims=1) 80 81 # All batch dims have been "absorbed" into event dims. 82 ind.batch_shape # ==> [] 83 ind.event_shape # ==> [2] 84 85 # Make independent distribution from a 2-batch bivariate Normal. 86 ind = tfd.Independent( 87 distribution=tfd.MultivariateNormalDiag( 88 loc=[[-1., 1], [1, -1]], 89 scale_identity_multiplier=[1., 0.5]), 90 reinterpreted_batch_ndims=1) 91 92 # All batch dims have been "absorbed" into event dims. 93 ind.batch_shape # ==> [] 94 ind.event_shape # ==> [2, 2] 95 ``` 96 97 """ 98 99 @deprecation.deprecated( 100 "2018-10-01", 101 "The TensorFlow Distributions library has moved to " 102 "TensorFlow Probability " 103 "(https://github.com/tensorflow/probability). You " 104 "should update all references to use `tfp.distributions` " 105 "instead of `tf.contrib.distributions`.", 106 warn_once=True) 107 def __init__( 108 self, distribution, reinterpreted_batch_ndims=None, 109 validate_args=False, name=None): 110 """Construct a `Independent` distribution. 111 112 Args: 113 distribution: The base distribution instance to transform. Typically an 114 instance of `Distribution`. 115 reinterpreted_batch_ndims: Scalar, integer number of rightmost batch dims 116 which will be regarded as event dims. When `None` all but the first 117 batch axis (batch axis 0) will be transferred to event dimensions 118 (analogous to `tf.layers.flatten`). 119 validate_args: Python `bool`. Whether to validate input with asserts. 120 If `validate_args` is `False`, and the inputs are invalid, 121 correct behavior is not guaranteed. 122 name: The name for ops managed by the distribution. 123 Default value: `Independent + distribution.name`. 124 125 Raises: 126 ValueError: if `reinterpreted_batch_ndims` exceeds 127 `distribution.batch_ndims` 128 """ 129 parameters = dict(locals()) 130 name = name or "Independent" + distribution.name 131 self._distribution = distribution 132 with ops.name_scope(name) as name: 133 if reinterpreted_batch_ndims is None: 134 reinterpreted_batch_ndims = self._get_default_reinterpreted_batch_ndims( 135 distribution) 136 reinterpreted_batch_ndims = ops.convert_to_tensor( 137 reinterpreted_batch_ndims, 138 dtype=dtypes.int32, 139 name="reinterpreted_batch_ndims") 140 self._reinterpreted_batch_ndims = reinterpreted_batch_ndims 141 self._static_reinterpreted_batch_ndims = tensor_util.constant_value( 142 reinterpreted_batch_ndims) 143 if self._static_reinterpreted_batch_ndims is not None: 144 self._reinterpreted_batch_ndims = self._static_reinterpreted_batch_ndims 145 super(Independent, self).__init__( 146 dtype=self._distribution.dtype, 147 reparameterization_type=self._distribution.reparameterization_type, 148 validate_args=validate_args, 149 allow_nan_stats=self._distribution.allow_nan_stats, 150 parameters=parameters, 151 graph_parents=( 152 [reinterpreted_batch_ndims] + 153 distribution._graph_parents), # pylint: disable=protected-access 154 name=name) 155 self._runtime_assertions = self._make_runtime_assertions( 156 distribution, reinterpreted_batch_ndims, validate_args) 157 158 @property 159 def distribution(self): 160 return self._distribution 161 162 @property 163 def reinterpreted_batch_ndims(self): 164 return self._reinterpreted_batch_ndims 165 166 def _batch_shape_tensor(self): 167 with ops.control_dependencies(self._runtime_assertions): 168 batch_shape = self.distribution.batch_shape_tensor() 169 dim0 = tensor_shape.dimension_value( 170 batch_shape.shape.with_rank_at_least(1)[0]) 171 batch_ndims = (dim0 172 if dim0 is not None 173 else array_ops.shape(batch_shape)[0]) 174 return batch_shape[:batch_ndims - self.reinterpreted_batch_ndims] 175 176 def _batch_shape(self): 177 batch_shape = self.distribution.batch_shape 178 if (self._static_reinterpreted_batch_ndims is None 179 or batch_shape.ndims is None): 180 return tensor_shape.TensorShape(None) 181 d = batch_shape.ndims - self._static_reinterpreted_batch_ndims 182 return batch_shape[:d] 183 184 def _event_shape_tensor(self): 185 with ops.control_dependencies(self._runtime_assertions): 186 batch_shape = self.distribution.batch_shape_tensor() 187 dim0 = tensor_shape.dimension_value( 188 batch_shape.shape.with_rank_at_least(1)[0]) 189 batch_ndims = (dim0 190 if dim0 is not None 191 else array_ops.shape(batch_shape)[0]) 192 return array_ops.concat([ 193 batch_shape[batch_ndims - self.reinterpreted_batch_ndims:], 194 self.distribution.event_shape_tensor(), 195 ], axis=0) 196 197 def _event_shape(self): 198 batch_shape = self.distribution.batch_shape 199 if (self._static_reinterpreted_batch_ndims is None 200 or batch_shape.ndims is None): 201 return tensor_shape.TensorShape(None) 202 d = batch_shape.ndims - self._static_reinterpreted_batch_ndims 203 return batch_shape[d:].concatenate(self.distribution.event_shape) 204 205 def _sample_n(self, n, seed): 206 with ops.control_dependencies(self._runtime_assertions): 207 return self.distribution.sample(sample_shape=n, seed=seed) 208 209 def _log_prob(self, x): 210 with ops.control_dependencies(self._runtime_assertions): 211 return self._reduce_sum(self.distribution.log_prob(x)) 212 213 def _entropy(self): 214 with ops.control_dependencies(self._runtime_assertions): 215 return self._reduce_sum(self.distribution.entropy()) 216 217 def _mean(self): 218 with ops.control_dependencies(self._runtime_assertions): 219 return self.distribution.mean() 220 221 def _variance(self): 222 with ops.control_dependencies(self._runtime_assertions): 223 return self.distribution.variance() 224 225 def _stddev(self): 226 with ops.control_dependencies(self._runtime_assertions): 227 return self.distribution.stddev() 228 229 def _mode(self): 230 with ops.control_dependencies(self._runtime_assertions): 231 return self.distribution.mode() 232 233 def _make_runtime_assertions( 234 self, distribution, reinterpreted_batch_ndims, validate_args): 235 assertions = [] 236 static_reinterpreted_batch_ndims = tensor_util.constant_value( 237 reinterpreted_batch_ndims) 238 batch_ndims = distribution.batch_shape.ndims 239 if batch_ndims is not None and static_reinterpreted_batch_ndims is not None: 240 if static_reinterpreted_batch_ndims > batch_ndims: 241 raise ValueError("reinterpreted_batch_ndims({}) cannot exceed " 242 "distribution.batch_ndims({})".format( 243 static_reinterpreted_batch_ndims, batch_ndims)) 244 elif validate_args: 245 batch_shape = distribution.batch_shape_tensor() 246 dim0 = tensor_shape.dimension_value( 247 batch_shape.shape.with_rank_at_least(1)[0]) 248 batch_ndims = ( 249 dim0 250 if dim0 is not None 251 else array_ops.shape(batch_shape)[0]) 252 assertions.append(check_ops.assert_less_equal( 253 reinterpreted_batch_ndims, batch_ndims, 254 message=("reinterpreted_batch_ndims cannot exceed " 255 "distribution.batch_ndims"))) 256 return assertions 257 258 def _reduce_sum(self, stat): 259 if self._static_reinterpreted_batch_ndims is None: 260 range_ = math_ops.range(self._reinterpreted_batch_ndims) 261 else: 262 range_ = np.arange(self._static_reinterpreted_batch_ndims) 263 return math_ops.reduce_sum(stat, axis=-1-range_) 264 265 def _get_default_reinterpreted_batch_ndims(self, distribution): 266 """Computes the default value for reinterpreted_batch_ndim __init__ arg.""" 267 ndims = distribution.batch_shape.ndims 268 if ndims is None: 269 which_maximum = math_ops.maximum 270 ndims = array_ops.shape(distribution.batch_shape_tensor())[0] 271 else: 272 which_maximum = np.maximum 273 return which_maximum(0, ndims - 1) 274 275 276 @kullback_leibler.RegisterKL(Independent, Independent) 277 @deprecation.deprecated( 278 "2018-10-01", 279 "The TensorFlow Distributions library has moved to " 280 "TensorFlow Probability " 281 "(https://github.com/tensorflow/probability). You " 282 "should update all references to use `tfp.distributions` " 283 "instead of `tf.contrib.distributions`.", 284 warn_once=True) 285 def _kl_independent(a, b, name="kl_independent"): 286 """Batched KL divergence `KL(a || b)` for Independent distributions. 287 288 We can leverage the fact that 289 ``` 290 KL(Independent(a) || Independent(b)) = sum(KL(a || b)) 291 ``` 292 where the sum is over the `reinterpreted_batch_ndims`. 293 294 Args: 295 a: Instance of `Independent`. 296 b: Instance of `Independent`. 297 name: (optional) name to use for created ops. Default "kl_independent". 298 299 Returns: 300 Batchwise `KL(a || b)`. 301 302 Raises: 303 ValueError: If the event space for `a` and `b`, or their underlying 304 distributions don't match. 305 """ 306 p = a.distribution 307 q = b.distribution 308 309 # The KL between any two (non)-batched distributions is a scalar. 310 # Given that the KL between two factored distributions is the sum, i.e. 311 # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute 312 # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions. 313 if a.event_shape.is_fully_defined() and b.event_shape.is_fully_defined(): 314 if a.event_shape == b.event_shape: 315 if p.event_shape == q.event_shape: 316 num_reduce_dims = a.event_shape.ndims - p.event_shape.ndims 317 reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)] 318 319 return math_ops.reduce_sum( 320 kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims) 321 else: 322 raise NotImplementedError("KL between Independents with different " 323 "event shapes not supported.") 324 else: 325 raise ValueError("Event shapes do not match.") 326 else: 327 with ops.control_dependencies([ 328 check_ops.assert_equal(a.event_shape_tensor(), b.event_shape_tensor()), 329 check_ops.assert_equal(p.event_shape_tensor(), q.event_shape_tensor()) 330 ]): 331 num_reduce_dims = ( 332 array_ops.shape(a.event_shape_tensor()[0]) - 333 array_ops.shape(p.event_shape_tensor()[0])) 334 reduce_dims = math_ops.range(-num_reduce_dims - 1, -1, 1) 335 return math_ops.reduce_sum( 336 kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims) 337