Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """The Poisson distribution class."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.framework import constant_op
     22 from tensorflow.python.framework import dtypes
     23 from tensorflow.python.framework import ops
     24 from tensorflow.python.framework import tensor_shape
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.ops import check_ops
     27 from tensorflow.python.ops import math_ops
     28 from tensorflow.python.ops import random_ops
     29 from tensorflow.python.ops.distributions import distribution
     30 from tensorflow.python.ops.distributions import util as distribution_util
     31 from tensorflow.python.util import deprecation
     32 
     33 __all__ = [
     34     "Poisson",
     35 ]
     36 
     37 
     38 _poisson_sample_note = """
     39 The Poisson distribution is technically only defined for non-negative integer
     40 values. When `validate_args=False`, non-integral inputs trigger an assertion.
     41 
     42 When `validate_args=False` calculations are otherwise unchanged despite
     43 integral or non-integral inputs.
     44 
     45 When `validate_args=False`, evaluating the pmf at non-integral values,
     46 corresponds to evaluations of an unnormalized distribution, that does not
     47 correspond to evaluations of the cdf.
     48 """
     49 
     50 
     51 class Poisson(distribution.Distribution):
     52   """Poisson distribution.
     53 
     54   The Poisson distribution is parameterized by an event `rate` parameter.
     55 
     56   #### Mathematical Details
     57 
     58   The probability mass function (pmf) is,
     59 
     60   ```none
     61   pmf(k; lambda, k >= 0) = (lambda^k / k!) / Z
     62   Z = exp(lambda).
     63   ```
     64 
     65   where `rate = lambda` and `Z` is the normalizing constant.
     66 
     67   """
     68 
     69   @deprecation.deprecated(
     70       "2018-10-01",
     71       "The TensorFlow Distributions library has moved to "
     72       "TensorFlow Probability "
     73       "(https://github.com/tensorflow/probability). You "
     74       "should update all references to use `tfp.distributions` "
     75       "instead of `tf.contrib.distributions`.",
     76       warn_once=True)
     77   def __init__(self,
     78                rate=None,
     79                log_rate=None,
     80                validate_args=False,
     81                allow_nan_stats=True,
     82                name="Poisson"):
     83     """Initialize a batch of Poisson distributions.
     84 
     85     Args:
     86       rate: Floating point tensor, the rate parameter. `rate` must be positive.
     87         Must specify exactly one of `rate` and `log_rate`.
     88       log_rate: Floating point tensor, the log of the rate parameter.
     89         Must specify exactly one of `rate` and `log_rate`.
     90       validate_args: Python `bool`, default `False`. When `True` distribution
     91         parameters are checked for validity despite possibly degrading runtime
     92         performance. When `False` invalid inputs may silently render incorrect
     93         outputs.
     94       allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
     95         (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
     96         result is undefined. When `False`, an exception is raised if one or
     97         more of the statistic's batch members are undefined.
     98       name: Python `str` name prefixed to Ops created by this class.
     99 
    100     Raises:
    101       ValueError: if none or both of `rate`, `log_rate` are specified.
    102       TypeError: if `rate` is not a float-type.
    103       TypeError: if `log_rate` is not a float-type.
    104     """
    105     parameters = dict(locals())
    106     with ops.name_scope(name, values=[rate]) as name:
    107       if (rate is None) == (log_rate is None):
    108         raise ValueError("Must specify exactly one of `rate` and `log_rate`.")
    109       elif log_rate is None:
    110         rate = ops.convert_to_tensor(rate, name="rate")
    111         if not rate.dtype.is_floating:
    112           raise TypeError("rate.dtype ({}) is a not a float-type.".format(
    113               rate.dtype.name))
    114         with ops.control_dependencies([check_ops.assert_positive(rate)] if
    115                                       validate_args else []):
    116           self._rate = array_ops.identity(rate, name="rate")
    117           self._log_rate = math_ops.log(rate, name="log_rate")
    118       else:
    119         log_rate = ops.convert_to_tensor(log_rate, name="log_rate")
    120         if not log_rate.dtype.is_floating:
    121           raise TypeError("log_rate.dtype ({}) is a not a float-type.".format(
    122               log_rate.dtype.name))
    123         self._rate = math_ops.exp(log_rate, name="rate")
    124         self._log_rate = ops.convert_to_tensor(log_rate, name="log_rate")
    125     super(Poisson, self).__init__(
    126         dtype=self._rate.dtype,
    127         reparameterization_type=distribution.NOT_REPARAMETERIZED,
    128         validate_args=validate_args,
    129         allow_nan_stats=allow_nan_stats,
    130         parameters=parameters,
    131         graph_parents=[self._rate],
    132         name=name)
    133 
    134   @property
    135   def rate(self):
    136     """Rate parameter."""
    137     return self._rate
    138 
    139   @property
    140   def log_rate(self):
    141     """Log rate parameter."""
    142     return self._log_rate
    143 
    144   def _batch_shape_tensor(self):
    145     return array_ops.shape(self.rate)
    146 
    147   def _batch_shape(self):
    148     return self.rate.shape
    149 
    150   def _event_shape_tensor(self):
    151     return constant_op.constant([], dtype=dtypes.int32)
    152 
    153   def _event_shape(self):
    154     return tensor_shape.scalar()
    155 
    156   @distribution_util.AppendDocstring(_poisson_sample_note)
    157   def _log_prob(self, x):
    158     return self._log_unnormalized_prob(x) - self._log_normalization()
    159 
    160   @distribution_util.AppendDocstring(_poisson_sample_note)
    161   def _log_cdf(self, x):
    162     return math_ops.log(self.cdf(x))
    163 
    164   @distribution_util.AppendDocstring(_poisson_sample_note)
    165   def _cdf(self, x):
    166     if self.validate_args:
    167       x = distribution_util.embed_check_nonnegative_integer_form(x)
    168     return math_ops.igammac(1. + x, self.rate)
    169 
    170   def _log_normalization(self):
    171     return self.rate
    172 
    173   def _log_unnormalized_prob(self, x):
    174     if self.validate_args:
    175       x = distribution_util.embed_check_nonnegative_integer_form(x)
    176     return x * self.log_rate - math_ops.lgamma(1. + x)
    177 
    178   def _mean(self):
    179     return array_ops.identity(self.rate)
    180 
    181   def _variance(self):
    182     return array_ops.identity(self.rate)
    183 
    184   @distribution_util.AppendDocstring(
    185       """Note: when `rate` is an integer, there are actually two modes: `rate`
    186       and `rate - 1`. In this case we return the larger, i.e., `rate`.""")
    187   def _mode(self):
    188     return math_ops.floor(self.rate)
    189 
    190   def _sample_n(self, n, seed=None):
    191     return random_ops.random_poisson(
    192         self.rate, [n], dtype=self.dtype, seed=seed)
    193