1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """The Half Normal distribution class.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.framework import constant_op 24 from tensorflow.python.framework import dtypes 25 from tensorflow.python.framework import ops 26 from tensorflow.python.framework import tensor_shape 27 from tensorflow.python.ops import array_ops 28 from tensorflow.python.ops import check_ops 29 from tensorflow.python.ops import math_ops 30 from tensorflow.python.ops import nn 31 from tensorflow.python.ops import random_ops 32 from tensorflow.python.ops.distributions import distribution 33 from tensorflow.python.ops.distributions import special_math 34 35 36 __all__ = [ 37 "HalfNormal", 38 ] 39 40 41 class HalfNormal(distribution.Distribution): 42 """The Half Normal distribution with scale `scale`. 43 44 #### Mathematical details 45 46 The half normal is a transformation of a centered normal distribution. 47 If some random variable `X` has normal distribution, 48 ```none 49 X ~ Normal(0.0, scale) 50 Y = |X| 51 ``` 52 Then `Y` will have half normal distribution. The probability density 53 function (pdf) is: 54 55 ```none 56 pdf(x; scale, x > 0) = sqrt(2) / (scale * sqrt(pi)) * 57 exp(- 1/2 * (x / scale) ** 2) 58 ) 59 ``` 60 Where `scale = sigma` is the standard deviation of the underlying normal 61 distribution. 62 63 #### Examples 64 65 Examples of initialization of one or a batch of distributions. 66 67 ```python 68 # Define a single scalar HalfNormal distribution. 69 dist = tf.contrib.distributions.HalfNormal(scale=3.0) 70 71 # Evaluate the cdf at 1, returning a scalar. 72 dist.cdf(1.) 73 74 # Define a batch of two scalar valued HalfNormals. 75 # The first has scale 11.0, the second 22.0 76 dist = tf.contrib.distributions.HalfNormal(scale=[11.0, 22.0]) 77 78 # Evaluate the pdf of the first distribution on 1.0, and the second on 1.5, 79 # returning a length two tensor. 80 dist.prob([1.0, 1.5]) 81 82 # Get 3 samples, returning a 3 x 2 tensor. 83 dist.sample([3]) 84 ``` 85 86 """ 87 88 def __init__(self, 89 scale, 90 validate_args=False, 91 allow_nan_stats=True, 92 name="HalfNormal"): 93 """Construct HalfNormals with scale `scale`. 94 95 Args: 96 scale: Floating point tensor; the scales of the distribution(s). 97 Must contain only positive values. 98 validate_args: Python `bool`, default `False`. When `True` distribution 99 parameters are checked for validity despite possibly degrading runtime 100 performance. When `False` invalid inputs may silently render incorrect 101 outputs. 102 allow_nan_stats: Python `bool`, default `True`. When `True`, 103 statistics (e.g., mean, mode, variance) use the value "`NaN`" to 104 indicate the result is undefined. When `False`, an exception is raised 105 if one or more of the statistic's batch members are undefined. 106 name: Python `str` name prefixed to Ops created by this class. 107 """ 108 parameters = locals() 109 with ops.name_scope(name, values=[scale]): 110 with ops.control_dependencies([check_ops.assert_positive(scale)] if 111 validate_args else []): 112 self._scale = array_ops.identity(scale, name="scale") 113 super(HalfNormal, self).__init__( 114 dtype=self._scale.dtype, 115 reparameterization_type=distribution.FULLY_REPARAMETERIZED, 116 validate_args=validate_args, 117 allow_nan_stats=allow_nan_stats, 118 parameters=parameters, 119 graph_parents=[self._scale], 120 name=name) 121 122 @staticmethod 123 def _param_shapes(sample_shape): 124 return {"scale": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)} 125 126 @property 127 def scale(self): 128 """Distribution parameter for the scale.""" 129 return self._scale 130 131 def _batch_shape_tensor(self): 132 return array_ops.shape(self.scale) 133 134 def _batch_shape(self): 135 return self.scale.shape 136 137 def _event_shape_tensor(self): 138 return constant_op.constant([], dtype=dtypes.int32) 139 140 def _event_shape(self): 141 return tensor_shape.scalar() 142 143 def _sample_n(self, n, seed=None): 144 shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) 145 sampled = random_ops.random_normal( 146 shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) 147 return math_ops.abs(sampled * self.scale) 148 149 def _prob(self, x): 150 coeff = np.sqrt(2) / self.scale / np.sqrt(np.pi) 151 pdf = coeff * math_ops.exp(- 0.5 * (x / self.scale) ** 2) 152 return pdf * math_ops.cast(x >= 0, self.dtype) 153 154 def _cdf(self, x): 155 truncated_x = nn.relu(x) 156 return math_ops.erf(truncated_x / self.scale / np.sqrt(2.0)) 157 158 def _entropy(self): 159 return 0.5 * math_ops.log(np.pi * self.scale ** 2.0 / 2.0) + 0.5 160 161 def _mean(self): 162 return self.scale * np.sqrt(2.0) / np.sqrt(np.pi) 163 164 def _quantile(self, p): 165 return np.sqrt(2.0) * self.scale * special_math.erfinv(p) 166 167 def _mode(self): 168 return array_ops.zeros(self.batch_shape_tensor()) 169 170 def _variance(self): 171 return self.scale ** 2.0 * (1.0 - 2.0 / np.pi) 172