Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """The Half Normal distribution class."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.framework import constant_op
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.framework import tensor_shape
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.ops import check_ops
     29 from tensorflow.python.ops import math_ops
     30 from tensorflow.python.ops import nn
     31 from tensorflow.python.ops import random_ops
     32 from tensorflow.python.ops.distributions import distribution
     33 from tensorflow.python.ops.distributions import special_math
     34 
     35 
     36 __all__ = [
     37     "HalfNormal",
     38 ]
     39 
     40 
     41 class HalfNormal(distribution.Distribution):
     42   """The Half Normal distribution with scale `scale`.
     43 
     44   #### Mathematical details
     45 
     46   The half normal is a transformation of a centered normal distribution.
     47   If some random variable `X` has normal distribution,
     48   ```none
     49   X ~ Normal(0.0, scale)
     50   Y = |X|
     51   ```
     52   Then `Y` will have half normal distribution. The probability density
     53   function (pdf) is:
     54 
     55   ```none
     56   pdf(x; scale, x > 0) = sqrt(2) / (scale * sqrt(pi)) *
     57     exp(- 1/2 * (x / scale) ** 2)
     58   )
     59   ```
     60   Where `scale = sigma` is the standard deviation of the underlying normal
     61   distribution.
     62 
     63   #### Examples
     64 
     65   Examples of initialization of one or a batch of distributions.
     66 
     67   ```python
     68   # Define a single scalar HalfNormal distribution.
     69   dist = tf.contrib.distributions.HalfNormal(scale=3.0)
     70 
     71   # Evaluate the cdf at 1, returning a scalar.
     72   dist.cdf(1.)
     73 
     74   # Define a batch of two scalar valued HalfNormals.
     75   # The first has scale 11.0, the second 22.0
     76   dist = tf.contrib.distributions.HalfNormal(scale=[11.0, 22.0])
     77 
     78   # Evaluate the pdf of the first distribution on 1.0, and the second on 1.5,
     79   # returning a length two tensor.
     80   dist.prob([1.0, 1.5])
     81 
     82   # Get 3 samples, returning a 3 x 2 tensor.
     83   dist.sample([3])
     84   ```
     85 
     86   """
     87 
     88   def __init__(self,
     89                scale,
     90                validate_args=False,
     91                allow_nan_stats=True,
     92                name="HalfNormal"):
     93     """Construct HalfNormals with scale `scale`.
     94 
     95     Args:
     96       scale: Floating point tensor; the scales of the distribution(s).
     97         Must contain only positive values.
     98       validate_args: Python `bool`, default `False`. When `True` distribution
     99         parameters are checked for validity despite possibly degrading runtime
    100         performance. When `False` invalid inputs may silently render incorrect
    101         outputs.
    102       allow_nan_stats: Python `bool`, default `True`. When `True`,
    103         statistics (e.g., mean, mode, variance) use the value "`NaN`" to
    104         indicate the result is undefined. When `False`, an exception is raised
    105         if one or more of the statistic's batch members are undefined.
    106       name: Python `str` name prefixed to Ops created by this class.
    107     """
    108     parameters = locals()
    109     with ops.name_scope(name, values=[scale]):
    110       with ops.control_dependencies([check_ops.assert_positive(scale)] if
    111                                     validate_args else []):
    112         self._scale = array_ops.identity(scale, name="scale")
    113     super(HalfNormal, self).__init__(
    114         dtype=self._scale.dtype,
    115         reparameterization_type=distribution.FULLY_REPARAMETERIZED,
    116         validate_args=validate_args,
    117         allow_nan_stats=allow_nan_stats,
    118         parameters=parameters,
    119         graph_parents=[self._scale],
    120         name=name)
    121 
    122   @staticmethod
    123   def _param_shapes(sample_shape):
    124     return {"scale": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)}
    125 
    126   @property
    127   def scale(self):
    128     """Distribution parameter for the scale."""
    129     return self._scale
    130 
    131   def _batch_shape_tensor(self):
    132     return array_ops.shape(self.scale)
    133 
    134   def _batch_shape(self):
    135     return self.scale.shape
    136 
    137   def _event_shape_tensor(self):
    138     return constant_op.constant([], dtype=dtypes.int32)
    139 
    140   def _event_shape(self):
    141     return tensor_shape.scalar()
    142 
    143   def _sample_n(self, n, seed=None):
    144     shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)
    145     sampled = random_ops.random_normal(
    146         shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed)
    147     return math_ops.abs(sampled * self.scale)
    148 
    149   def _prob(self, x):
    150     coeff = np.sqrt(2) / self.scale / np.sqrt(np.pi)
    151     pdf = coeff * math_ops.exp(- 0.5 * (x / self.scale) ** 2)
    152     return pdf * math_ops.cast(x >= 0, self.dtype)
    153 
    154   def _cdf(self, x):
    155     truncated_x = nn.relu(x)
    156     return math_ops.erf(truncated_x / self.scale / np.sqrt(2.0))
    157 
    158   def _entropy(self):
    159     return 0.5 * math_ops.log(np.pi * self.scale ** 2.0 / 2.0) + 0.5
    160 
    161   def _mean(self):
    162     return self.scale * np.sqrt(2.0) / np.sqrt(np.pi)
    163 
    164   def _quantile(self, p):
    165     return np.sqrt(2.0) * self.scale * special_math.erfinv(p)
    166 
    167   def _mode(self):
    168     return array_ops.zeros(self.batch_shape_tensor())
    169 
    170   def _variance(self):
    171     return self.scale ** 2.0 * (1.0 - 2.0 / np.pi)
    172