Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Quantized distribution."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.framework import ops
     24 from tensorflow.python.ops import array_ops
     25 from tensorflow.python.ops import check_ops
     26 from tensorflow.python.ops import control_flow_ops
     27 from tensorflow.python.ops import math_ops
     28 from tensorflow.python.ops.distributions import distribution as distributions
     29 from tensorflow.python.ops.distributions import util as distribution_util
     30 from tensorflow.python.util import deprecation
     31 
     32 __all__ = ["QuantizedDistribution"]
     33 
     34 
     35 @deprecation.deprecated(
     36     "2018-10-01",
     37     "The TensorFlow Distributions library has moved to "
     38     "TensorFlow Probability "
     39     "(https://github.com/tensorflow/probability). You "
     40     "should update all references to use `tfp.distributions` "
     41     "instead of `tf.contrib.distributions`.",
     42     warn_once=True)
     43 def _logsum_expbig_minus_expsmall(big, small):
     44   """Stable evaluation of `Log[exp{big} - exp{small}]`.
     45 
     46   To work correctly, we should have the pointwise relation:  `small <= big`.
     47 
     48   Args:
     49     big: Floating-point `Tensor`
     50     small: Floating-point `Tensor` with same `dtype` as `big` and broadcastable
     51       shape.
     52 
     53   Returns:
     54     `Tensor` of same `dtype` of `big` and broadcast shape.
     55   """
     56   with ops.name_scope("logsum_expbig_minus_expsmall", values=[small, big]):
     57     return math_ops.log(1. - math_ops.exp(small - big)) + big
     58 
     59 
     60 _prob_base_note = """
     61 For whole numbers `y`,
     62 
     63 ```
     64 P[Y = y] := P[X <= low],  if y == low,
     65          := P[X > high - 1],  y == high,
     66          := 0, if j < low or y > high,
     67          := P[y - 1 < X <= y],  all other y.
     68 ```
     69 
     70 """
     71 
     72 _prob_note = _prob_base_note + """
     73 The base distribution's `cdf` method must be defined on `y - 1`. If the
     74 base distribution has a `survival_function` method, results will be more
     75 accurate for large values of `y`, and in this case the `survival_function` must
     76 also be defined on `y - 1`.
     77 """
     78 
     79 _log_prob_note = _prob_base_note + """
     80 The base distribution's `log_cdf` method must be defined on `y - 1`. If the
     81 base distribution has a `log_survival_function` method results will be more
     82 accurate for large values of `y`, and in this case the `log_survival_function`
     83 must also be defined on `y - 1`.
     84 """
     85 
     86 
     87 _cdf_base_note = """
     88 
     89 For whole numbers `y`,
     90 
     91 ```
     92 cdf(y) := P[Y <= y]
     93         = 1, if y >= high,
     94         = 0, if y < low,
     95         = P[X <= y], otherwise.
     96 ```
     97 
     98 Since `Y` only has mass at whole numbers, `P[Y <= y] = P[Y <= floor(y)]`.
     99 This dictates that fractional `y` are first floored to a whole number, and
    100 then above definition applies.
    101 """
    102 
    103 _cdf_note = _cdf_base_note + """
    104 The base distribution's `cdf` method must be defined on `y - 1`.
    105 """
    106 
    107 _log_cdf_note = _cdf_base_note + """
    108 The base distribution's `log_cdf` method must be defined on `y - 1`.
    109 """
    110 
    111 
    112 _sf_base_note = """
    113 
    114 For whole numbers `y`,
    115 
    116 ```
    117 survival_function(y) := P[Y > y]
    118                       = 0, if y >= high,
    119                       = 1, if y < low,
    120                       = P[X <= y], otherwise.
    121 ```
    122 
    123 Since `Y` only has mass at whole numbers, `P[Y <= y] = P[Y <= floor(y)]`.
    124 This dictates that fractional `y` are first floored to a whole number, and
    125 then above definition applies.
    126 """
    127 
    128 _sf_note = _sf_base_note + """
    129 The base distribution's `cdf` method must be defined on `y - 1`.
    130 """
    131 
    132 _log_sf_note = _sf_base_note + """
    133 The base distribution's `log_cdf` method must be defined on `y - 1`.
    134 """
    135 
    136 
    137 class QuantizedDistribution(distributions.Distribution):
    138   """Distribution representing the quantization `Y = ceiling(X)`.
    139 
    140   #### Definition in Terms of Sampling
    141 
    142   ```
    143   1. Draw X
    144   2. Set Y <-- ceiling(X)
    145   3. If Y < low, reset Y <-- low
    146   4. If Y > high, reset Y <-- high
    147   5. Return Y
    148   ```
    149 
    150   #### Definition in Terms of the Probability Mass Function
    151 
    152   Given scalar random variable `X`, we define a discrete random variable `Y`
    153   supported on the integers as follows:
    154 
    155   ```
    156   P[Y = j] := P[X <= low],  if j == low,
    157            := P[X > high - 1],  j == high,
    158            := 0, if j < low or j > high,
    159            := P[j - 1 < X <= j],  all other j.
    160   ```
    161 
    162   Conceptually, without cutoffs, the quantization process partitions the real
    163   line `R` into half open intervals, and identifies an integer `j` with the
    164   right endpoints:
    165 
    166   ```
    167   R = ... (-2, -1](-1, 0](0, 1](1, 2](2, 3](3, 4] ...
    168   j = ...      -1      0     1     2     3     4  ...
    169   ```
    170 
    171   `P[Y = j]` is the mass of `X` within the `jth` interval.
    172   If `low = 0`, and `high = 2`, then the intervals are redrawn
    173   and `j` is re-assigned:
    174 
    175   ```
    176   R = (-infty, 0](0, 1](1, infty)
    177   j =          0     1     2
    178   ```
    179 
    180   `P[Y = j]` is still the mass of `X` within the `jth` interval.
    181 
    182   #### Examples
    183 
    184   We illustrate a mixture of discretized logistic distributions
    185   [(Salimans et al., 2017)][1]. This is used, for example, for capturing 16-bit
    186   audio in WaveNet [(van den Oord et al., 2017)][2]. The values range in
    187   a 1-D integer domain of `[0, 2**16-1]`, and the discretization captures
    188   `P(x - 0.5 < X <= x + 0.5)` for all `x` in the domain excluding the endpoints.
    189   The lowest value has probability `P(X <= 0.5)` and the highest value has
    190   probability `P(2**16 - 1.5 < X)`.
    191 
    192   Below we assume a `wavenet` function. It takes as `input` right-shifted audio
    193   samples of shape `[..., sequence_length]`. It returns a real-valued tensor of
    194   shape `[..., num_mixtures * 3]`, i.e., each mixture component has a `loc` and
    195   `scale` parameter belonging to the logistic distribution, and a `logits`
    196   parameter determining the unnormalized probability of that component.
    197 
    198   ```python
    199   import tensorflow_probability as tfp
    200   tfd = tfp.distributions
    201   tfb = tfp.bijectors
    202 
    203   net = wavenet(inputs)
    204   loc, unconstrained_scale, logits = tf.split(net,
    205                                               num_or_size_splits=3,
    206                                               axis=-1)
    207   scale = tf.nn.softplus(unconstrained_scale)
    208 
    209   # Form mixture of discretized logistic distributions. Note we shift the
    210   # logistic distribution by -0.5. This lets the quantization capture "rounding"
    211   # intervals, `(x-0.5, x+0.5]`, and not "ceiling" intervals, `(x-1, x]`.
    212   discretized_logistic_dist = tfd.QuantizedDistribution(
    213       distribution=tfd.TransformedDistribution(
    214           distribution=tfd.Logistic(loc=loc, scale=scale),
    215           bijector=tfb.AffineScalar(shift=-0.5)),
    216       low=0.,
    217       high=2**16 - 1.)
    218   mixture_dist = tfd.MixtureSameFamily(
    219       mixture_distribution=tfd.Categorical(logits=logits),
    220       components_distribution=discretized_logistic_dist)
    221 
    222   neg_log_likelihood = -tf.reduce_sum(mixture_dist.log_prob(targets))
    223   train_op = tf.train.AdamOptimizer().minimize(neg_log_likelihood)
    224   ```
    225 
    226   After instantiating `mixture_dist`, we illustrate maximum likelihood by
    227   calculating its log-probability of audio samples as `target` and optimizing.
    228 
    229   #### References
    230 
    231   [1]: Tim Salimans, Andrej Karpathy, Xi Chen, and Diederik P. Kingma.
    232        PixelCNN++: Improving the PixelCNN with discretized logistic mixture
    233        likelihood and other modifications.
    234        _International Conference on Learning Representations_, 2017.
    235        https://arxiv.org/abs/1701.05517
    236   [2]: Aaron van den Oord et al. Parallel WaveNet: Fast High-Fidelity Speech
    237        Synthesis. _arXiv preprint arXiv:1711.10433_, 2017.
    238        https://arxiv.org/abs/1711.10433
    239   """
    240 
    241   @deprecation.deprecated(
    242       "2018-10-01",
    243       "The TensorFlow Distributions library has moved to "
    244       "TensorFlow Probability "
    245       "(https://github.com/tensorflow/probability). You "
    246       "should update all references to use `tfp.distributions` "
    247       "instead of `tf.contrib.distributions`.",
    248       warn_once=True)
    249   def __init__(self,
    250                distribution,
    251                low=None,
    252                high=None,
    253                validate_args=False,
    254                name="QuantizedDistribution"):
    255     """Construct a Quantized Distribution representing `Y = ceiling(X)`.
    256 
    257     Some properties are inherited from the distribution defining `X`. Example:
    258     `allow_nan_stats` is determined for this `QuantizedDistribution` by reading
    259     the `distribution`.
    260 
    261     Args:
    262       distribution:  The base distribution class to transform. Typically an
    263         instance of `Distribution`.
    264       low: `Tensor` with same `dtype` as this distribution and shape
    265         able to be added to samples. Should be a whole number. Default `None`.
    266         If provided, base distribution's `prob` should be defined at
    267         `low`.
    268       high: `Tensor` with same `dtype` as this distribution and shape
    269         able to be added to samples. Should be a whole number. Default `None`.
    270         If provided, base distribution's `prob` should be defined at
    271         `high - 1`.
    272         `high` must be strictly greater than `low`.
    273       validate_args: Python `bool`, default `False`. When `True` distribution
    274         parameters are checked for validity despite possibly degrading runtime
    275         performance. When `False` invalid inputs may silently render incorrect
    276         outputs.
    277       name: Python `str` name prefixed to Ops created by this class.
    278 
    279     Raises:
    280       TypeError: If `dist_cls` is not a subclass of
    281           `Distribution` or continuous.
    282       NotImplementedError:  If the base distribution does not implement `cdf`.
    283     """
    284     parameters = dict(locals())
    285     values = (
    286         list(distribution.parameters.values()) +
    287         [low, high])
    288     with ops.name_scope(name, values=values) as name:
    289       self._dist = distribution
    290 
    291       if low is not None:
    292         low = ops.convert_to_tensor(low, name="low")
    293       if high is not None:
    294         high = ops.convert_to_tensor(high, name="high")
    295       check_ops.assert_same_float_dtype(
    296           tensors=[self.distribution, low, high])
    297 
    298       # We let QuantizedDistribution access _graph_parents since this class is
    299       # more like a baseclass.
    300       graph_parents = self._dist._graph_parents  # pylint: disable=protected-access
    301 
    302       checks = []
    303       if validate_args and low is not None and high is not None:
    304         message = "low must be strictly less than high."
    305         checks.append(
    306             check_ops.assert_less(
    307                 low, high, message=message))
    308       self._validate_args = validate_args  # self._check_integer uses this.
    309       with ops.control_dependencies(checks if validate_args else []):
    310         if low is not None:
    311           self._low = self._check_integer(low)
    312           graph_parents += [self._low]
    313         else:
    314           self._low = None
    315         if high is not None:
    316           self._high = self._check_integer(high)
    317           graph_parents += [self._high]
    318         else:
    319           self._high = None
    320 
    321     super(QuantizedDistribution, self).__init__(
    322         dtype=self._dist.dtype,
    323         reparameterization_type=distributions.NOT_REPARAMETERIZED,
    324         validate_args=validate_args,
    325         allow_nan_stats=self._dist.allow_nan_stats,
    326         parameters=parameters,
    327         graph_parents=graph_parents,
    328         name=name)
    329 
    330   @property
    331   def distribution(self):
    332     """Base distribution, p(x)."""
    333     return self._dist
    334 
    335   @property
    336   def low(self):
    337     """Lowest value that quantization returns."""
    338     return self._low
    339 
    340   @property
    341   def high(self):
    342     """Highest value that quantization returns."""
    343     return self._high
    344 
    345   def _batch_shape_tensor(self):
    346     return self.distribution.batch_shape_tensor()
    347 
    348   def _batch_shape(self):
    349     return self.distribution.batch_shape
    350 
    351   def _event_shape_tensor(self):
    352     return self.distribution.event_shape_tensor()
    353 
    354   def _event_shape(self):
    355     return self.distribution.event_shape
    356 
    357   def _sample_n(self, n, seed=None):
    358     low = self._low
    359     high = self._high
    360     with ops.name_scope("transform"):
    361       n = ops.convert_to_tensor(n, name="n")
    362       x_samps = self.distribution.sample(n, seed=seed)
    363       ones = array_ops.ones_like(x_samps)
    364 
    365       # Snap values to the intervals (j - 1, j].
    366       result_so_far = math_ops.ceil(x_samps)
    367 
    368       if low is not None:
    369         result_so_far = array_ops.where(result_so_far < low,
    370                                         low * ones, result_so_far)
    371 
    372       if high is not None:
    373         result_so_far = array_ops.where(result_so_far > high,
    374                                         high * ones, result_so_far)
    375 
    376       return result_so_far
    377 
    378   @distribution_util.AppendDocstring(_log_prob_note)
    379   def _log_prob(self, y):
    380     if not hasattr(self.distribution, "_log_cdf"):
    381       raise NotImplementedError(
    382           "'log_prob' not implemented unless the base distribution implements "
    383           "'log_cdf'")
    384     y = self._check_integer(y)
    385     try:
    386       return self._log_prob_with_logsf_and_logcdf(y)
    387     except NotImplementedError:
    388       return self._log_prob_with_logcdf(y)
    389 
    390   def _log_prob_with_logcdf(self, y):
    391     return _logsum_expbig_minus_expsmall(self.log_cdf(y), self.log_cdf(y - 1))
    392 
    393   def _log_prob_with_logsf_and_logcdf(self, y):
    394     """Compute log_prob(y) using log survival_function and cdf together."""
    395     # There are two options that would be equal if we had infinite precision:
    396     # Log[ sf(y - 1) - sf(y) ]
    397     #   = Log[ exp{logsf(y - 1)} - exp{logsf(y)} ]
    398     # Log[ cdf(y) - cdf(y - 1) ]
    399     #   = Log[ exp{logcdf(y)} - exp{logcdf(y - 1)} ]
    400     logsf_y = self.log_survival_function(y)
    401     logsf_y_minus_1 = self.log_survival_function(y - 1)
    402     logcdf_y = self.log_cdf(y)
    403     logcdf_y_minus_1 = self.log_cdf(y - 1)
    404 
    405     # Important:  Here we use select in a way such that no input is inf, this
    406     # prevents the troublesome case where the output of select can be finite,
    407     # but the output of grad(select) will be NaN.
    408 
    409     # In either case, we are doing Log[ exp{big} - exp{small} ]
    410     # We want to use the sf items precisely when we are on the right side of the
    411     # median, which occurs when logsf_y < logcdf_y.
    412     big = array_ops.where(logsf_y < logcdf_y, logsf_y_minus_1, logcdf_y)
    413     small = array_ops.where(logsf_y < logcdf_y, logsf_y, logcdf_y_minus_1)
    414 
    415     return _logsum_expbig_minus_expsmall(big, small)
    416 
    417   @distribution_util.AppendDocstring(_prob_note)
    418   def _prob(self, y):
    419     if not hasattr(self.distribution, "_cdf"):
    420       raise NotImplementedError(
    421           "'prob' not implemented unless the base distribution implements "
    422           "'cdf'")
    423     y = self._check_integer(y)
    424     try:
    425       return self._prob_with_sf_and_cdf(y)
    426     except NotImplementedError:
    427       return self._prob_with_cdf(y)
    428 
    429   def _prob_with_cdf(self, y):
    430     return self.cdf(y) - self.cdf(y - 1)
    431 
    432   def _prob_with_sf_and_cdf(self, y):
    433     # There are two options that would be equal if we had infinite precision:
    434     # sf(y - 1) - sf(y)
    435     # cdf(y) - cdf(y - 1)
    436     sf_y = self.survival_function(y)
    437     sf_y_minus_1 = self.survival_function(y - 1)
    438     cdf_y = self.cdf(y)
    439     cdf_y_minus_1 = self.cdf(y - 1)
    440 
    441     # sf_prob has greater precision iff we're on the right side of the median.
    442     return array_ops.where(
    443         sf_y < cdf_y,  # True iff we're on the right side of the median.
    444         sf_y_minus_1 - sf_y,
    445         cdf_y - cdf_y_minus_1)
    446 
    447   @distribution_util.AppendDocstring(_log_cdf_note)
    448   def _log_cdf(self, y):
    449     low = self._low
    450     high = self._high
    451 
    452     # Recall the promise:
    453     # cdf(y) := P[Y <= y]
    454     #         = 1, if y >= high,
    455     #         = 0, if y < low,
    456     #         = P[X <= y], otherwise.
    457 
    458     # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in
    459     # between.
    460     j = math_ops.floor(y)
    461 
    462     result_so_far = self.distribution.log_cdf(j)
    463 
    464     # Broadcast, because it's possible that this is a single distribution being
    465     # evaluated on a number of samples, or something like that.
    466     j += array_ops.zeros_like(result_so_far)
    467 
    468     # Re-define values at the cutoffs.
    469     if low is not None:
    470       neg_inf = -np.inf * array_ops.ones_like(result_so_far)
    471       result_so_far = array_ops.where(j < low, neg_inf, result_so_far)
    472     if high is not None:
    473       result_so_far = array_ops.where(j >= high,
    474                                       array_ops.zeros_like(result_so_far),
    475                                       result_so_far)
    476 
    477     return result_so_far
    478 
    479   @distribution_util.AppendDocstring(_cdf_note)
    480   def _cdf(self, y):
    481     low = self._low
    482     high = self._high
    483 
    484     # Recall the promise:
    485     # cdf(y) := P[Y <= y]
    486     #         = 1, if y >= high,
    487     #         = 0, if y < low,
    488     #         = P[X <= y], otherwise.
    489 
    490     # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in
    491     # between.
    492     j = math_ops.floor(y)
    493 
    494     # P[X <= j], used when low < X < high.
    495     result_so_far = self.distribution.cdf(j)
    496 
    497     # Broadcast, because it's possible that this is a single distribution being
    498     # evaluated on a number of samples, or something like that.
    499     j += array_ops.zeros_like(result_so_far)
    500 
    501     # Re-define values at the cutoffs.
    502     if low is not None:
    503       result_so_far = array_ops.where(j < low,
    504                                       array_ops.zeros_like(result_so_far),
    505                                       result_so_far)
    506     if high is not None:
    507       result_so_far = array_ops.where(j >= high,
    508                                       array_ops.ones_like(result_so_far),
    509                                       result_so_far)
    510 
    511     return result_so_far
    512 
    513   @distribution_util.AppendDocstring(_log_sf_note)
    514   def _log_survival_function(self, y):
    515     low = self._low
    516     high = self._high
    517 
    518     # Recall the promise:
    519     # survival_function(y) := P[Y > y]
    520     #                       = 0, if y >= high,
    521     #                       = 1, if y < low,
    522     #                       = P[X > y], otherwise.
    523 
    524     # P[Y > j] = P[ceiling(Y) > j] since mass is only at integers, not in
    525     # between.
    526     j = math_ops.ceil(y)
    527 
    528     # P[X > j], used when low < X < high.
    529     result_so_far = self.distribution.log_survival_function(j)
    530 
    531     # Broadcast, because it's possible that this is a single distribution being
    532     # evaluated on a number of samples, or something like that.
    533     j += array_ops.zeros_like(result_so_far)
    534 
    535     # Re-define values at the cutoffs.
    536     if low is not None:
    537       result_so_far = array_ops.where(j < low,
    538                                       array_ops.zeros_like(result_so_far),
    539                                       result_so_far)
    540     if high is not None:
    541       neg_inf = -np.inf * array_ops.ones_like(result_so_far)
    542       result_so_far = array_ops.where(j >= high, neg_inf, result_so_far)
    543 
    544     return result_so_far
    545 
    546   @distribution_util.AppendDocstring(_sf_note)
    547   def _survival_function(self, y):
    548     low = self._low
    549     high = self._high
    550 
    551     # Recall the promise:
    552     # survival_function(y) := P[Y > y]
    553     #                       = 0, if y >= high,
    554     #                       = 1, if y < low,
    555     #                       = P[X > y], otherwise.
    556 
    557     # P[Y > j] = P[ceiling(Y) > j] since mass is only at integers, not in
    558     # between.
    559     j = math_ops.ceil(y)
    560 
    561     # P[X > j], used when low < X < high.
    562     result_so_far = self.distribution.survival_function(j)
    563 
    564     # Broadcast, because it's possible that this is a single distribution being
    565     # evaluated on a number of samples, or something like that.
    566     j += array_ops.zeros_like(result_so_far)
    567 
    568     # Re-define values at the cutoffs.
    569     if low is not None:
    570       result_so_far = array_ops.where(j < low,
    571                                       array_ops.ones_like(result_so_far),
    572                                       result_so_far)
    573     if high is not None:
    574       result_so_far = array_ops.where(j >= high,
    575                                       array_ops.zeros_like(result_so_far),
    576                                       result_so_far)
    577 
    578     return result_so_far
    579 
    580   def _check_integer(self, value):
    581     with ops.name_scope("check_integer", values=[value]):
    582       value = ops.convert_to_tensor(value, name="value")
    583       if not self.validate_args:
    584         return value
    585       dependencies = [distribution_util.assert_integer_form(
    586           value, message="value has non-integer components.")]
    587       return control_flow_ops.with_dependencies(dependencies, value)
    588