1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Quantized distribution.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.framework import ops 24 from tensorflow.python.ops import array_ops 25 from tensorflow.python.ops import check_ops 26 from tensorflow.python.ops import control_flow_ops 27 from tensorflow.python.ops import math_ops 28 from tensorflow.python.ops.distributions import distribution as distributions 29 from tensorflow.python.ops.distributions import util as distribution_util 30 from tensorflow.python.util import deprecation 31 32 __all__ = ["QuantizedDistribution"] 33 34 35 @deprecation.deprecated( 36 "2018-10-01", 37 "The TensorFlow Distributions library has moved to " 38 "TensorFlow Probability " 39 "(https://github.com/tensorflow/probability). You " 40 "should update all references to use `tfp.distributions` " 41 "instead of `tf.contrib.distributions`.", 42 warn_once=True) 43 def _logsum_expbig_minus_expsmall(big, small): 44 """Stable evaluation of `Log[exp{big} - exp{small}]`. 45 46 To work correctly, we should have the pointwise relation: `small <= big`. 47 48 Args: 49 big: Floating-point `Tensor` 50 small: Floating-point `Tensor` with same `dtype` as `big` and broadcastable 51 shape. 52 53 Returns: 54 `Tensor` of same `dtype` of `big` and broadcast shape. 55 """ 56 with ops.name_scope("logsum_expbig_minus_expsmall", values=[small, big]): 57 return math_ops.log(1. - math_ops.exp(small - big)) + big 58 59 60 _prob_base_note = """ 61 For whole numbers `y`, 62 63 ``` 64 P[Y = y] := P[X <= low], if y == low, 65 := P[X > high - 1], y == high, 66 := 0, if j < low or y > high, 67 := P[y - 1 < X <= y], all other y. 68 ``` 69 70 """ 71 72 _prob_note = _prob_base_note + """ 73 The base distribution's `cdf` method must be defined on `y - 1`. If the 74 base distribution has a `survival_function` method, results will be more 75 accurate for large values of `y`, and in this case the `survival_function` must 76 also be defined on `y - 1`. 77 """ 78 79 _log_prob_note = _prob_base_note + """ 80 The base distribution's `log_cdf` method must be defined on `y - 1`. If the 81 base distribution has a `log_survival_function` method results will be more 82 accurate for large values of `y`, and in this case the `log_survival_function` 83 must also be defined on `y - 1`. 84 """ 85 86 87 _cdf_base_note = """ 88 89 For whole numbers `y`, 90 91 ``` 92 cdf(y) := P[Y <= y] 93 = 1, if y >= high, 94 = 0, if y < low, 95 = P[X <= y], otherwise. 96 ``` 97 98 Since `Y` only has mass at whole numbers, `P[Y <= y] = P[Y <= floor(y)]`. 99 This dictates that fractional `y` are first floored to a whole number, and 100 then above definition applies. 101 """ 102 103 _cdf_note = _cdf_base_note + """ 104 The base distribution's `cdf` method must be defined on `y - 1`. 105 """ 106 107 _log_cdf_note = _cdf_base_note + """ 108 The base distribution's `log_cdf` method must be defined on `y - 1`. 109 """ 110 111 112 _sf_base_note = """ 113 114 For whole numbers `y`, 115 116 ``` 117 survival_function(y) := P[Y > y] 118 = 0, if y >= high, 119 = 1, if y < low, 120 = P[X <= y], otherwise. 121 ``` 122 123 Since `Y` only has mass at whole numbers, `P[Y <= y] = P[Y <= floor(y)]`. 124 This dictates that fractional `y` are first floored to a whole number, and 125 then above definition applies. 126 """ 127 128 _sf_note = _sf_base_note + """ 129 The base distribution's `cdf` method must be defined on `y - 1`. 130 """ 131 132 _log_sf_note = _sf_base_note + """ 133 The base distribution's `log_cdf` method must be defined on `y - 1`. 134 """ 135 136 137 class QuantizedDistribution(distributions.Distribution): 138 """Distribution representing the quantization `Y = ceiling(X)`. 139 140 #### Definition in Terms of Sampling 141 142 ``` 143 1. Draw X 144 2. Set Y <-- ceiling(X) 145 3. If Y < low, reset Y <-- low 146 4. If Y > high, reset Y <-- high 147 5. Return Y 148 ``` 149 150 #### Definition in Terms of the Probability Mass Function 151 152 Given scalar random variable `X`, we define a discrete random variable `Y` 153 supported on the integers as follows: 154 155 ``` 156 P[Y = j] := P[X <= low], if j == low, 157 := P[X > high - 1], j == high, 158 := 0, if j < low or j > high, 159 := P[j - 1 < X <= j], all other j. 160 ``` 161 162 Conceptually, without cutoffs, the quantization process partitions the real 163 line `R` into half open intervals, and identifies an integer `j` with the 164 right endpoints: 165 166 ``` 167 R = ... (-2, -1](-1, 0](0, 1](1, 2](2, 3](3, 4] ... 168 j = ... -1 0 1 2 3 4 ... 169 ``` 170 171 `P[Y = j]` is the mass of `X` within the `jth` interval. 172 If `low = 0`, and `high = 2`, then the intervals are redrawn 173 and `j` is re-assigned: 174 175 ``` 176 R = (-infty, 0](0, 1](1, infty) 177 j = 0 1 2 178 ``` 179 180 `P[Y = j]` is still the mass of `X` within the `jth` interval. 181 182 #### Examples 183 184 We illustrate a mixture of discretized logistic distributions 185 [(Salimans et al., 2017)][1]. This is used, for example, for capturing 16-bit 186 audio in WaveNet [(van den Oord et al., 2017)][2]. The values range in 187 a 1-D integer domain of `[0, 2**16-1]`, and the discretization captures 188 `P(x - 0.5 < X <= x + 0.5)` for all `x` in the domain excluding the endpoints. 189 The lowest value has probability `P(X <= 0.5)` and the highest value has 190 probability `P(2**16 - 1.5 < X)`. 191 192 Below we assume a `wavenet` function. It takes as `input` right-shifted audio 193 samples of shape `[..., sequence_length]`. It returns a real-valued tensor of 194 shape `[..., num_mixtures * 3]`, i.e., each mixture component has a `loc` and 195 `scale` parameter belonging to the logistic distribution, and a `logits` 196 parameter determining the unnormalized probability of that component. 197 198 ```python 199 import tensorflow_probability as tfp 200 tfd = tfp.distributions 201 tfb = tfp.bijectors 202 203 net = wavenet(inputs) 204 loc, unconstrained_scale, logits = tf.split(net, 205 num_or_size_splits=3, 206 axis=-1) 207 scale = tf.nn.softplus(unconstrained_scale) 208 209 # Form mixture of discretized logistic distributions. Note we shift the 210 # logistic distribution by -0.5. This lets the quantization capture "rounding" 211 # intervals, `(x-0.5, x+0.5]`, and not "ceiling" intervals, `(x-1, x]`. 212 discretized_logistic_dist = tfd.QuantizedDistribution( 213 distribution=tfd.TransformedDistribution( 214 distribution=tfd.Logistic(loc=loc, scale=scale), 215 bijector=tfb.AffineScalar(shift=-0.5)), 216 low=0., 217 high=2**16 - 1.) 218 mixture_dist = tfd.MixtureSameFamily( 219 mixture_distribution=tfd.Categorical(logits=logits), 220 components_distribution=discretized_logistic_dist) 221 222 neg_log_likelihood = -tf.reduce_sum(mixture_dist.log_prob(targets)) 223 train_op = tf.train.AdamOptimizer().minimize(neg_log_likelihood) 224 ``` 225 226 After instantiating `mixture_dist`, we illustrate maximum likelihood by 227 calculating its log-probability of audio samples as `target` and optimizing. 228 229 #### References 230 231 [1]: Tim Salimans, Andrej Karpathy, Xi Chen, and Diederik P. Kingma. 232 PixelCNN++: Improving the PixelCNN with discretized logistic mixture 233 likelihood and other modifications. 234 _International Conference on Learning Representations_, 2017. 235 https://arxiv.org/abs/1701.05517 236 [2]: Aaron van den Oord et al. Parallel WaveNet: Fast High-Fidelity Speech 237 Synthesis. _arXiv preprint arXiv:1711.10433_, 2017. 238 https://arxiv.org/abs/1711.10433 239 """ 240 241 @deprecation.deprecated( 242 "2018-10-01", 243 "The TensorFlow Distributions library has moved to " 244 "TensorFlow Probability " 245 "(https://github.com/tensorflow/probability). You " 246 "should update all references to use `tfp.distributions` " 247 "instead of `tf.contrib.distributions`.", 248 warn_once=True) 249 def __init__(self, 250 distribution, 251 low=None, 252 high=None, 253 validate_args=False, 254 name="QuantizedDistribution"): 255 """Construct a Quantized Distribution representing `Y = ceiling(X)`. 256 257 Some properties are inherited from the distribution defining `X`. Example: 258 `allow_nan_stats` is determined for this `QuantizedDistribution` by reading 259 the `distribution`. 260 261 Args: 262 distribution: The base distribution class to transform. Typically an 263 instance of `Distribution`. 264 low: `Tensor` with same `dtype` as this distribution and shape 265 able to be added to samples. Should be a whole number. Default `None`. 266 If provided, base distribution's `prob` should be defined at 267 `low`. 268 high: `Tensor` with same `dtype` as this distribution and shape 269 able to be added to samples. Should be a whole number. Default `None`. 270 If provided, base distribution's `prob` should be defined at 271 `high - 1`. 272 `high` must be strictly greater than `low`. 273 validate_args: Python `bool`, default `False`. When `True` distribution 274 parameters are checked for validity despite possibly degrading runtime 275 performance. When `False` invalid inputs may silently render incorrect 276 outputs. 277 name: Python `str` name prefixed to Ops created by this class. 278 279 Raises: 280 TypeError: If `dist_cls` is not a subclass of 281 `Distribution` or continuous. 282 NotImplementedError: If the base distribution does not implement `cdf`. 283 """ 284 parameters = dict(locals()) 285 values = ( 286 list(distribution.parameters.values()) + 287 [low, high]) 288 with ops.name_scope(name, values=values) as name: 289 self._dist = distribution 290 291 if low is not None: 292 low = ops.convert_to_tensor(low, name="low") 293 if high is not None: 294 high = ops.convert_to_tensor(high, name="high") 295 check_ops.assert_same_float_dtype( 296 tensors=[self.distribution, low, high]) 297 298 # We let QuantizedDistribution access _graph_parents since this class is 299 # more like a baseclass. 300 graph_parents = self._dist._graph_parents # pylint: disable=protected-access 301 302 checks = [] 303 if validate_args and low is not None and high is not None: 304 message = "low must be strictly less than high." 305 checks.append( 306 check_ops.assert_less( 307 low, high, message=message)) 308 self._validate_args = validate_args # self._check_integer uses this. 309 with ops.control_dependencies(checks if validate_args else []): 310 if low is not None: 311 self._low = self._check_integer(low) 312 graph_parents += [self._low] 313 else: 314 self._low = None 315 if high is not None: 316 self._high = self._check_integer(high) 317 graph_parents += [self._high] 318 else: 319 self._high = None 320 321 super(QuantizedDistribution, self).__init__( 322 dtype=self._dist.dtype, 323 reparameterization_type=distributions.NOT_REPARAMETERIZED, 324 validate_args=validate_args, 325 allow_nan_stats=self._dist.allow_nan_stats, 326 parameters=parameters, 327 graph_parents=graph_parents, 328 name=name) 329 330 @property 331 def distribution(self): 332 """Base distribution, p(x).""" 333 return self._dist 334 335 @property 336 def low(self): 337 """Lowest value that quantization returns.""" 338 return self._low 339 340 @property 341 def high(self): 342 """Highest value that quantization returns.""" 343 return self._high 344 345 def _batch_shape_tensor(self): 346 return self.distribution.batch_shape_tensor() 347 348 def _batch_shape(self): 349 return self.distribution.batch_shape 350 351 def _event_shape_tensor(self): 352 return self.distribution.event_shape_tensor() 353 354 def _event_shape(self): 355 return self.distribution.event_shape 356 357 def _sample_n(self, n, seed=None): 358 low = self._low 359 high = self._high 360 with ops.name_scope("transform"): 361 n = ops.convert_to_tensor(n, name="n") 362 x_samps = self.distribution.sample(n, seed=seed) 363 ones = array_ops.ones_like(x_samps) 364 365 # Snap values to the intervals (j - 1, j]. 366 result_so_far = math_ops.ceil(x_samps) 367 368 if low is not None: 369 result_so_far = array_ops.where(result_so_far < low, 370 low * ones, result_so_far) 371 372 if high is not None: 373 result_so_far = array_ops.where(result_so_far > high, 374 high * ones, result_so_far) 375 376 return result_so_far 377 378 @distribution_util.AppendDocstring(_log_prob_note) 379 def _log_prob(self, y): 380 if not hasattr(self.distribution, "_log_cdf"): 381 raise NotImplementedError( 382 "'log_prob' not implemented unless the base distribution implements " 383 "'log_cdf'") 384 y = self._check_integer(y) 385 try: 386 return self._log_prob_with_logsf_and_logcdf(y) 387 except NotImplementedError: 388 return self._log_prob_with_logcdf(y) 389 390 def _log_prob_with_logcdf(self, y): 391 return _logsum_expbig_minus_expsmall(self.log_cdf(y), self.log_cdf(y - 1)) 392 393 def _log_prob_with_logsf_and_logcdf(self, y): 394 """Compute log_prob(y) using log survival_function and cdf together.""" 395 # There are two options that would be equal if we had infinite precision: 396 # Log[ sf(y - 1) - sf(y) ] 397 # = Log[ exp{logsf(y - 1)} - exp{logsf(y)} ] 398 # Log[ cdf(y) - cdf(y - 1) ] 399 # = Log[ exp{logcdf(y)} - exp{logcdf(y - 1)} ] 400 logsf_y = self.log_survival_function(y) 401 logsf_y_minus_1 = self.log_survival_function(y - 1) 402 logcdf_y = self.log_cdf(y) 403 logcdf_y_minus_1 = self.log_cdf(y - 1) 404 405 # Important: Here we use select in a way such that no input is inf, this 406 # prevents the troublesome case where the output of select can be finite, 407 # but the output of grad(select) will be NaN. 408 409 # In either case, we are doing Log[ exp{big} - exp{small} ] 410 # We want to use the sf items precisely when we are on the right side of the 411 # median, which occurs when logsf_y < logcdf_y. 412 big = array_ops.where(logsf_y < logcdf_y, logsf_y_minus_1, logcdf_y) 413 small = array_ops.where(logsf_y < logcdf_y, logsf_y, logcdf_y_minus_1) 414 415 return _logsum_expbig_minus_expsmall(big, small) 416 417 @distribution_util.AppendDocstring(_prob_note) 418 def _prob(self, y): 419 if not hasattr(self.distribution, "_cdf"): 420 raise NotImplementedError( 421 "'prob' not implemented unless the base distribution implements " 422 "'cdf'") 423 y = self._check_integer(y) 424 try: 425 return self._prob_with_sf_and_cdf(y) 426 except NotImplementedError: 427 return self._prob_with_cdf(y) 428 429 def _prob_with_cdf(self, y): 430 return self.cdf(y) - self.cdf(y - 1) 431 432 def _prob_with_sf_and_cdf(self, y): 433 # There are two options that would be equal if we had infinite precision: 434 # sf(y - 1) - sf(y) 435 # cdf(y) - cdf(y - 1) 436 sf_y = self.survival_function(y) 437 sf_y_minus_1 = self.survival_function(y - 1) 438 cdf_y = self.cdf(y) 439 cdf_y_minus_1 = self.cdf(y - 1) 440 441 # sf_prob has greater precision iff we're on the right side of the median. 442 return array_ops.where( 443 sf_y < cdf_y, # True iff we're on the right side of the median. 444 sf_y_minus_1 - sf_y, 445 cdf_y - cdf_y_minus_1) 446 447 @distribution_util.AppendDocstring(_log_cdf_note) 448 def _log_cdf(self, y): 449 low = self._low 450 high = self._high 451 452 # Recall the promise: 453 # cdf(y) := P[Y <= y] 454 # = 1, if y >= high, 455 # = 0, if y < low, 456 # = P[X <= y], otherwise. 457 458 # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in 459 # between. 460 j = math_ops.floor(y) 461 462 result_so_far = self.distribution.log_cdf(j) 463 464 # Broadcast, because it's possible that this is a single distribution being 465 # evaluated on a number of samples, or something like that. 466 j += array_ops.zeros_like(result_so_far) 467 468 # Re-define values at the cutoffs. 469 if low is not None: 470 neg_inf = -np.inf * array_ops.ones_like(result_so_far) 471 result_so_far = array_ops.where(j < low, neg_inf, result_so_far) 472 if high is not None: 473 result_so_far = array_ops.where(j >= high, 474 array_ops.zeros_like(result_so_far), 475 result_so_far) 476 477 return result_so_far 478 479 @distribution_util.AppendDocstring(_cdf_note) 480 def _cdf(self, y): 481 low = self._low 482 high = self._high 483 484 # Recall the promise: 485 # cdf(y) := P[Y <= y] 486 # = 1, if y >= high, 487 # = 0, if y < low, 488 # = P[X <= y], otherwise. 489 490 # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in 491 # between. 492 j = math_ops.floor(y) 493 494 # P[X <= j], used when low < X < high. 495 result_so_far = self.distribution.cdf(j) 496 497 # Broadcast, because it's possible that this is a single distribution being 498 # evaluated on a number of samples, or something like that. 499 j += array_ops.zeros_like(result_so_far) 500 501 # Re-define values at the cutoffs. 502 if low is not None: 503 result_so_far = array_ops.where(j < low, 504 array_ops.zeros_like(result_so_far), 505 result_so_far) 506 if high is not None: 507 result_so_far = array_ops.where(j >= high, 508 array_ops.ones_like(result_so_far), 509 result_so_far) 510 511 return result_so_far 512 513 @distribution_util.AppendDocstring(_log_sf_note) 514 def _log_survival_function(self, y): 515 low = self._low 516 high = self._high 517 518 # Recall the promise: 519 # survival_function(y) := P[Y > y] 520 # = 0, if y >= high, 521 # = 1, if y < low, 522 # = P[X > y], otherwise. 523 524 # P[Y > j] = P[ceiling(Y) > j] since mass is only at integers, not in 525 # between. 526 j = math_ops.ceil(y) 527 528 # P[X > j], used when low < X < high. 529 result_so_far = self.distribution.log_survival_function(j) 530 531 # Broadcast, because it's possible that this is a single distribution being 532 # evaluated on a number of samples, or something like that. 533 j += array_ops.zeros_like(result_so_far) 534 535 # Re-define values at the cutoffs. 536 if low is not None: 537 result_so_far = array_ops.where(j < low, 538 array_ops.zeros_like(result_so_far), 539 result_so_far) 540 if high is not None: 541 neg_inf = -np.inf * array_ops.ones_like(result_so_far) 542 result_so_far = array_ops.where(j >= high, neg_inf, result_so_far) 543 544 return result_so_far 545 546 @distribution_util.AppendDocstring(_sf_note) 547 def _survival_function(self, y): 548 low = self._low 549 high = self._high 550 551 # Recall the promise: 552 # survival_function(y) := P[Y > y] 553 # = 0, if y >= high, 554 # = 1, if y < low, 555 # = P[X > y], otherwise. 556 557 # P[Y > j] = P[ceiling(Y) > j] since mass is only at integers, not in 558 # between. 559 j = math_ops.ceil(y) 560 561 # P[X > j], used when low < X < high. 562 result_so_far = self.distribution.survival_function(j) 563 564 # Broadcast, because it's possible that this is a single distribution being 565 # evaluated on a number of samples, or something like that. 566 j += array_ops.zeros_like(result_so_far) 567 568 # Re-define values at the cutoffs. 569 if low is not None: 570 result_so_far = array_ops.where(j < low, 571 array_ops.ones_like(result_so_far), 572 result_so_far) 573 if high is not None: 574 result_so_far = array_ops.where(j >= high, 575 array_ops.zeros_like(result_so_far), 576 result_so_far) 577 578 return result_so_far 579 580 def _check_integer(self, value): 581 with ops.name_scope("check_integer", values=[value]): 582 value = ops.convert_to_tensor(value, name="value") 583 if not self.validate_args: 584 return value 585 dependencies = [distribution_util.assert_integer_form( 586 value, message="value has non-integer components.")] 587 return control_flow_ops.with_dependencies(dependencies, value) 588