Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """The Poisson distribution class."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.framework import constant_op
     22 from tensorflow.python.framework import dtypes
     23 from tensorflow.python.framework import ops
     24 from tensorflow.python.framework import tensor_shape
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.ops import check_ops
     27 from tensorflow.python.ops import math_ops
     28 from tensorflow.python.ops import random_ops
     29 from tensorflow.python.ops.distributions import distribution
     30 from tensorflow.python.ops.distributions import util as distribution_util
     31 
     32 __all__ = [
     33     "Poisson",
     34 ]
     35 
     36 
     37 _poisson_sample_note = """
     38 Note that the input value must be a non-negative floating point tensor with
     39 dtype `dtype` and whose shape can be broadcast with `self.rate`. `x` is only
     40 legal if it is non-negative and its components are equal to integer values.
     41 """
     42 
     43 
     44 class Poisson(distribution.Distribution):
     45   """Poisson distribution.
     46 
     47   The Poisson distribution is parameterized by an event `rate` parameter.
     48 
     49   #### Mathematical Details
     50 
     51   The probability mass function (pmf) is,
     52 
     53   ```none
     54   pmf(k; lambda, k >= 0) = (lambda^k / k!) / Z
     55   Z = exp(lambda).
     56   ```
     57 
     58   where `rate = lambda` and `Z` is the normalizing constant.
     59 
     60   """
     61 
     62   def __init__(self,
     63                rate=None,
     64                log_rate=None,
     65                validate_args=False,
     66                allow_nan_stats=True,
     67                name="Poisson"):
     68     """Initialize a batch of Poisson distributions.
     69 
     70     Args:
     71       rate: Floating point tensor, the rate parameter. `rate` must be positive.
     72         Must specify exactly one of `rate` and `log_rate`.
     73       log_rate: Floating point tensor, the log of the rate parameter.
     74         Must specify exactly one of `rate` and `log_rate`.
     75       validate_args: Python `bool`, default `False`. When `True` distribution
     76         parameters are checked for validity despite possibly degrading runtime
     77         performance. When `False` invalid inputs may silently render incorrect
     78         outputs.
     79       allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
     80         (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
     81         result is undefined. When `False`, an exception is raised if one or
     82         more of the statistic's batch members are undefined.
     83       name: Python `str` name prefixed to Ops created by this class.
     84 
     85     Raises:
     86       ValueError: if none or both of `rate`, `log_rate` are specified.
     87       TypeError: if `rate` is not a float-type.
     88       TypeError: if `log_rate` is not a float-type.
     89     """
     90     parameters = locals()
     91     with ops.name_scope(name, values=[rate]):
     92       if (rate is None) == (log_rate is None):
     93         raise ValueError("Must specify exactly one of `rate` and `log_rate`.")
     94       elif log_rate is None:
     95         rate = ops.convert_to_tensor(rate, name="rate")
     96         if not rate.dtype.is_floating:
     97           raise TypeError("rate.dtype ({}) is a not a float-type.".format(
     98               rate.dtype.name))
     99         with ops.control_dependencies([check_ops.assert_positive(rate)] if
    100                                       validate_args else []):
    101           self._rate = array_ops.identity(rate, name="rate")
    102           self._log_rate = math_ops.log(rate, name="log_rate")
    103       else:
    104         log_rate = ops.convert_to_tensor(log_rate, name="log_rate")
    105         if not log_rate.dtype.is_floating:
    106           raise TypeError("log_rate.dtype ({}) is a not a float-type.".format(
    107               log_rate.dtype.name))
    108         self._rate = math_ops.exp(log_rate, name="rate")
    109         self._log_rate = ops.convert_to_tensor(log_rate, name="log_rate")
    110     super(Poisson, self).__init__(
    111         dtype=self._rate.dtype,
    112         reparameterization_type=distribution.NOT_REPARAMETERIZED,
    113         validate_args=validate_args,
    114         allow_nan_stats=allow_nan_stats,
    115         parameters=parameters,
    116         graph_parents=[self._rate],
    117         name=name)
    118 
    119   @property
    120   def rate(self):
    121     """Rate parameter."""
    122     return self._rate
    123 
    124   @property
    125   def log_rate(self):
    126     """Log rate parameter."""
    127     return self._log_rate
    128 
    129   def _batch_shape_tensor(self):
    130     return array_ops.shape(self.rate)
    131 
    132   def _batch_shape(self):
    133     return self.rate.shape
    134 
    135   def _event_shape_tensor(self):
    136     return constant_op.constant([], dtype=dtypes.int32)
    137 
    138   def _event_shape(self):
    139     return tensor_shape.scalar()
    140 
    141   @distribution_util.AppendDocstring(_poisson_sample_note)
    142   def _log_prob(self, x):
    143     return self._log_unnormalized_prob(x) - self._log_normalization()
    144 
    145   @distribution_util.AppendDocstring(_poisson_sample_note)
    146   def _log_cdf(self, x):
    147     return math_ops.log(self.cdf(x))
    148 
    149   @distribution_util.AppendDocstring(_poisson_sample_note)
    150   def _cdf(self, x):
    151     if self.validate_args:
    152       x = distribution_util.embed_check_nonnegative_integer_form(x)
    153     else:
    154       # Whether or not x is integer-form, the following is well-defined.
    155       # However, scipy takes the floor, so we do too.
    156       x = math_ops.floor(x)
    157     return math_ops.igammac(1. + x, self.rate)
    158 
    159   def _log_normalization(self):
    160     return self.rate
    161 
    162   def _log_unnormalized_prob(self, x):
    163     if self.validate_args:
    164       x = distribution_util.embed_check_nonnegative_integer_form(x)
    165     else:
    166       # For consistency with cdf, we take the floor.
    167       x = math_ops.floor(x)
    168     return x * self.log_rate - math_ops.lgamma(1. + x)
    169 
    170   def _mean(self):
    171     return array_ops.identity(self.rate)
    172 
    173   def _variance(self):
    174     return array_ops.identity(self.rate)
    175 
    176   @distribution_util.AppendDocstring(
    177       """Note: when `rate` is an integer, there are actually two modes: `rate`
    178       and `rate - 1`. In this case we return the larger, i.e., `rate`.""")
    179   def _mode(self):
    180     return math_ops.floor(self.rate)
    181 
    182   def _sample_n(self, n, seed=None):
    183     return random_ops.random_poisson(
    184         self.rate, [n], dtype=self.dtype, seed=seed)
    185