Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """The Normal distribution: conjugate posterior closed form calculations."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.ops import math_ops
     22 from tensorflow.python.ops.distributions import normal
     23 
     24 
     25 def normal_conjugates_known_scale_posterior(prior, scale, s, n):
     26   """Posterior Normal distribution with conjugate prior on the mean.
     27 
     28   This model assumes that `n` observations (with sum `s`) come from a
     29   Normal with unknown mean `loc` (described by the Normal `prior`)
     30   and known variance `scale**2`. The "known scale posterior" is
     31   the distribution of the unknown `loc`.
     32 
     33   Accepts a prior Normal distribution object, having parameters
     34   `loc0` and `scale0`, as well as known `scale` values of the predictive
     35   distribution(s) (also assumed Normal),
     36   and statistical estimates `s` (the sum(s) of the observations) and
     37   `n` (the number(s) of observations).
     38 
     39   Returns a posterior (also Normal) distribution object, with parameters
     40   `(loc', scale'**2)`, where:
     41 
     42   ```
     43   mu ~ N(mu', sigma'**2)
     44   sigma'**2 = 1/(1/sigma0**2 + n/sigma**2),
     45   mu' = (mu0/sigma0**2 + s/sigma**2) * sigma'**2.
     46   ```
     47 
     48   Distribution parameters from `prior`, as well as `scale`, `s`, and `n`.
     49   will broadcast in the case of multidimensional sets of parameters.
     50 
     51   Args:
     52     prior: `Normal` object of type `dtype`:
     53       the prior distribution having parameters `(loc0, scale0)`.
     54     scale: tensor of type `dtype`, taking values `scale > 0`.
     55       The known stddev parameter(s).
     56     s: Tensor of type `dtype`. The sum(s) of observations.
     57     n: Tensor of type `int`. The number(s) of observations.
     58 
     59   Returns:
     60     A new Normal posterior distribution object for the unknown observation
     61     mean `loc`.
     62 
     63   Raises:
     64     TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a
     65       Normal object.
     66   """
     67   if not isinstance(prior, normal.Normal):
     68     raise TypeError("Expected prior to be an instance of type Normal")
     69 
     70   if s.dtype != prior.dtype:
     71     raise TypeError(
     72         "Observation sum s.dtype does not match prior dtype: %s vs. %s"
     73         % (s.dtype, prior.dtype))
     74 
     75   n = math_ops.cast(n, prior.dtype)
     76   scale0_2 = math_ops.square(prior.scale)
     77   scale_2 = math_ops.square(scale)
     78   scalep_2 = 1.0/(1/scale0_2 + n/scale_2)
     79   return normal.Normal(
     80       loc=(prior.loc/scale0_2 + s/scale_2) * scalep_2,
     81       scale=math_ops.sqrt(scalep_2))
     82 
     83 
     84 def normal_conjugates_known_scale_predictive(prior, scale, s, n):
     85   """Posterior predictive Normal distribution w. conjugate prior on the mean.
     86 
     87   This model assumes that `n` observations (with sum `s`) come from a
     88   Normal with unknown mean `loc` (described by the Normal `prior`)
     89   and known variance `scale**2`. The "known scale predictive"
     90   is the distribution of new observations, conditioned on the existing
     91   observations and our prior.
     92 
     93   Accepts a prior Normal distribution object, having parameters
     94   `loc0` and `scale0`, as well as known `scale` values of the predictive
     95   distribution(s) (also assumed Normal),
     96   and statistical estimates `s` (the sum(s) of the observations) and
     97   `n` (the number(s) of observations).
     98 
     99   Calculates the Normal distribution(s) `p(x | sigma**2)`:
    100 
    101   ```
    102   p(x | sigma**2) = int N(x | mu, sigma**2)N(mu | prior.loc, prior.scale**2) dmu
    103                   = N(x | prior.loc, 1 / (sigma**2 + prior.scale**2))
    104   ```
    105 
    106   Returns the predictive posterior distribution object, with parameters
    107   `(loc', scale'**2)`, where:
    108 
    109   ```
    110   sigma_n**2 = 1/(1/sigma0**2 + n/sigma**2),
    111   mu' = (mu0/sigma0**2 + s/sigma**2) * sigma_n**2.
    112   sigma'**2 = sigma_n**2 + sigma**2,
    113   ```
    114 
    115   Distribution parameters from `prior`, as well as `scale`, `s`, and `n`.
    116   will broadcast in the case of multidimensional sets of parameters.
    117 
    118   Args:
    119     prior: `Normal` object of type `dtype`:
    120       the prior distribution having parameters `(loc0, scale0)`.
    121     scale: tensor of type `dtype`, taking values `scale > 0`.
    122       The known stddev parameter(s).
    123     s: Tensor of type `dtype`. The sum(s) of observations.
    124     n: Tensor of type `int`. The number(s) of observations.
    125 
    126   Returns:
    127     A new Normal predictive distribution object.
    128 
    129   Raises:
    130     TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a
    131       Normal object.
    132   """
    133   if not isinstance(prior, normal.Normal):
    134     raise TypeError("Expected prior to be an instance of type Normal")
    135 
    136   if s.dtype != prior.dtype:
    137     raise TypeError(
    138         "Observation sum s.dtype does not match prior dtype: %s vs. %s"
    139         % (s.dtype, prior.dtype))
    140 
    141   n = math_ops.cast(n, prior.dtype)
    142   scale0_2 = math_ops.square(prior.scale)
    143   scale_2 = math_ops.square(scale)
    144   scalep_2 = 1.0/(1/scale0_2 + n/scale_2)
    145   return normal.Normal(
    146       loc=(prior.loc/scale0_2 + s/scale_2) * scalep_2,
    147       scale=math_ops.sqrt(scalep_2 + scale_2))
    148