1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """The Poisson distribution class.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.framework import constant_op 22 from tensorflow.python.framework import dtypes 23 from tensorflow.python.framework import ops 24 from tensorflow.python.framework import tensor_shape 25 from tensorflow.python.ops import array_ops 26 from tensorflow.python.ops import check_ops 27 from tensorflow.python.ops import math_ops 28 from tensorflow.python.ops import random_ops 29 from tensorflow.python.ops.distributions import distribution 30 from tensorflow.python.ops.distributions import util as distribution_util 31 from tensorflow.python.util import deprecation 32 33 __all__ = [ 34 "Poisson", 35 ] 36 37 38 _poisson_sample_note = """ 39 The Poisson distribution is technically only defined for non-negative integer 40 values. When `validate_args=False`, non-integral inputs trigger an assertion. 41 42 When `validate_args=False` calculations are otherwise unchanged despite 43 integral or non-integral inputs. 44 45 When `validate_args=False`, evaluating the pmf at non-integral values, 46 corresponds to evaluations of an unnormalized distribution, that does not 47 correspond to evaluations of the cdf. 48 """ 49 50 51 class Poisson(distribution.Distribution): 52 """Poisson distribution. 53 54 The Poisson distribution is parameterized by an event `rate` parameter. 55 56 #### Mathematical Details 57 58 The probability mass function (pmf) is, 59 60 ```none 61 pmf(k; lambda, k >= 0) = (lambda^k / k!) / Z 62 Z = exp(lambda). 63 ``` 64 65 where `rate = lambda` and `Z` is the normalizing constant. 66 67 """ 68 69 @deprecation.deprecated( 70 "2018-10-01", 71 "The TensorFlow Distributions library has moved to " 72 "TensorFlow Probability " 73 "(https://github.com/tensorflow/probability). You " 74 "should update all references to use `tfp.distributions` " 75 "instead of `tf.contrib.distributions`.", 76 warn_once=True) 77 def __init__(self, 78 rate=None, 79 log_rate=None, 80 validate_args=False, 81 allow_nan_stats=True, 82 name="Poisson"): 83 """Initialize a batch of Poisson distributions. 84 85 Args: 86 rate: Floating point tensor, the rate parameter. `rate` must be positive. 87 Must specify exactly one of `rate` and `log_rate`. 88 log_rate: Floating point tensor, the log of the rate parameter. 89 Must specify exactly one of `rate` and `log_rate`. 90 validate_args: Python `bool`, default `False`. When `True` distribution 91 parameters are checked for validity despite possibly degrading runtime 92 performance. When `False` invalid inputs may silently render incorrect 93 outputs. 94 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 95 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 96 result is undefined. When `False`, an exception is raised if one or 97 more of the statistic's batch members are undefined. 98 name: Python `str` name prefixed to Ops created by this class. 99 100 Raises: 101 ValueError: if none or both of `rate`, `log_rate` are specified. 102 TypeError: if `rate` is not a float-type. 103 TypeError: if `log_rate` is not a float-type. 104 """ 105 parameters = dict(locals()) 106 with ops.name_scope(name, values=[rate]) as name: 107 if (rate is None) == (log_rate is None): 108 raise ValueError("Must specify exactly one of `rate` and `log_rate`.") 109 elif log_rate is None: 110 rate = ops.convert_to_tensor(rate, name="rate") 111 if not rate.dtype.is_floating: 112 raise TypeError("rate.dtype ({}) is a not a float-type.".format( 113 rate.dtype.name)) 114 with ops.control_dependencies([check_ops.assert_positive(rate)] if 115 validate_args else []): 116 self._rate = array_ops.identity(rate, name="rate") 117 self._log_rate = math_ops.log(rate, name="log_rate") 118 else: 119 log_rate = ops.convert_to_tensor(log_rate, name="log_rate") 120 if not log_rate.dtype.is_floating: 121 raise TypeError("log_rate.dtype ({}) is a not a float-type.".format( 122 log_rate.dtype.name)) 123 self._rate = math_ops.exp(log_rate, name="rate") 124 self._log_rate = ops.convert_to_tensor(log_rate, name="log_rate") 125 super(Poisson, self).__init__( 126 dtype=self._rate.dtype, 127 reparameterization_type=distribution.NOT_REPARAMETERIZED, 128 validate_args=validate_args, 129 allow_nan_stats=allow_nan_stats, 130 parameters=parameters, 131 graph_parents=[self._rate], 132 name=name) 133 134 @property 135 def rate(self): 136 """Rate parameter.""" 137 return self._rate 138 139 @property 140 def log_rate(self): 141 """Log rate parameter.""" 142 return self._log_rate 143 144 def _batch_shape_tensor(self): 145 return array_ops.shape(self.rate) 146 147 def _batch_shape(self): 148 return self.rate.shape 149 150 def _event_shape_tensor(self): 151 return constant_op.constant([], dtype=dtypes.int32) 152 153 def _event_shape(self): 154 return tensor_shape.scalar() 155 156 @distribution_util.AppendDocstring(_poisson_sample_note) 157 def _log_prob(self, x): 158 return self._log_unnormalized_prob(x) - self._log_normalization() 159 160 @distribution_util.AppendDocstring(_poisson_sample_note) 161 def _log_cdf(self, x): 162 return math_ops.log(self.cdf(x)) 163 164 @distribution_util.AppendDocstring(_poisson_sample_note) 165 def _cdf(self, x): 166 if self.validate_args: 167 x = distribution_util.embed_check_nonnegative_integer_form(x) 168 return math_ops.igammac(1. + x, self.rate) 169 170 def _log_normalization(self): 171 return self.rate 172 173 def _log_unnormalized_prob(self, x): 174 if self.validate_args: 175 x = distribution_util.embed_check_nonnegative_integer_form(x) 176 return x * self.log_rate - math_ops.lgamma(1. + x) 177 178 def _mean(self): 179 return array_ops.identity(self.rate) 180 181 def _variance(self): 182 return array_ops.identity(self.rate) 183 184 @distribution_util.AppendDocstring( 185 """Note: when `rate` is an integer, there are actually two modes: `rate` 186 and `rate - 1`. In this case we return the larger, i.e., `rate`.""") 187 def _mode(self): 188 return math_ops.floor(self.rate) 189 190 def _sample_n(self, n, seed=None): 191 return random_ops.random_poisson( 192 self.rate, [n], dtype=self.dtype, seed=seed) 193