Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Quantized distribution."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.framework import ops
     24 from tensorflow.python.ops import array_ops
     25 from tensorflow.python.ops import check_ops
     26 from tensorflow.python.ops import control_flow_ops
     27 from tensorflow.python.ops import math_ops
     28 from tensorflow.python.ops.distributions import distribution as distributions
     29 from tensorflow.python.ops.distributions import util as distribution_util
     30 
     31 __all__ = ["QuantizedDistribution"]
     32 
     33 
     34 def _logsum_expbig_minus_expsmall(big, small):
     35   """Stable evaluation of `Log[exp{big} - exp{small}]`.
     36 
     37   To work correctly, we should have the pointwise relation:  `small <= big`.
     38 
     39   Args:
     40     big: Floating-point `Tensor`
     41     small: Floating-point `Tensor` with same `dtype` as `big` and broadcastable
     42       shape.
     43 
     44   Returns:
     45     `Tensor` of same `dtype` of `big` and broadcast shape.
     46   """
     47   with ops.name_scope("logsum_expbig_minus_expsmall", values=[small, big]):
     48     return math_ops.log(1. - math_ops.exp(small - big)) + big
     49 
     50 
     51 _prob_base_note = """
     52 For whole numbers `y`,
     53 
     54 ```
     55 P[Y = y] := P[X <= low],  if y == low,
     56          := P[X > high - 1],  y == high,
     57          := 0, if j < low or y > high,
     58          := P[y - 1 < X <= y],  all other y.
     59 ```
     60 
     61 """
     62 
     63 _prob_note = _prob_base_note + """
     64 The base distribution's `cdf` method must be defined on `y - 1`. If the
     65 base distribution has a `survival_function` method, results will be more
     66 accurate for large values of `y`, and in this case the `survival_function` must
     67 also be defined on `y - 1`.
     68 """
     69 
     70 _log_prob_note = _prob_base_note + """
     71 The base distribution's `log_cdf` method must be defined on `y - 1`. If the
     72 base distribution has a `log_survival_function` method results will be more
     73 accurate for large values of `y`, and in this case the `log_survival_function`
     74 must also be defined on `y - 1`.
     75 """
     76 
     77 
     78 _cdf_base_note = """
     79 
     80 For whole numbers `y`,
     81 
     82 ```
     83 cdf(y) := P[Y <= y]
     84         = 1, if y >= high,
     85         = 0, if y < low,
     86         = P[X <= y], otherwise.
     87 ```
     88 
     89 Since `Y` only has mass at whole numbers, `P[Y <= y] = P[Y <= floor(y)]`.
     90 This dictates that fractional `y` are first floored to a whole number, and
     91 then above definition applies.
     92 """
     93 
     94 _cdf_note = _cdf_base_note + """
     95 The base distribution's `cdf` method must be defined on `y - 1`.
     96 """
     97 
     98 _log_cdf_note = _cdf_base_note + """
     99 The base distribution's `log_cdf` method must be defined on `y - 1`.
    100 """
    101 
    102 
    103 _sf_base_note = """
    104 
    105 For whole numbers `y`,
    106 
    107 ```
    108 survival_function(y) := P[Y > y]
    109                       = 0, if y >= high,
    110                       = 1, if y < low,
    111                       = P[X <= y], otherwise.
    112 ```
    113 
    114 Since `Y` only has mass at whole numbers, `P[Y <= y] = P[Y <= floor(y)]`.
    115 This dictates that fractional `y` are first floored to a whole number, and
    116 then above definition applies.
    117 """
    118 
    119 _sf_note = _sf_base_note + """
    120 The base distribution's `cdf` method must be defined on `y - 1`.
    121 """
    122 
    123 _log_sf_note = _sf_base_note + """
    124 The base distribution's `log_cdf` method must be defined on `y - 1`.
    125 """
    126 
    127 
    128 class QuantizedDistribution(distributions.Distribution):
    129   """Distribution representing the quantization `Y = ceiling(X)`.
    130 
    131   #### Definition in terms of sampling.
    132 
    133   ```
    134   1. Draw X
    135   2. Set Y <-- ceiling(X)
    136   3. If Y < low, reset Y <-- low
    137   4. If Y > high, reset Y <-- high
    138   5. Return Y
    139   ```
    140 
    141   #### Definition in terms of the probability mass function.
    142 
    143   Given scalar random variable `X`, we define a discrete random variable `Y`
    144   supported on the integers as follows:
    145 
    146   ```
    147   P[Y = j] := P[X <= low],  if j == low,
    148            := P[X > high - 1],  j == high,
    149            := 0, if j < low or j > high,
    150            := P[j - 1 < X <= j],  all other j.
    151   ```
    152 
    153   Conceptually, without cutoffs, the quantization process partitions the real
    154   line `R` into half open intervals, and identifies an integer `j` with the
    155   right endpoints:
    156 
    157   ```
    158   R = ... (-2, -1](-1, 0](0, 1](1, 2](2, 3](3, 4] ...
    159   j = ...      -1      0     1     2     3     4  ...
    160   ```
    161 
    162   `P[Y = j]` is the mass of `X` within the `jth` interval.
    163   If `low = 0`, and `high = 2`, then the intervals are redrawn
    164   and `j` is re-assigned:
    165 
    166   ```
    167   R = (-infty, 0](0, 1](1, infty)
    168   j =          0     1     2
    169   ```
    170 
    171   `P[Y = j]` is still the mass of `X` within the `jth` interval.
    172 
    173   #### Caveats
    174 
    175   Since evaluation of each `P[Y = j]` involves a cdf evaluation (rather than
    176   a closed form function such as for a Poisson), computations such as mean and
    177   entropy are better done with samples or approximations, and are not
    178   implemented by this class.
    179   """
    180 
    181   def __init__(self,
    182                distribution,
    183                low=None,
    184                high=None,
    185                validate_args=False,
    186                name="QuantizedDistribution"):
    187     """Construct a Quantized Distribution representing `Y = ceiling(X)`.
    188 
    189     Some properties are inherited from the distribution defining `X`. Example:
    190     `allow_nan_stats` is determined for this `QuantizedDistribution` by reading
    191     the `distribution`.
    192 
    193     Args:
    194       distribution:  The base distribution class to transform. Typically an
    195         instance of `Distribution`.
    196       low: `Tensor` with same `dtype` as this distribution and shape
    197         able to be added to samples. Should be a whole number. Default `None`.
    198         If provided, base distribution's `prob` should be defined at
    199         `low`.
    200       high: `Tensor` with same `dtype` as this distribution and shape
    201         able to be added to samples. Should be a whole number. Default `None`.
    202         If provided, base distribution's `prob` should be defined at
    203         `high - 1`.
    204         `high` must be strictly greater than `low`.
    205       validate_args: Python `bool`, default `False`. When `True` distribution
    206         parameters are checked for validity despite possibly degrading runtime
    207         performance. When `False` invalid inputs may silently render incorrect
    208         outputs.
    209       name: Python `str` name prefixed to Ops created by this class.
    210 
    211     Raises:
    212       TypeError: If `dist_cls` is not a subclass of
    213           `Distribution` or continuous.
    214       NotImplementedError:  If the base distribution does not implement `cdf`.
    215     """
    216     parameters = locals()
    217     values = (
    218         list(distribution.parameters.values()) +
    219         [low, high])
    220     with ops.name_scope(name, values=values):
    221       self._dist = distribution
    222 
    223       if low is not None:
    224         low = ops.convert_to_tensor(low, name="low")
    225       if high is not None:
    226         high = ops.convert_to_tensor(high, name="high")
    227       check_ops.assert_same_float_dtype(
    228           tensors=[self.distribution, low, high])
    229 
    230       # We let QuantizedDistribution access _graph_parents since this class is
    231       # more like a baseclass.
    232       graph_parents = self._dist._graph_parents  # pylint: disable=protected-access
    233 
    234       checks = []
    235       if validate_args and low is not None and high is not None:
    236         message = "low must be strictly less than high."
    237         checks.append(
    238             check_ops.assert_less(
    239                 low, high, message=message))
    240       self._validate_args = validate_args  # self._check_integer uses this.
    241       with ops.control_dependencies(checks if validate_args else []):
    242         if low is not None:
    243           self._low = self._check_integer(low)
    244           graph_parents += [self._low]
    245         else:
    246           self._low = None
    247         if high is not None:
    248           self._high = self._check_integer(high)
    249           graph_parents += [self._high]
    250         else:
    251           self._high = None
    252 
    253     super(QuantizedDistribution, self).__init__(
    254         dtype=self._dist.dtype,
    255         reparameterization_type=distributions.NOT_REPARAMETERIZED,
    256         validate_args=validate_args,
    257         allow_nan_stats=self._dist.allow_nan_stats,
    258         parameters=parameters,
    259         graph_parents=graph_parents,
    260         name=name)
    261 
    262   def _batch_shape_tensor(self):
    263     return self.distribution.batch_shape_tensor()
    264 
    265   def _batch_shape(self):
    266     return self.distribution.batch_shape
    267 
    268   def _event_shape_tensor(self):
    269     return self.distribution.event_shape_tensor()
    270 
    271   def _event_shape(self):
    272     return self.distribution.event_shape
    273 
    274   def _sample_n(self, n, seed=None):
    275     low = self._low
    276     high = self._high
    277     with ops.name_scope("transform"):
    278       n = ops.convert_to_tensor(n, name="n")
    279       x_samps = self.distribution.sample(n, seed=seed)
    280       ones = array_ops.ones_like(x_samps)
    281 
    282       # Snap values to the intervals (j - 1, j].
    283       result_so_far = math_ops.ceil(x_samps)
    284 
    285       if low is not None:
    286         result_so_far = array_ops.where(result_so_far < low,
    287                                         low * ones, result_so_far)
    288 
    289       if high is not None:
    290         result_so_far = array_ops.where(result_so_far > high,
    291                                         high * ones, result_so_far)
    292 
    293       return result_so_far
    294 
    295   @distribution_util.AppendDocstring(_log_prob_note)
    296   def _log_prob(self, y):
    297     if not hasattr(self.distribution, "_log_cdf"):
    298       raise NotImplementedError(
    299           "'log_prob' not implemented unless the base distribution implements "
    300           "'log_cdf'")
    301     y = self._check_integer(y)
    302     try:
    303       return self._log_prob_with_logsf_and_logcdf(y)
    304     except NotImplementedError:
    305       return self._log_prob_with_logcdf(y)
    306 
    307   def _log_prob_with_logcdf(self, y):
    308     return _logsum_expbig_minus_expsmall(self.log_cdf(y), self.log_cdf(y - 1))
    309 
    310   def _log_prob_with_logsf_and_logcdf(self, y):
    311     """Compute log_prob(y) using log survival_function and cdf together."""
    312     # There are two options that would be equal if we had infinite precision:
    313     # Log[ sf(y - 1) - sf(y) ]
    314     #   = Log[ exp{logsf(y - 1)} - exp{logsf(y)} ]
    315     # Log[ cdf(y) - cdf(y - 1) ]
    316     #   = Log[ exp{logcdf(y)} - exp{logcdf(y - 1)} ]
    317     logsf_y = self.log_survival_function(y)
    318     logsf_y_minus_1 = self.log_survival_function(y - 1)
    319     logcdf_y = self.log_cdf(y)
    320     logcdf_y_minus_1 = self.log_cdf(y - 1)
    321 
    322     # Important:  Here we use select in a way such that no input is inf, this
    323     # prevents the troublesome case where the output of select can be finite,
    324     # but the output of grad(select) will be NaN.
    325 
    326     # In either case, we are doing Log[ exp{big} - exp{small} ]
    327     # We want to use the sf items precisely when we are on the right side of the
    328     # median, which occurs when logsf_y < logcdf_y.
    329     big = array_ops.where(logsf_y < logcdf_y, logsf_y_minus_1, logcdf_y)
    330     small = array_ops.where(logsf_y < logcdf_y, logsf_y, logcdf_y_minus_1)
    331 
    332     return _logsum_expbig_minus_expsmall(big, small)
    333 
    334   @distribution_util.AppendDocstring(_prob_note)
    335   def _prob(self, y):
    336     if not hasattr(self.distribution, "_cdf"):
    337       raise NotImplementedError(
    338           "'prob' not implemented unless the base distribution implements "
    339           "'cdf'")
    340     y = self._check_integer(y)
    341     try:
    342       return self._prob_with_sf_and_cdf(y)
    343     except NotImplementedError:
    344       return self._prob_with_cdf(y)
    345 
    346   def _prob_with_cdf(self, y):
    347     return self.cdf(y) - self.cdf(y - 1)
    348 
    349   def _prob_with_sf_and_cdf(self, y):
    350     # There are two options that would be equal if we had infinite precision:
    351     # sf(y - 1) - sf(y)
    352     # cdf(y) - cdf(y - 1)
    353     sf_y = self.survival_function(y)
    354     sf_y_minus_1 = self.survival_function(y - 1)
    355     cdf_y = self.cdf(y)
    356     cdf_y_minus_1 = self.cdf(y - 1)
    357 
    358     # sf_prob has greater precision iff we're on the right side of the median.
    359     return array_ops.where(
    360         sf_y < cdf_y,  # True iff we're on the right side of the median.
    361         sf_y_minus_1 - sf_y,
    362         cdf_y - cdf_y_minus_1)
    363 
    364   @distribution_util.AppendDocstring(_log_cdf_note)
    365   def _log_cdf(self, y):
    366     low = self._low
    367     high = self._high
    368 
    369     # Recall the promise:
    370     # cdf(y) := P[Y <= y]
    371     #         = 1, if y >= high,
    372     #         = 0, if y < low,
    373     #         = P[X <= y], otherwise.
    374 
    375     # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in
    376     # between.
    377     j = math_ops.floor(y)
    378 
    379     result_so_far = self.distribution.log_cdf(j)
    380 
    381     # Broadcast, because it's possible that this is a single distribution being
    382     # evaluated on a number of samples, or something like that.
    383     j += array_ops.zeros_like(result_so_far)
    384 
    385     # Re-define values at the cutoffs.
    386     if low is not None:
    387       neg_inf = -np.inf * array_ops.ones_like(result_so_far)
    388       result_so_far = array_ops.where(j < low, neg_inf, result_so_far)
    389     if high is not None:
    390       result_so_far = array_ops.where(j >= high,
    391                                       array_ops.zeros_like(result_so_far),
    392                                       result_so_far)
    393 
    394     return result_so_far
    395 
    396   @distribution_util.AppendDocstring(_cdf_note)
    397   def _cdf(self, y):
    398     low = self._low
    399     high = self._high
    400 
    401     # Recall the promise:
    402     # cdf(y) := P[Y <= y]
    403     #         = 1, if y >= high,
    404     #         = 0, if y < low,
    405     #         = P[X <= y], otherwise.
    406 
    407     # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in
    408     # between.
    409     j = math_ops.floor(y)
    410 
    411     # P[X <= j], used when low < X < high.
    412     result_so_far = self.distribution.cdf(j)
    413 
    414     # Broadcast, because it's possible that this is a single distribution being
    415     # evaluated on a number of samples, or something like that.
    416     j += array_ops.zeros_like(result_so_far)
    417 
    418     # Re-define values at the cutoffs.
    419     if low is not None:
    420       result_so_far = array_ops.where(j < low,
    421                                       array_ops.zeros_like(result_so_far),
    422                                       result_so_far)
    423     if high is not None:
    424       result_so_far = array_ops.where(j >= high,
    425                                       array_ops.ones_like(result_so_far),
    426                                       result_so_far)
    427 
    428     return result_so_far
    429 
    430   @distribution_util.AppendDocstring(_log_sf_note)
    431   def _log_survival_function(self, y):
    432     low = self._low
    433     high = self._high
    434 
    435     # Recall the promise:
    436     # survival_function(y) := P[Y > y]
    437     #                       = 0, if y >= high,
    438     #                       = 1, if y < low,
    439     #                       = P[X > y], otherwise.
    440 
    441     # P[Y > j] = P[ceiling(Y) > j] since mass is only at integers, not in
    442     # between.
    443     j = math_ops.ceil(y)
    444 
    445     # P[X > j], used when low < X < high.
    446     result_so_far = self.distribution.log_survival_function(j)
    447 
    448     # Broadcast, because it's possible that this is a single distribution being
    449     # evaluated on a number of samples, or something like that.
    450     j += array_ops.zeros_like(result_so_far)
    451 
    452     # Re-define values at the cutoffs.
    453     if low is not None:
    454       result_so_far = array_ops.where(j < low,
    455                                       array_ops.zeros_like(result_so_far),
    456                                       result_so_far)
    457     if high is not None:
    458       neg_inf = -np.inf * array_ops.ones_like(result_so_far)
    459       result_so_far = array_ops.where(j >= high, neg_inf, result_so_far)
    460 
    461     return result_so_far
    462 
    463   @distribution_util.AppendDocstring(_sf_note)
    464   def _survival_function(self, y):
    465     low = self._low
    466     high = self._high
    467 
    468     # Recall the promise:
    469     # survival_function(y) := P[Y > y]
    470     #                       = 0, if y >= high,
    471     #                       = 1, if y < low,
    472     #                       = P[X > y], otherwise.
    473 
    474     # P[Y > j] = P[ceiling(Y) > j] since mass is only at integers, not in
    475     # between.
    476     j = math_ops.ceil(y)
    477 
    478     # P[X > j], used when low < X < high.
    479     result_so_far = self.distribution.survival_function(j)
    480 
    481     # Broadcast, because it's possible that this is a single distribution being
    482     # evaluated on a number of samples, or something like that.
    483     j += array_ops.zeros_like(result_so_far)
    484 
    485     # Re-define values at the cutoffs.
    486     if low is not None:
    487       result_so_far = array_ops.where(j < low,
    488                                       array_ops.ones_like(result_so_far),
    489                                       result_so_far)
    490     if high is not None:
    491       result_so_far = array_ops.where(j >= high,
    492                                       array_ops.zeros_like(result_so_far),
    493                                       result_so_far)
    494 
    495     return result_so_far
    496 
    497   def _check_integer(self, value):
    498     with ops.name_scope("check_integer", values=[value]):
    499       value = ops.convert_to_tensor(value, name="value")
    500       if not self.validate_args:
    501         return value
    502       dependencies = [distribution_util.assert_integer_form(
    503           value, message="value has non-integer components.")]
    504       return control_flow_ops.with_dependencies(dependencies, value)
    505 
    506   @property
    507   def distribution(self):
    508     """Base distribution, p(x)."""
    509     return self._dist
    510