1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """The Geometric distribution class.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import numpy as np 21 22 from tensorflow.python.framework import constant_op 23 from tensorflow.python.framework import dtypes 24 from tensorflow.python.framework import ops 25 from tensorflow.python.framework import tensor_shape 26 from tensorflow.python.ops import array_ops 27 from tensorflow.python.ops import check_ops 28 from tensorflow.python.ops import control_flow_ops 29 from tensorflow.python.ops import math_ops 30 from tensorflow.python.ops import nn 31 from tensorflow.python.ops import random_ops 32 from tensorflow.python.ops.distributions import distribution 33 from tensorflow.python.ops.distributions import util as distribution_util 34 35 36 class Geometric(distribution.Distribution): 37 """Geometric distribution. 38 39 The Geometric distribution is parameterized by p, the probability of a 40 positive event. It represents the probability that in k + 1 Bernoulli trials, 41 the first k trials failed, before seeing a success. 42 43 The pmf of this distribution is: 44 45 #### Mathematical Details 46 47 ```none 48 pmf(k; p) = (1 - p)**k * p 49 ``` 50 51 where: 52 53 * `p` is the success probability, `0 < p <= 1`, and, 54 * `k` is a non-negative integer. 55 56 """ 57 58 def __init__(self, 59 logits=None, 60 probs=None, 61 validate_args=False, 62 allow_nan_stats=True, 63 name="Geometric"): 64 """Construct Geometric distributions. 65 66 Args: 67 logits: Floating-point `Tensor` with shape `[B1, ..., Bb]` where `b >= 0` 68 indicates the number of batch dimensions. Each entry represents logits 69 for the probability of success for independent Geometric distributions 70 and must be in the range `(-inf, inf]`. Only one of `logits` or `probs` 71 should be specified. 72 probs: Positive floating-point `Tensor` with shape `[B1, ..., Bb]` 73 where `b >= 0` indicates the number of batch dimensions. Each entry 74 represents the probability of success for independent Geometric 75 distributions and must be in the range `(0, 1]`. Only one of `logits` 76 or `probs` should be specified. 77 validate_args: Python `bool`, default `False`. When `True` distribution 78 parameters are checked for validity despite possibly degrading runtime 79 performance. When `False` invalid inputs may silently render incorrect 80 outputs. 81 allow_nan_stats: Python `bool`, default `True`. When `True`, statistics 82 (e.g., mean, mode, variance) use the value "`NaN`" to indicate the 83 result is undefined. When `False`, an exception is raised if one or 84 more of the statistic's batch members are undefined. 85 name: Python `str` name prefixed to Ops created by this class. 86 """ 87 88 parameters = locals() 89 with ops.name_scope(name, values=[logits, probs]): 90 self._logits, self._probs = distribution_util.get_logits_and_probs( 91 logits, probs, validate_args=validate_args, name=name) 92 93 with ops.control_dependencies( 94 [check_ops.assert_positive(self._probs)] if validate_args else []): 95 self._probs = array_ops.identity(self._probs, name="probs") 96 97 super(Geometric, self).__init__( 98 dtype=self._probs.dtype, 99 reparameterization_type=distribution.NOT_REPARAMETERIZED, 100 validate_args=validate_args, 101 allow_nan_stats=allow_nan_stats, 102 parameters=parameters, 103 graph_parents=[self._probs, self._logits], 104 name=name) 105 106 @property 107 def logits(self): 108 """Log-odds of a `1` outcome (vs `0`).""" 109 return self._logits 110 111 @property 112 def probs(self): 113 """Probability of a `1` outcome (vs `0`).""" 114 return self._probs 115 116 def _batch_shape_tensor(self): 117 return array_ops.shape(self._probs) 118 119 def _batch_shape(self): 120 return self.probs.get_shape() 121 122 def _event_shape_tensor(self): 123 return array_ops.constant([], dtype=dtypes.int32) 124 125 def _event_shape(self): 126 return tensor_shape.scalar() 127 128 def _sample_n(self, n, seed=None): 129 # Uniform variates must be sampled from the open-interval `(0, 1)` rather 130 # than `[0, 1)`. To do so, we use `np.finfo(self.dtype.as_numpy_dtype).tiny` 131 # because it is the smallest, positive, "normal" number. A "normal" number 132 # is such that the mantissa has an implicit leading 1. Normal, positive 133 # numbers x, y have the reasonable property that, `x + y >= max(x, y)`. In 134 # this case, a subnormal number (i.e., np.nextafter) can cause us to sample 135 # 0. 136 sampled = random_ops.random_uniform( 137 array_ops.concat([[n], array_ops.shape(self._probs)], 0), 138 minval=np.finfo(self.dtype.as_numpy_dtype).tiny, 139 maxval=1., 140 seed=seed, 141 dtype=self.dtype) 142 143 return math_ops.floor( 144 math_ops.log(sampled) / math_ops.log1p(-self.probs)) 145 146 def _cdf(self, x): 147 if self.validate_args: 148 x = distribution_util.embed_check_nonnegative_integer_form(x) 149 else: 150 # Whether or not x is integer-form, the following is well-defined. 151 # However, scipy takes the floor, so we do too. 152 x = math_ops.floor(x) 153 x *= array_ops.ones_like(self.probs) 154 return array_ops.where( 155 x < 0., 156 array_ops.zeros_like(x), 157 -math_ops.expm1((1. + x) * math_ops.log1p(-self.probs))) 158 159 def _log_prob(self, x): 160 if self.validate_args: 161 x = distribution_util.embed_check_nonnegative_integer_form(x) 162 else: 163 # For consistency with cdf, we take the floor. 164 x = math_ops.floor(x) 165 x *= array_ops.ones_like(self.probs) 166 probs = self.probs * array_ops.ones_like(x) 167 safe_domain = array_ops.where( 168 math_ops.equal(x, 0.), 169 array_ops.zeros_like(probs), 170 probs) 171 return x * math_ops.log1p(-safe_domain) + math_ops.log(probs) 172 173 def _entropy(self): 174 probs = self._probs 175 if self.validate_args: 176 probs = control_flow_ops.with_dependencies( 177 [check_ops.assert_less( 178 probs, 179 constant_op.constant(1., probs.dtype), 180 message="Entropy is undefined when logits = inf or probs = 1.")], 181 probs) 182 # Claim: entropy(p) = softplus(s)/p - s 183 # where s=logits and p=probs. 184 # 185 # Proof: 186 # 187 # entropy(p) 188 # := -[(1-p)log(1-p) + plog(p)]/p 189 # = -[log(1-p) + plog(p/(1-p))]/p 190 # = -[-softplus(s) + ps]/p 191 # = softplus(s)/p - s 192 # 193 # since, 194 # log[1-sigmoid(s)] 195 # = log[1/(1+exp(s)] 196 # = -log[1+exp(s)] 197 # = -softplus(s) 198 # 199 # using the fact that, 200 # 1-sigmoid(s) = sigmoid(-s) = 1/(1+exp(s)) 201 return nn.softplus(self.logits) / probs - self.logits 202 203 def _mean(self): 204 return math_ops.exp(-self.logits) 205 206 def _variance(self): 207 return self._mean() / self.probs 208 209 def _mode(self): 210 return array_ops.zeros(self.batch_shape_tensor(), dtype=self.dtype) 211