Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Multivariate Normal distribution classes."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib import linalg
     22 from tensorflow.contrib.distributions.python.ops import distribution_util
     23 from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop
     24 from tensorflow.python.framework import ops
     25 
     26 
     27 __all__ = [
     28     "MultivariateNormalDiagPlusLowRank",
     29 ]
     30 
     31 
     32 class MultivariateNormalDiagPlusLowRank(
     33     mvn_linop.MultivariateNormalLinearOperator):
     34   """The multivariate normal distribution on `R^k`.
     35 
     36   The Multivariate Normal distribution is defined over `R^k` and parameterized
     37   by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k`
     38   `scale` matrix; `covariance = scale @ scale.T` where `@` denotes
     39   matrix-multiplication.
     40 
     41   #### Mathematical Details
     42 
     43   The probability density function (pdf) is,
     44 
     45   ```none
     46   pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z,
     47   y = inv(scale) @ (x - loc),
     48   Z = (2 pi)**(0.5 k) |det(scale)|,
     49   ```
     50 
     51   where:
     52 
     53   * `loc` is a vector in `R^k`,
     54   * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`,
     55   * `Z` denotes the normalization constant, and,
     56   * `||y||**2` denotes the squared Euclidean norm of `y`.
     57 
     58   A (non-batch) `scale` matrix is:
     59 
     60   ```none
     61   scale = diag(scale_diag + scale_identity_multiplier ones(k)) +
     62         scale_perturb_factor @ diag(scale_perturb_diag) @ scale_perturb_factor.T
     63   ```
     64 
     65   where:
     66 
     67   * `scale_diag.shape = [k]`,
     68   * `scale_identity_multiplier.shape = []`,
     69   * `scale_perturb_factor.shape = [k, r]`, typically `k >> r`, and,
     70   * `scale_perturb_diag.shape = [r]`.
     71 
     72   Additional leading dimensions (if any) will index batches.
     73 
     74   If both `scale_diag` and `scale_identity_multiplier` are `None`, then
     75   `scale` is the Identity matrix.
     76 
     77   The MultivariateNormal distribution is a member of the [location-scale
     78   family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
     79   constructed as,
     80 
     81   ```none
     82   X ~ MultivariateNormal(loc=0, scale=1)   # Identity scale, zero shift.
     83   Y = scale @ X + loc
     84   ```
     85 
     86   #### Examples
     87 
     88   ```python
     89   tfd = tf.contrib.distributions
     90 
     91   # Initialize a single 3-variate Gaussian with covariance `cov = S @ S.T`,
     92   # `S = diag(d) + U @ diag(m) @ U.T`. The perturbation, `U @ diag(m) @ U.T`, is
     93   # a rank-2 update.
     94   mu = [-0.5., 0, 0.5]   # shape: [3]
     95   d = [1.5, 0.5, 2]      # shape: [3]
     96   U = [[1., 2],
     97        [-1, 1],
     98        [2, -0.5]]        # shape: [3, 2]
     99   m = [4., 5]            # shape: [2]
    100   mvn = tfd.MultivariateNormalDiagPlusLowRank(
    101       loc=mu
    102       scale_diag=d
    103       scale_perturb_factor=U,
    104       scale_perturb_diag=m)
    105 
    106   # Evaluate this on an observation in `R^3`, returning a scalar.
    107   mvn.prob([-1, 0, 1]).eval()  # shape: []
    108 
    109   # Initialize a 2-batch of 3-variate Gaussians; `S = diag(d) + U @ U.T`.
    110   mu = [[1.,  2,  3],
    111         [11, 22, 33]]      # shape: [b, k] = [2, 3]
    112   U = [[[1., 2],
    113         [3,  4],
    114         [5,  6]],
    115        [[0.5, 0.75],
    116         [1,0, 0.25],
    117         [1.5, 1.25]]]      # shape: [b, k, r] = [2, 3, 2]
    118   m = [[0.1, 0.2],
    119        [0.4, 0.5]]         # shape: [b, r] = [2, 2]
    120 
    121   mvn = tfd.MultivariateNormalDiagPlusLowRank(
    122       loc=mu,
    123       scale_perturb_factor=U,
    124       scale_perturb_diag=m)
    125 
    126   mvn.covariance().eval()   # shape: [2, 3, 3]
    127   # ==> [[[  15.63   31.57    48.51]
    128   #       [  31.57   69.31   105.05]
    129   #       [  48.51  105.05   162.59]]
    130   #
    131   #      [[   2.59    1.41    3.35]
    132   #       [   1.41    2.71    3.34]
    133   #       [   3.35    3.34    8.35]]]
    134 
    135   # Compute the pdf of two `R^3` observations (one from each batch);
    136   # return a length-2 vector.
    137   x = [[-0.9, 0, 0.1],
    138        [-10, 0, 9]]     # shape: [2, 3]
    139   mvn.prob(x).eval()    # shape: [2]
    140   ```
    141 
    142   """
    143 
    144   def __init__(self,
    145                loc=None,
    146                scale_diag=None,
    147                scale_identity_multiplier=None,
    148                scale_perturb_factor=None,
    149                scale_perturb_diag=None,
    150                validate_args=False,
    151                allow_nan_stats=True,
    152                name="MultivariateNormalDiagPlusLowRank"):
    153     """Construct Multivariate Normal distribution on `R^k`.
    154 
    155     The `batch_shape` is the broadcast shape between `loc` and `scale`
    156     arguments.
    157 
    158     The `event_shape` is given by last dimension of the matrix implied by
    159     `scale`. The last dimension of `loc` (if provided) must broadcast with this.
    160 
    161     Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is:
    162 
    163     ```none
    164     scale = diag(scale_diag + scale_identity_multiplier ones(k)) +
    165         scale_perturb_factor @ diag(scale_perturb_diag) @ scale_perturb_factor.T
    166     ```
    167 
    168     where:
    169 
    170     * `scale_diag.shape = [k]`,
    171     * `scale_identity_multiplier.shape = []`,
    172     * `scale_perturb_factor.shape = [k, r]`, typically `k >> r`, and,
    173     * `scale_perturb_diag.shape = [r]`.
    174 
    175     Additional leading dimensions (if any) will index batches.
    176 
    177     If both `scale_diag` and `scale_identity_multiplier` are `None`, then
    178     `scale` is the Identity matrix.
    179 
    180     Args:
    181       loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
    182         implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
    183         `b >= 0` and `k` is the event size.
    184       scale_diag: Non-zero, floating-point `Tensor` representing a diagonal
    185         matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`,
    186         and characterizes `b`-batches of `k x k` diagonal matrices added to
    187         `scale`. When both `scale_identity_multiplier` and `scale_diag` are
    188         `None` then `scale` is the `Identity`.
    189       scale_identity_multiplier: Non-zero, floating-point `Tensor` representing
    190         a scaled-identity-matrix added to `scale`. May have shape
    191         `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scaled
    192         `k x k` identity matrices added to `scale`. When both
    193         `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is
    194         the `Identity`.
    195       scale_perturb_factor: Floating-point `Tensor` representing a rank-`r`
    196         perturbation added to `scale`. May have shape `[B1, ..., Bb, k, r]`,
    197         `b >= 0`, and characterizes `b`-batches of rank-`r` updates to `scale`.
    198         When `None`, no rank-`r` update is added to `scale`.
    199       scale_perturb_diag: Floating-point `Tensor` representing a diagonal matrix
    200         inside the rank-`r` perturbation added to `scale`. May have shape
    201         `[B1, ..., Bb, r]`, `b >= 0`, and characterizes `b`-batches of `r x r`
    202         diagonal matrices inside the perturbation added to `scale`. When
    203         `None`, an identity matrix is used inside the perturbation. Can only be
    204         specified if `scale_perturb_factor` is also specified.
    205       validate_args: Python `bool`, default `False`. When `True` distribution
    206         parameters are checked for validity despite possibly degrading runtime
    207         performance. When `False` invalid inputs may silently render incorrect
    208         outputs.
    209       allow_nan_stats: Python `bool`, default `True`. When `True`,
    210         statistics (e.g., mean, mode, variance) use the value "`NaN`" to
    211         indicate the result is undefined. When `False`, an exception is raised
    212         if one or more of the statistic's batch members are undefined.
    213       name: Python `str` name prefixed to Ops created by this class.
    214 
    215     Raises:
    216       ValueError: if at most `scale_identity_multiplier` is specified.
    217     """
    218     parameters = locals()
    219     def _convert_to_tensor(x, name):
    220       return None if x is None else ops.convert_to_tensor(x, name=name)
    221     with ops.name_scope(name):
    222       with ops.name_scope("init", values=[
    223           loc, scale_diag, scale_identity_multiplier, scale_perturb_factor,
    224           scale_perturb_diag]):
    225         has_low_rank = (scale_perturb_factor is not None or
    226                         scale_perturb_diag is not None)
    227         scale = distribution_util.make_diag_scale(
    228             loc=loc,
    229             scale_diag=scale_diag,
    230             scale_identity_multiplier=scale_identity_multiplier,
    231             validate_args=validate_args,
    232             assert_positive=has_low_rank)
    233         scale_perturb_factor = _convert_to_tensor(
    234             scale_perturb_factor,
    235             name="scale_perturb_factor")
    236         scale_perturb_diag = _convert_to_tensor(
    237             scale_perturb_diag,
    238             name="scale_perturb_diag")
    239         if has_low_rank:
    240           scale = linalg.LinearOperatorLowRankUpdate(
    241               scale,
    242               u=scale_perturb_factor,
    243               diag_update=scale_perturb_diag,
    244               is_diag_update_positive=scale_perturb_diag is None,
    245               is_non_singular=True,  # Implied by is_positive_definite=True.
    246               is_self_adjoint=True,
    247               is_positive_definite=True,
    248               is_square=True)
    249     super(MultivariateNormalDiagPlusLowRank, self).__init__(
    250         loc=loc,
    251         scale=scale,
    252         validate_args=validate_args,
    253         allow_nan_stats=allow_nan_stats,
    254         name=name)
    255     self._parameters = parameters
    256