1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Quantized distribution.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.framework import ops 24 from tensorflow.python.ops import array_ops 25 from tensorflow.python.ops import check_ops 26 from tensorflow.python.ops import control_flow_ops 27 from tensorflow.python.ops import math_ops 28 from tensorflow.python.ops.distributions import distribution as distributions 29 from tensorflow.python.ops.distributions import util as distribution_util 30 31 __all__ = ["QuantizedDistribution"] 32 33 34 def _logsum_expbig_minus_expsmall(big, small): 35 """Stable evaluation of `Log[exp{big} - exp{small}]`. 36 37 To work correctly, we should have the pointwise relation: `small <= big`. 38 39 Args: 40 big: Floating-point `Tensor` 41 small: Floating-point `Tensor` with same `dtype` as `big` and broadcastable 42 shape. 43 44 Returns: 45 `Tensor` of same `dtype` of `big` and broadcast shape. 46 """ 47 with ops.name_scope("logsum_expbig_minus_expsmall", values=[small, big]): 48 return math_ops.log(1. - math_ops.exp(small - big)) + big 49 50 51 _prob_base_note = """ 52 For whole numbers `y`, 53 54 ``` 55 P[Y = y] := P[X <= low], if y == low, 56 := P[X > high - 1], y == high, 57 := 0, if j < low or y > high, 58 := P[y - 1 < X <= y], all other y. 59 ``` 60 61 """ 62 63 _prob_note = _prob_base_note + """ 64 The base distribution's `cdf` method must be defined on `y - 1`. If the 65 base distribution has a `survival_function` method, results will be more 66 accurate for large values of `y`, and in this case the `survival_function` must 67 also be defined on `y - 1`. 68 """ 69 70 _log_prob_note = _prob_base_note + """ 71 The base distribution's `log_cdf` method must be defined on `y - 1`. If the 72 base distribution has a `log_survival_function` method results will be more 73 accurate for large values of `y`, and in this case the `log_survival_function` 74 must also be defined on `y - 1`. 75 """ 76 77 78 _cdf_base_note = """ 79 80 For whole numbers `y`, 81 82 ``` 83 cdf(y) := P[Y <= y] 84 = 1, if y >= high, 85 = 0, if y < low, 86 = P[X <= y], otherwise. 87 ``` 88 89 Since `Y` only has mass at whole numbers, `P[Y <= y] = P[Y <= floor(y)]`. 90 This dictates that fractional `y` are first floored to a whole number, and 91 then above definition applies. 92 """ 93 94 _cdf_note = _cdf_base_note + """ 95 The base distribution's `cdf` method must be defined on `y - 1`. 96 """ 97 98 _log_cdf_note = _cdf_base_note + """ 99 The base distribution's `log_cdf` method must be defined on `y - 1`. 100 """ 101 102 103 _sf_base_note = """ 104 105 For whole numbers `y`, 106 107 ``` 108 survival_function(y) := P[Y > y] 109 = 0, if y >= high, 110 = 1, if y < low, 111 = P[X <= y], otherwise. 112 ``` 113 114 Since `Y` only has mass at whole numbers, `P[Y <= y] = P[Y <= floor(y)]`. 115 This dictates that fractional `y` are first floored to a whole number, and 116 then above definition applies. 117 """ 118 119 _sf_note = _sf_base_note + """ 120 The base distribution's `cdf` method must be defined on `y - 1`. 121 """ 122 123 _log_sf_note = _sf_base_note + """ 124 The base distribution's `log_cdf` method must be defined on `y - 1`. 125 """ 126 127 128 class QuantizedDistribution(distributions.Distribution): 129 """Distribution representing the quantization `Y = ceiling(X)`. 130 131 #### Definition in terms of sampling. 132 133 ``` 134 1. Draw X 135 2. Set Y <-- ceiling(X) 136 3. If Y < low, reset Y <-- low 137 4. If Y > high, reset Y <-- high 138 5. Return Y 139 ``` 140 141 #### Definition in terms of the probability mass function. 142 143 Given scalar random variable `X`, we define a discrete random variable `Y` 144 supported on the integers as follows: 145 146 ``` 147 P[Y = j] := P[X <= low], if j == low, 148 := P[X > high - 1], j == high, 149 := 0, if j < low or j > high, 150 := P[j - 1 < X <= j], all other j. 151 ``` 152 153 Conceptually, without cutoffs, the quantization process partitions the real 154 line `R` into half open intervals, and identifies an integer `j` with the 155 right endpoints: 156 157 ``` 158 R = ... (-2, -1](-1, 0](0, 1](1, 2](2, 3](3, 4] ... 159 j = ... -1 0 1 2 3 4 ... 160 ``` 161 162 `P[Y = j]` is the mass of `X` within the `jth` interval. 163 If `low = 0`, and `high = 2`, then the intervals are redrawn 164 and `j` is re-assigned: 165 166 ``` 167 R = (-infty, 0](0, 1](1, infty) 168 j = 0 1 2 169 ``` 170 171 `P[Y = j]` is still the mass of `X` within the `jth` interval. 172 173 #### Caveats 174 175 Since evaluation of each `P[Y = j]` involves a cdf evaluation (rather than 176 a closed form function such as for a Poisson), computations such as mean and 177 entropy are better done with samples or approximations, and are not 178 implemented by this class. 179 """ 180 181 def __init__(self, 182 distribution, 183 low=None, 184 high=None, 185 validate_args=False, 186 name="QuantizedDistribution"): 187 """Construct a Quantized Distribution representing `Y = ceiling(X)`. 188 189 Some properties are inherited from the distribution defining `X`. Example: 190 `allow_nan_stats` is determined for this `QuantizedDistribution` by reading 191 the `distribution`. 192 193 Args: 194 distribution: The base distribution class to transform. Typically an 195 instance of `Distribution`. 196 low: `Tensor` with same `dtype` as this distribution and shape 197 able to be added to samples. Should be a whole number. Default `None`. 198 If provided, base distribution's `prob` should be defined at 199 `low`. 200 high: `Tensor` with same `dtype` as this distribution and shape 201 able to be added to samples. Should be a whole number. Default `None`. 202 If provided, base distribution's `prob` should be defined at 203 `high - 1`. 204 `high` must be strictly greater than `low`. 205 validate_args: Python `bool`, default `False`. When `True` distribution 206 parameters are checked for validity despite possibly degrading runtime 207 performance. When `False` invalid inputs may silently render incorrect 208 outputs. 209 name: Python `str` name prefixed to Ops created by this class. 210 211 Raises: 212 TypeError: If `dist_cls` is not a subclass of 213 `Distribution` or continuous. 214 NotImplementedError: If the base distribution does not implement `cdf`. 215 """ 216 parameters = locals() 217 values = ( 218 list(distribution.parameters.values()) + 219 [low, high]) 220 with ops.name_scope(name, values=values): 221 self._dist = distribution 222 223 if low is not None: 224 low = ops.convert_to_tensor(low, name="low") 225 if high is not None: 226 high = ops.convert_to_tensor(high, name="high") 227 check_ops.assert_same_float_dtype( 228 tensors=[self.distribution, low, high]) 229 230 # We let QuantizedDistribution access _graph_parents since this class is 231 # more like a baseclass. 232 graph_parents = self._dist._graph_parents # pylint: disable=protected-access 233 234 checks = [] 235 if validate_args and low is not None and high is not None: 236 message = "low must be strictly less than high." 237 checks.append( 238 check_ops.assert_less( 239 low, high, message=message)) 240 self._validate_args = validate_args # self._check_integer uses this. 241 with ops.control_dependencies(checks if validate_args else []): 242 if low is not None: 243 self._low = self._check_integer(low) 244 graph_parents += [self._low] 245 else: 246 self._low = None 247 if high is not None: 248 self._high = self._check_integer(high) 249 graph_parents += [self._high] 250 else: 251 self._high = None 252 253 super(QuantizedDistribution, self).__init__( 254 dtype=self._dist.dtype, 255 reparameterization_type=distributions.NOT_REPARAMETERIZED, 256 validate_args=validate_args, 257 allow_nan_stats=self._dist.allow_nan_stats, 258 parameters=parameters, 259 graph_parents=graph_parents, 260 name=name) 261 262 def _batch_shape_tensor(self): 263 return self.distribution.batch_shape_tensor() 264 265 def _batch_shape(self): 266 return self.distribution.batch_shape 267 268 def _event_shape_tensor(self): 269 return self.distribution.event_shape_tensor() 270 271 def _event_shape(self): 272 return self.distribution.event_shape 273 274 def _sample_n(self, n, seed=None): 275 low = self._low 276 high = self._high 277 with ops.name_scope("transform"): 278 n = ops.convert_to_tensor(n, name="n") 279 x_samps = self.distribution.sample(n, seed=seed) 280 ones = array_ops.ones_like(x_samps) 281 282 # Snap values to the intervals (j - 1, j]. 283 result_so_far = math_ops.ceil(x_samps) 284 285 if low is not None: 286 result_so_far = array_ops.where(result_so_far < low, 287 low * ones, result_so_far) 288 289 if high is not None: 290 result_so_far = array_ops.where(result_so_far > high, 291 high * ones, result_so_far) 292 293 return result_so_far 294 295 @distribution_util.AppendDocstring(_log_prob_note) 296 def _log_prob(self, y): 297 if not hasattr(self.distribution, "_log_cdf"): 298 raise NotImplementedError( 299 "'log_prob' not implemented unless the base distribution implements " 300 "'log_cdf'") 301 y = self._check_integer(y) 302 try: 303 return self._log_prob_with_logsf_and_logcdf(y) 304 except NotImplementedError: 305 return self._log_prob_with_logcdf(y) 306 307 def _log_prob_with_logcdf(self, y): 308 return _logsum_expbig_minus_expsmall(self.log_cdf(y), self.log_cdf(y - 1)) 309 310 def _log_prob_with_logsf_and_logcdf(self, y): 311 """Compute log_prob(y) using log survival_function and cdf together.""" 312 # There are two options that would be equal if we had infinite precision: 313 # Log[ sf(y - 1) - sf(y) ] 314 # = Log[ exp{logsf(y - 1)} - exp{logsf(y)} ] 315 # Log[ cdf(y) - cdf(y - 1) ] 316 # = Log[ exp{logcdf(y)} - exp{logcdf(y - 1)} ] 317 logsf_y = self.log_survival_function(y) 318 logsf_y_minus_1 = self.log_survival_function(y - 1) 319 logcdf_y = self.log_cdf(y) 320 logcdf_y_minus_1 = self.log_cdf(y - 1) 321 322 # Important: Here we use select in a way such that no input is inf, this 323 # prevents the troublesome case where the output of select can be finite, 324 # but the output of grad(select) will be NaN. 325 326 # In either case, we are doing Log[ exp{big} - exp{small} ] 327 # We want to use the sf items precisely when we are on the right side of the 328 # median, which occurs when logsf_y < logcdf_y. 329 big = array_ops.where(logsf_y < logcdf_y, logsf_y_minus_1, logcdf_y) 330 small = array_ops.where(logsf_y < logcdf_y, logsf_y, logcdf_y_minus_1) 331 332 return _logsum_expbig_minus_expsmall(big, small) 333 334 @distribution_util.AppendDocstring(_prob_note) 335 def _prob(self, y): 336 if not hasattr(self.distribution, "_cdf"): 337 raise NotImplementedError( 338 "'prob' not implemented unless the base distribution implements " 339 "'cdf'") 340 y = self._check_integer(y) 341 try: 342 return self._prob_with_sf_and_cdf(y) 343 except NotImplementedError: 344 return self._prob_with_cdf(y) 345 346 def _prob_with_cdf(self, y): 347 return self.cdf(y) - self.cdf(y - 1) 348 349 def _prob_with_sf_and_cdf(self, y): 350 # There are two options that would be equal if we had infinite precision: 351 # sf(y - 1) - sf(y) 352 # cdf(y) - cdf(y - 1) 353 sf_y = self.survival_function(y) 354 sf_y_minus_1 = self.survival_function(y - 1) 355 cdf_y = self.cdf(y) 356 cdf_y_minus_1 = self.cdf(y - 1) 357 358 # sf_prob has greater precision iff we're on the right side of the median. 359 return array_ops.where( 360 sf_y < cdf_y, # True iff we're on the right side of the median. 361 sf_y_minus_1 - sf_y, 362 cdf_y - cdf_y_minus_1) 363 364 @distribution_util.AppendDocstring(_log_cdf_note) 365 def _log_cdf(self, y): 366 low = self._low 367 high = self._high 368 369 # Recall the promise: 370 # cdf(y) := P[Y <= y] 371 # = 1, if y >= high, 372 # = 0, if y < low, 373 # = P[X <= y], otherwise. 374 375 # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in 376 # between. 377 j = math_ops.floor(y) 378 379 result_so_far = self.distribution.log_cdf(j) 380 381 # Broadcast, because it's possible that this is a single distribution being 382 # evaluated on a number of samples, or something like that. 383 j += array_ops.zeros_like(result_so_far) 384 385 # Re-define values at the cutoffs. 386 if low is not None: 387 neg_inf = -np.inf * array_ops.ones_like(result_so_far) 388 result_so_far = array_ops.where(j < low, neg_inf, result_so_far) 389 if high is not None: 390 result_so_far = array_ops.where(j >= high, 391 array_ops.zeros_like(result_so_far), 392 result_so_far) 393 394 return result_so_far 395 396 @distribution_util.AppendDocstring(_cdf_note) 397 def _cdf(self, y): 398 low = self._low 399 high = self._high 400 401 # Recall the promise: 402 # cdf(y) := P[Y <= y] 403 # = 1, if y >= high, 404 # = 0, if y < low, 405 # = P[X <= y], otherwise. 406 407 # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in 408 # between. 409 j = math_ops.floor(y) 410 411 # P[X <= j], used when low < X < high. 412 result_so_far = self.distribution.cdf(j) 413 414 # Broadcast, because it's possible that this is a single distribution being 415 # evaluated on a number of samples, or something like that. 416 j += array_ops.zeros_like(result_so_far) 417 418 # Re-define values at the cutoffs. 419 if low is not None: 420 result_so_far = array_ops.where(j < low, 421 array_ops.zeros_like(result_so_far), 422 result_so_far) 423 if high is not None: 424 result_so_far = array_ops.where(j >= high, 425 array_ops.ones_like(result_so_far), 426 result_so_far) 427 428 return result_so_far 429 430 @distribution_util.AppendDocstring(_log_sf_note) 431 def _log_survival_function(self, y): 432 low = self._low 433 high = self._high 434 435 # Recall the promise: 436 # survival_function(y) := P[Y > y] 437 # = 0, if y >= high, 438 # = 1, if y < low, 439 # = P[X > y], otherwise. 440 441 # P[Y > j] = P[ceiling(Y) > j] since mass is only at integers, not in 442 # between. 443 j = math_ops.ceil(y) 444 445 # P[X > j], used when low < X < high. 446 result_so_far = self.distribution.log_survival_function(j) 447 448 # Broadcast, because it's possible that this is a single distribution being 449 # evaluated on a number of samples, or something like that. 450 j += array_ops.zeros_like(result_so_far) 451 452 # Re-define values at the cutoffs. 453 if low is not None: 454 result_so_far = array_ops.where(j < low, 455 array_ops.zeros_like(result_so_far), 456 result_so_far) 457 if high is not None: 458 neg_inf = -np.inf * array_ops.ones_like(result_so_far) 459 result_so_far = array_ops.where(j >= high, neg_inf, result_so_far) 460 461 return result_so_far 462 463 @distribution_util.AppendDocstring(_sf_note) 464 def _survival_function(self, y): 465 low = self._low 466 high = self._high 467 468 # Recall the promise: 469 # survival_function(y) := P[Y > y] 470 # = 0, if y >= high, 471 # = 1, if y < low, 472 # = P[X > y], otherwise. 473 474 # P[Y > j] = P[ceiling(Y) > j] since mass is only at integers, not in 475 # between. 476 j = math_ops.ceil(y) 477 478 # P[X > j], used when low < X < high. 479 result_so_far = self.distribution.survival_function(j) 480 481 # Broadcast, because it's possible that this is a single distribution being 482 # evaluated on a number of samples, or something like that. 483 j += array_ops.zeros_like(result_so_far) 484 485 # Re-define values at the cutoffs. 486 if low is not None: 487 result_so_far = array_ops.where(j < low, 488 array_ops.ones_like(result_so_far), 489 result_so_far) 490 if high is not None: 491 result_so_far = array_ops.where(j >= high, 492 array_ops.zeros_like(result_so_far), 493 result_so_far) 494 495 return result_so_far 496 497 def _check_integer(self, value): 498 with ops.name_scope("check_integer", values=[value]): 499 value = ops.convert_to_tensor(value, name="value") 500 if not self.validate_args: 501 return value 502 dependencies = [distribution_util.assert_integer_form( 503 value, message="value has non-integer components.")] 504 return control_flow_ops.with_dependencies(dependencies, value) 505 506 @property 507 def distribution(self): 508 """Base distribution, p(x).""" 509 return self._dist 510