Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """The Geometric distribution class."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 
     22 from tensorflow.python.framework import constant_op
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.framework import tensor_shape
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import check_ops
     28 from tensorflow.python.ops import control_flow_ops
     29 from tensorflow.python.ops import math_ops
     30 from tensorflow.python.ops import nn
     31 from tensorflow.python.ops import random_ops
     32 from tensorflow.python.ops.distributions import distribution
     33 from tensorflow.python.ops.distributions import util as distribution_util
     34 
     35 
     36 class Geometric(distribution.Distribution):
     37   """Geometric distribution.
     38 
     39   The Geometric distribution is parameterized by p, the probability of a
     40   positive event. It represents the probability that in k + 1 Bernoulli trials,
     41   the first k trials failed, before seeing a success.
     42 
     43   The pmf of this distribution is:
     44 
     45   #### Mathematical Details
     46 
     47   ```none
     48   pmf(k; p) = (1 - p)**k * p
     49   ```
     50 
     51   where:
     52 
     53   * `p` is the success probability, `0 < p <= 1`, and,
     54   * `k` is a non-negative integer.
     55 
     56   """
     57 
     58   def __init__(self,
     59                logits=None,
     60                probs=None,
     61                validate_args=False,
     62                allow_nan_stats=True,
     63                name="Geometric"):
     64     """Construct Geometric distributions.
     65 
     66     Args:
     67       logits: Floating-point `Tensor` with shape `[B1, ..., Bb]` where `b >= 0`
     68         indicates the number of batch dimensions. Each entry represents logits
     69         for the probability of success for independent Geometric distributions
     70         and must be in the range `(-inf, inf]`. Only one of `logits` or `probs`
     71         should be specified.
     72       probs: Positive floating-point `Tensor` with shape `[B1, ..., Bb]`
     73         where `b >= 0` indicates the number of batch dimensions. Each entry
     74         represents the probability of success for independent Geometric
     75         distributions and must be in the range `(0, 1]`. Only one of `logits`
     76         or `probs` should be specified.
     77       validate_args: Python `bool`, default `False`. When `True` distribution
     78         parameters are checked for validity despite possibly degrading runtime
     79         performance. When `False` invalid inputs may silently render incorrect
     80         outputs.
     81       allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
     82         (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
     83         result is undefined. When `False`, an exception is raised if one or
     84         more of the statistic's batch members are undefined.
     85       name: Python `str` name prefixed to Ops created by this class.
     86     """
     87 
     88     parameters = locals()
     89     with ops.name_scope(name, values=[logits, probs]):
     90       self._logits, self._probs = distribution_util.get_logits_and_probs(
     91           logits, probs, validate_args=validate_args, name=name)
     92 
     93       with ops.control_dependencies(
     94           [check_ops.assert_positive(self._probs)] if validate_args else []):
     95         self._probs = array_ops.identity(self._probs, name="probs")
     96 
     97     super(Geometric, self).__init__(
     98         dtype=self._probs.dtype,
     99         reparameterization_type=distribution.NOT_REPARAMETERIZED,
    100         validate_args=validate_args,
    101         allow_nan_stats=allow_nan_stats,
    102         parameters=parameters,
    103         graph_parents=[self._probs, self._logits],
    104         name=name)
    105 
    106   @property
    107   def logits(self):
    108     """Log-odds of a `1` outcome (vs `0`)."""
    109     return self._logits
    110 
    111   @property
    112   def probs(self):
    113     """Probability of a `1` outcome (vs `0`)."""
    114     return self._probs
    115 
    116   def _batch_shape_tensor(self):
    117     return array_ops.shape(self._probs)
    118 
    119   def _batch_shape(self):
    120     return self.probs.get_shape()
    121 
    122   def _event_shape_tensor(self):
    123     return array_ops.constant([], dtype=dtypes.int32)
    124 
    125   def _event_shape(self):
    126     return tensor_shape.scalar()
    127 
    128   def _sample_n(self, n, seed=None):
    129     # Uniform variates must be sampled from the open-interval `(0, 1)` rather
    130     # than `[0, 1)`. To do so, we use `np.finfo(self.dtype.as_numpy_dtype).tiny`
    131     # because it is the smallest, positive, "normal" number. A "normal" number
    132     # is such that the mantissa has an implicit leading 1. Normal, positive
    133     # numbers x, y have the reasonable property that, `x + y >= max(x, y)`. In
    134     # this case, a subnormal number (i.e., np.nextafter) can cause us to sample
    135     # 0.
    136     sampled = random_ops.random_uniform(
    137         array_ops.concat([[n], array_ops.shape(self._probs)], 0),
    138         minval=np.finfo(self.dtype.as_numpy_dtype).tiny,
    139         maxval=1.,
    140         seed=seed,
    141         dtype=self.dtype)
    142 
    143     return math_ops.floor(
    144         math_ops.log(sampled) / math_ops.log1p(-self.probs))
    145 
    146   def _cdf(self, x):
    147     if self.validate_args:
    148       x = distribution_util.embed_check_nonnegative_integer_form(x)
    149     else:
    150       # Whether or not x is integer-form, the following is well-defined.
    151       # However, scipy takes the floor, so we do too.
    152       x = math_ops.floor(x)
    153     x *= array_ops.ones_like(self.probs)
    154     return array_ops.where(
    155         x < 0.,
    156         array_ops.zeros_like(x),
    157         -math_ops.expm1((1. + x) * math_ops.log1p(-self.probs)))
    158 
    159   def _log_prob(self, x):
    160     if self.validate_args:
    161       x = distribution_util.embed_check_nonnegative_integer_form(x)
    162     else:
    163       # For consistency with cdf, we take the floor.
    164       x = math_ops.floor(x)
    165     x *= array_ops.ones_like(self.probs)
    166     probs = self.probs * array_ops.ones_like(x)
    167     safe_domain = array_ops.where(
    168         math_ops.equal(x, 0.),
    169         array_ops.zeros_like(probs),
    170         probs)
    171     return x * math_ops.log1p(-safe_domain) + math_ops.log(probs)
    172 
    173   def _entropy(self):
    174     probs = self._probs
    175     if self.validate_args:
    176       probs = control_flow_ops.with_dependencies(
    177           [check_ops.assert_less(
    178               probs,
    179               constant_op.constant(1., probs.dtype),
    180               message="Entropy is undefined when logits = inf or probs = 1.")],
    181           probs)
    182     # Claim: entropy(p) = softplus(s)/p - s
    183     # where s=logits and p=probs.
    184     #
    185     # Proof:
    186     #
    187     # entropy(p)
    188     # := -[(1-p)log(1-p) + plog(p)]/p
    189     # = -[log(1-p) + plog(p/(1-p))]/p
    190     # = -[-softplus(s) + ps]/p
    191     # = softplus(s)/p - s
    192     #
    193     # since,
    194     # log[1-sigmoid(s)]
    195     # = log[1/(1+exp(s)]
    196     # = -log[1+exp(s)]
    197     # = -softplus(s)
    198     #
    199     # using the fact that,
    200     # 1-sigmoid(s) = sigmoid(-s) = 1/(1+exp(s))
    201     return nn.softplus(self.logits) / probs - self.logits
    202 
    203   def _mean(self):
    204     return math_ops.exp(-self.logits)
    205 
    206   def _variance(self):
    207     return self._mean() / self.probs
    208 
    209   def _mode(self):
    210     return array_ops.zeros(self.batch_shape_tensor(), dtype=self.dtype)
    211