Home | History | Annotate | Download | only in distributions
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Base classes for probability distributions."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import abc
     22 import contextlib
     23 import types
     24 
     25 import numpy as np
     26 import six
     27 
     28 from tensorflow.python.framework import dtypes
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.framework import tensor_shape
     31 from tensorflow.python.framework import tensor_util
     32 from tensorflow.python.ops import array_ops
     33 from tensorflow.python.ops import math_ops
     34 from tensorflow.python.ops.distributions import kullback_leibler
     35 from tensorflow.python.ops.distributions import util
     36 from tensorflow.python.util import tf_inspect
     37 from tensorflow.python.util.tf_export import tf_export
     38 
     39 
     40 __all__ = [
     41     "ReparameterizationType",
     42     "FULLY_REPARAMETERIZED",
     43     "NOT_REPARAMETERIZED",
     44     "Distribution",
     45 ]
     46 
     47 _DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [
     48     "batch_shape",
     49     "batch_shape_tensor",
     50     "cdf",
     51     "covariance",
     52     "cross_entropy",
     53     "entropy",
     54     "event_shape",
     55     "event_shape_tensor",
     56     "kl_divergence",
     57     "log_cdf",
     58     "log_prob",
     59     "log_survival_function",
     60     "mean",
     61     "mode",
     62     "prob",
     63     "sample",
     64     "stddev",
     65     "survival_function",
     66     "variance",
     67 ]
     68 
     69 
     70 @six.add_metaclass(abc.ABCMeta)
     71 class _BaseDistribution(object):
     72   """Abstract base class needed for resolving subclass hierarchy."""
     73   pass
     74 
     75 
     76 def _copy_fn(fn):
     77   """Create a deep copy of fn.
     78 
     79   Args:
     80     fn: a callable
     81 
     82   Returns:
     83     A `FunctionType`: a deep copy of fn.
     84 
     85   Raises:
     86     TypeError: if `fn` is not a callable.
     87   """
     88   if not callable(fn):
     89     raise TypeError("fn is not callable: %s" % fn)
     90   # The blessed way to copy a function. copy.deepcopy fails to create a
     91   # non-reference copy. Since:
     92   #   types.FunctionType == type(lambda: None),
     93   # and the docstring for the function type states:
     94   #
     95   #   function(code, globals[, name[, argdefs[, closure]]])
     96   #
     97   #   Create a function object from a code object and a dictionary.
     98   #   ...
     99   #
    100   # Here we can use this to create a new function with the old function's
    101   # code, globals, closure, etc.
    102   return types.FunctionType(
    103       code=fn.__code__, globals=fn.__globals__,
    104       name=fn.__name__, argdefs=fn.__defaults__,
    105       closure=fn.__closure__)
    106 
    107 
    108 def _update_docstring(old_str, append_str):
    109   """Update old_str by inserting append_str just before the "Args:" section."""
    110   old_str = old_str or ""
    111   old_str_lines = old_str.split("\n")
    112 
    113   # Step 0: Prepend spaces to all lines of append_str. This is
    114   # necessary for correct markdown generation.
    115   append_str = "\n".join("    %s" % line for line in append_str.split("\n"))
    116 
    117   # Step 1: Find mention of "Args":
    118   has_args_ix = [
    119       ix for ix, line in enumerate(old_str_lines)
    120       if line.strip().lower() == "args:"]
    121   if has_args_ix:
    122     final_args_ix = has_args_ix[-1]
    123     return ("\n".join(old_str_lines[:final_args_ix])
    124             + "\n\n" + append_str + "\n\n"
    125             + "\n".join(old_str_lines[final_args_ix:]))
    126   else:
    127     return old_str + "\n\n" + append_str
    128 
    129 
    130 class _DistributionMeta(abc.ABCMeta):
    131 
    132   def __new__(mcs, classname, baseclasses, attrs):
    133     """Control the creation of subclasses of the Distribution class.
    134 
    135     The main purpose of this method is to properly propagate docstrings
    136     from private Distribution methods, like `_log_prob`, into their
    137     public wrappers as inherited by the Distribution base class
    138     (e.g. `log_prob`).
    139 
    140     Args:
    141       classname: The name of the subclass being created.
    142       baseclasses: A tuple of parent classes.
    143       attrs: A dict mapping new attributes to their values.
    144 
    145     Returns:
    146       The class object.
    147 
    148     Raises:
    149       TypeError: If `Distribution` is not a subclass of `BaseDistribution`, or
    150         the new class is derived via multiple inheritance and the first
    151         parent class is not a subclass of `BaseDistribution`.
    152       AttributeError:  If `Distribution` does not implement e.g. `log_prob`.
    153       ValueError:  If a `Distribution` public method lacks a docstring.
    154     """
    155     if not baseclasses:  # Nothing to be done for Distribution
    156       raise TypeError("Expected non-empty baseclass. Does Distribution "
    157                       "not subclass _BaseDistribution?")
    158     which_base = [
    159         base for base in baseclasses
    160         if base == _BaseDistribution or issubclass(base, Distribution)]
    161     base = which_base[0]
    162     if base == _BaseDistribution:  # Nothing to be done for Distribution
    163       return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs)
    164     if not issubclass(base, Distribution):
    165       raise TypeError("First parent class declared for %s must be "
    166                       "Distribution, but saw '%s'" % (classname, base.__name__))
    167     for attr in _DISTRIBUTION_PUBLIC_METHOD_WRAPPERS:
    168       special_attr = "_%s" % attr
    169       class_attr_value = attrs.get(attr, None)
    170       if attr in attrs:
    171         # The method is being overridden, do not update its docstring
    172         continue
    173       base_attr_value = getattr(base, attr, None)
    174       if not base_attr_value:
    175         raise AttributeError(
    176             "Internal error: expected base class '%s' to implement method '%s'"
    177             % (base.__name__, attr))
    178       class_special_attr_value = attrs.get(special_attr, None)
    179       if class_special_attr_value is None:
    180         # No _special method available, no need to update the docstring.
    181         continue
    182       class_special_attr_docstring = tf_inspect.getdoc(class_special_attr_value)
    183       if not class_special_attr_docstring:
    184         # No docstring to append.
    185         continue
    186       class_attr_value = _copy_fn(base_attr_value)
    187       class_attr_docstring = tf_inspect.getdoc(base_attr_value)
    188       if class_attr_docstring is None:
    189         raise ValueError(
    190             "Expected base class fn to contain a docstring: %s.%s"
    191             % (base.__name__, attr))
    192       class_attr_value.__doc__ = _update_docstring(
    193           class_attr_value.__doc__,
    194           ("Additional documentation from `%s`:\n\n%s"
    195            % (classname, class_special_attr_docstring)))
    196       attrs[attr] = class_attr_value
    197 
    198     return abc.ABCMeta.__new__(mcs, classname, baseclasses, attrs)
    199 
    200 
    201 @tf_export("distributions.ReparameterizationType")
    202 class ReparameterizationType(object):
    203   """Instances of this class represent how sampling is reparameterized.
    204 
    205   Two static instances exist in the distributions library, signifying
    206   one of two possible properties for samples from a distribution:
    207 
    208   `FULLY_REPARAMETERIZED`: Samples from the distribution are fully
    209     reparameterized, and straight-through gradients are supported.
    210 
    211   `NOT_REPARAMETERIZED`: Samples from the distribution are not fully
    212     reparameterized, and straight-through gradients are either partially
    213     unsupported or are not supported at all. In this case, for purposes of
    214     e.g. RL or variational inference, it is generally safest to wrap the
    215     sample results in a `stop_gradients` call and instead use policy
    216     gradients / surrogate loss instead.
    217   """
    218 
    219   def __init__(self, rep_type):
    220     self._rep_type = rep_type
    221 
    222   def __repr__(self):
    223     return "<Reparameteriation Type: %s>" % self._rep_type
    224 
    225   def __eq__(self, other):
    226     """Determine if this `ReparameterizationType` is equal to another.
    227 
    228     Since RepaparameterizationType instances are constant static global
    229     instances, equality checks if two instances' id() values are equal.
    230 
    231     Args:
    232       other: Object to compare against.
    233 
    234     Returns:
    235       `self is other`.
    236     """
    237     return self is other
    238 
    239 
    240 # Fully reparameterized distribution: samples from a fully
    241 # reparameterized distribution support straight-through gradients with
    242 # respect to all parameters.
    243 FULLY_REPARAMETERIZED = ReparameterizationType("FULLY_REPARAMETERIZED")
    244 tf_export("distributions.FULLY_REPARAMETERIZED").export_constant(
    245     __name__, "FULLY_REPARAMETERIZED")
    246 
    247 
    248 # Not reparameterized distribution: samples from a non-
    249 # reparameterized distribution do not support straight-through gradients for
    250 # at least some of the parameters.
    251 NOT_REPARAMETERIZED = ReparameterizationType("NOT_REPARAMETERIZED")
    252 tf_export("distributions.NOT_REPARAMETERIZED").export_constant(
    253     __name__, "NOT_REPARAMETERIZED")
    254 
    255 
    256 @six.add_metaclass(_DistributionMeta)
    257 @tf_export("distributions.Distribution")
    258 class Distribution(_BaseDistribution):
    259   """A generic probability distribution base class.
    260 
    261   `Distribution` is a base class for constructing and organizing properties
    262   (e.g., mean, variance) of random variables (e.g, Bernoulli, Gaussian).
    263 
    264   #### Subclassing
    265 
    266   Subclasses are expected to implement a leading-underscore version of the
    267   same-named function. The argument signature should be identical except for
    268   the omission of `name="..."`. For example, to enable `log_prob(value,
    269   name="log_prob")` a subclass should implement `_log_prob(value)`.
    270 
    271   Subclasses can append to public-level docstrings by providing
    272   docstrings for their method specializations. For example:
    273 
    274   ```python
    275   @util.AppendDocstring("Some other details.")
    276   def _log_prob(self, value):
    277     ...
    278   ```
    279 
    280   would add the string "Some other details." to the `log_prob` function
    281   docstring. This is implemented as a simple decorator to avoid python
    282   linter complaining about missing Args/Returns/Raises sections in the
    283   partial docstrings.
    284 
    285   #### Broadcasting, batching, and shapes
    286 
    287   All distributions support batches of independent distributions of that type.
    288   The batch shape is determined by broadcasting together the parameters.
    289 
    290   The shape of arguments to `__init__`, `cdf`, `log_cdf`, `prob`, and
    291   `log_prob` reflect this broadcasting, as does the return value of `sample` and
    292   `sample_n`.
    293 
    294   `sample_n_shape = [n] + batch_shape + event_shape`, where `sample_n_shape` is
    295   the shape of the `Tensor` returned from `sample_n`, `n` is the number of
    296   samples, `batch_shape` defines how many independent distributions there are,
    297   and `event_shape` defines the shape of samples from each of those independent
    298   distributions. Samples are independent along the `batch_shape` dimensions, but
    299   not necessarily so along the `event_shape` dimensions (depending on the
    300   particulars of the underlying distribution).
    301 
    302   Using the `Uniform` distribution as an example:
    303 
    304   ```python
    305   minval = 3.0
    306   maxval = [[4.0, 6.0],
    307             [10.0, 12.0]]
    308 
    309   # Broadcasting:
    310   # This instance represents 4 Uniform distributions. Each has a lower bound at
    311   # 3.0 as the `minval` parameter was broadcasted to match `maxval`'s shape.
    312   u = Uniform(minval, maxval)
    313 
    314   # `event_shape` is `TensorShape([])`.
    315   event_shape = u.event_shape
    316   # `event_shape_t` is a `Tensor` which will evaluate to [].
    317   event_shape_t = u.event_shape_tensor()
    318 
    319   # Sampling returns a sample per distribution. `samples` has shape
    320   # [5, 2, 2], which is [n] + batch_shape + event_shape, where n=5,
    321   # batch_shape=[2, 2], and event_shape=[].
    322   samples = u.sample_n(5)
    323 
    324   # The broadcasting holds across methods. Here we use `cdf` as an example. The
    325   # same holds for `log_cdf` and the likelihood functions.
    326 
    327   # `cum_prob` has shape [2, 2] as the `value` argument was broadcasted to the
    328   # shape of the `Uniform` instance.
    329   cum_prob_broadcast = u.cdf(4.0)
    330 
    331   # `cum_prob`'s shape is [2, 2], one per distribution. No broadcasting
    332   # occurred.
    333   cum_prob_per_dist = u.cdf([[4.0, 5.0],
    334                              [6.0, 7.0]])
    335 
    336   # INVALID as the `value` argument is not broadcastable to the distribution's
    337   # shape.
    338   cum_prob_invalid = u.cdf([4.0, 5.0, 6.0])
    339   ```
    340 
    341   #### Parameter values leading to undefined statistics or distributions.
    342 
    343   Some distributions do not have well-defined statistics for all initialization
    344   parameter values. For example, the beta distribution is parameterized by
    345   positive real numbers `concentration1` and `concentration0`, and does not have
    346   well-defined mode if `concentration1 < 1` or `concentration0 < 1`.
    347 
    348   The user is given the option of raising an exception or returning `NaN`.
    349 
    350   ```python
    351   a = tf.exp(tf.matmul(logits, weights_a))
    352   b = tf.exp(tf.matmul(logits, weights_b))
    353 
    354   # Will raise exception if ANY batch member has a < 1 or b < 1.
    355   dist = distributions.beta(a, b, allow_nan_stats=False)
    356   mode = dist.mode().eval()
    357 
    358   # Will return NaN for batch members with either a < 1 or b < 1.
    359   dist = distributions.beta(a, b, allow_nan_stats=True)  # Default behavior
    360   mode = dist.mode().eval()
    361   ```
    362 
    363   In all cases, an exception is raised if *invalid* parameters are passed, e.g.
    364 
    365   ```python
    366   # Will raise an exception if any Op is run.
    367   negative_a = -1.0 * a  # beta distribution by definition has a > 0.
    368   dist = distributions.beta(negative_a, b, allow_nan_stats=True)
    369   dist.mean().eval()
    370   ```
    371 
    372   """
    373 
    374   def __init__(self,
    375                dtype,
    376                reparameterization_type,
    377                validate_args,
    378                allow_nan_stats,
    379                parameters=None,
    380                graph_parents=None,
    381                name=None):
    382     """Constructs the `Distribution`.
    383 
    384     **This is a private method for subclass use.**
    385 
    386     Args:
    387       dtype: The type of the event samples. `None` implies no type-enforcement.
    388       reparameterization_type: Instance of `ReparameterizationType`.
    389         If `distributions.FULLY_REPARAMETERIZED`, this
    390         `Distribution` can be reparameterized in terms of some standard
    391         distribution with a function whose Jacobian is constant for the support
    392         of the standard distribution. If `distributions.NOT_REPARAMETERIZED`,
    393         then no such reparameterization is available.
    394       validate_args: Python `bool`, default `False`. When `True` distribution
    395         parameters are checked for validity despite possibly degrading runtime
    396         performance. When `False` invalid inputs may silently render incorrect
    397         outputs.
    398       allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
    399         (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
    400         result is undefined. When `False`, an exception is raised if one or
    401         more of the statistic's batch members are undefined.
    402       parameters: Python `dict` of parameters used to instantiate this
    403         `Distribution`.
    404       graph_parents: Python `list` of graph prerequisites of this
    405         `Distribution`.
    406       name: Python `str` name prefixed to Ops created by this class. Default:
    407         subclass name.
    408 
    409     Raises:
    410       ValueError: if any member of graph_parents is `None` or not a `Tensor`.
    411     """
    412     graph_parents = [] if graph_parents is None else graph_parents
    413     for i, t in enumerate(graph_parents):
    414       if t is None or not tensor_util.is_tensor(t):
    415         raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
    416     self._dtype = dtype
    417     self._reparameterization_type = reparameterization_type
    418     self._allow_nan_stats = allow_nan_stats
    419     self._validate_args = validate_args
    420     self._parameters = parameters or {}
    421     self._graph_parents = graph_parents
    422     self._name = name or type(self).__name__
    423 
    424   @classmethod
    425   def param_shapes(cls, sample_shape, name="DistributionParamShapes"):
    426     """Shapes of parameters given the desired shape of a call to `sample()`.
    427 
    428     This is a class method that describes what key/value arguments are required
    429     to instantiate the given `Distribution` so that a particular shape is
    430     returned for that instance's call to `sample()`.
    431 
    432     Subclasses should override class method `_param_shapes`.
    433 
    434     Args:
    435       sample_shape: `Tensor` or python list/tuple. Desired shape of a call to
    436         `sample()`.
    437       name: name to prepend ops with.
    438 
    439     Returns:
    440       `dict` of parameter name to `Tensor` shapes.
    441     """
    442     with ops.name_scope(name, values=[sample_shape]):
    443       return cls._param_shapes(sample_shape)
    444 
    445   @classmethod
    446   def param_static_shapes(cls, sample_shape):
    447     """param_shapes with static (i.e. `TensorShape`) shapes.
    448 
    449     This is a class method that describes what key/value arguments are required
    450     to instantiate the given `Distribution` so that a particular shape is
    451     returned for that instance's call to `sample()`. Assumes that the sample's
    452     shape is known statically.
    453 
    454     Subclasses should override class method `_param_shapes` to return
    455     constant-valued tensors when constant values are fed.
    456 
    457     Args:
    458       sample_shape: `TensorShape` or python list/tuple. Desired shape of a call
    459         to `sample()`.
    460 
    461     Returns:
    462       `dict` of parameter name to `TensorShape`.
    463 
    464     Raises:
    465       ValueError: if `sample_shape` is a `TensorShape` and is not fully defined.
    466     """
    467     if isinstance(sample_shape, tensor_shape.TensorShape):
    468       if not sample_shape.is_fully_defined():
    469         raise ValueError("TensorShape sample_shape must be fully defined")
    470       sample_shape = sample_shape.as_list()
    471 
    472     params = cls.param_shapes(sample_shape)
    473 
    474     static_params = {}
    475     for name, shape in params.items():
    476       static_shape = tensor_util.constant_value(shape)
    477       if static_shape is None:
    478         raise ValueError(
    479             "sample_shape must be a fully-defined TensorShape or list/tuple")
    480       static_params[name] = tensor_shape.TensorShape(static_shape)
    481 
    482     return static_params
    483 
    484   @staticmethod
    485   def _param_shapes(sample_shape):
    486     raise NotImplementedError("_param_shapes not implemented")
    487 
    488   @property
    489   def name(self):
    490     """Name prepended to all ops created by this `Distribution`."""
    491     return self._name
    492 
    493   @property
    494   def dtype(self):
    495     """The `DType` of `Tensor`s handled by this `Distribution`."""
    496     return self._dtype
    497 
    498   @property
    499   def parameters(self):
    500     """Dictionary of parameters used to instantiate this `Distribution`."""
    501     # Remove "self", "__class__", or other special variables. These can appear
    502     # if the subclass used `parameters = locals()`.
    503     return dict((k, v) for k, v in self._parameters.items()
    504                 if not k.startswith("__") and k != "self")
    505 
    506   @property
    507   def reparameterization_type(self):
    508     """Describes how samples from the distribution are reparameterized.
    509 
    510     Currently this is one of the static instances
    511     `distributions.FULLY_REPARAMETERIZED`
    512     or `distributions.NOT_REPARAMETERIZED`.
    513 
    514     Returns:
    515       An instance of `ReparameterizationType`.
    516     """
    517     return self._reparameterization_type
    518 
    519   @property
    520   def allow_nan_stats(self):
    521     """Python `bool` describing behavior when a stat is undefined.
    522 
    523     Stats return +/- infinity when it makes sense. E.g., the variance of a
    524     Cauchy distribution is infinity. However, sometimes the statistic is
    525     undefined, e.g., if a distribution's pdf does not achieve a maximum within
    526     the support of the distribution, the mode is undefined. If the mean is
    527     undefined, then by definition the variance is undefined. E.g. the mean for
    528     Student's T for df = 1 is undefined (no clear way to say it is either + or -
    529     infinity), so the variance = E[(X - mean)**2] is also undefined.
    530 
    531     Returns:
    532       allow_nan_stats: Python `bool`.
    533     """
    534     return self._allow_nan_stats
    535 
    536   @property
    537   def validate_args(self):
    538     """Python `bool` indicating possibly expensive checks are enabled."""
    539     return self._validate_args
    540 
    541   def copy(self, **override_parameters_kwargs):
    542     """Creates a deep copy of the distribution.
    543 
    544     Note: the copy distribution may continue to depend on the original
    545     initialization arguments.
    546 
    547     Args:
    548       **override_parameters_kwargs: String/value dictionary of initialization
    549         arguments to override with new values.
    550 
    551     Returns:
    552       distribution: A new instance of `type(self)` initialized from the union
    553         of self.parameters and override_parameters_kwargs, i.e.,
    554         `dict(self.parameters, **override_parameters_kwargs)`.
    555     """
    556     parameters = dict(self.parameters, **override_parameters_kwargs)
    557     return type(self)(**parameters)
    558 
    559   def _batch_shape_tensor(self):
    560     raise NotImplementedError("batch_shape_tensor is not implemented")
    561 
    562   def batch_shape_tensor(self, name="batch_shape_tensor"):
    563     """Shape of a single sample from a single event index as a 1-D `Tensor`.
    564 
    565     The batch dimensions are indexes into independent, non-identical
    566     parameterizations of this distribution.
    567 
    568     Args:
    569       name: name to give to the op
    570 
    571     Returns:
    572       batch_shape: `Tensor`.
    573     """
    574     with self._name_scope(name):
    575       if self.batch_shape.is_fully_defined():
    576         return ops.convert_to_tensor(self.batch_shape.as_list(),
    577                                      dtype=dtypes.int32,
    578                                      name="batch_shape")
    579       return self._batch_shape_tensor()
    580 
    581   def _batch_shape(self):
    582     return tensor_shape.TensorShape(None)
    583 
    584   @property
    585   def batch_shape(self):
    586     """Shape of a single sample from a single event index as a `TensorShape`.
    587 
    588     May be partially defined or unknown.
    589 
    590     The batch dimensions are indexes into independent, non-identical
    591     parameterizations of this distribution.
    592 
    593     Returns:
    594       batch_shape: `TensorShape`, possibly unknown.
    595     """
    596     return self._batch_shape()
    597 
    598   def _event_shape_tensor(self):
    599     raise NotImplementedError("event_shape_tensor is not implemented")
    600 
    601   def event_shape_tensor(self, name="event_shape_tensor"):
    602     """Shape of a single sample from a single batch as a 1-D int32 `Tensor`.
    603 
    604     Args:
    605       name: name to give to the op
    606 
    607     Returns:
    608       event_shape: `Tensor`.
    609     """
    610     with self._name_scope(name):
    611       if self.event_shape.is_fully_defined():
    612         return ops.convert_to_tensor(self.event_shape.as_list(),
    613                                      dtype=dtypes.int32,
    614                                      name="event_shape")
    615       return self._event_shape_tensor()
    616 
    617   def _event_shape(self):
    618     return tensor_shape.TensorShape(None)
    619 
    620   @property
    621   def event_shape(self):
    622     """Shape of a single sample from a single batch as a `TensorShape`.
    623 
    624     May be partially defined or unknown.
    625 
    626     Returns:
    627       event_shape: `TensorShape`, possibly unknown.
    628     """
    629     return self._event_shape()
    630 
    631   def is_scalar_event(self, name="is_scalar_event"):
    632     """Indicates that `event_shape == []`.
    633 
    634     Args:
    635       name: Python `str` prepended to names of ops created by this function.
    636 
    637     Returns:
    638       is_scalar_event: `bool` scalar `Tensor`.
    639     """
    640     with self._name_scope(name):
    641       return ops.convert_to_tensor(
    642           self._is_scalar_helper(self.event_shape, self.event_shape_tensor),
    643           name="is_scalar_event")
    644 
    645   def is_scalar_batch(self, name="is_scalar_batch"):
    646     """Indicates that `batch_shape == []`.
    647 
    648     Args:
    649       name: Python `str` prepended to names of ops created by this function.
    650 
    651     Returns:
    652       is_scalar_batch: `bool` scalar `Tensor`.
    653     """
    654     with self._name_scope(name):
    655       return ops.convert_to_tensor(
    656           self._is_scalar_helper(self.batch_shape, self.batch_shape_tensor),
    657           name="is_scalar_batch")
    658 
    659   def _sample_n(self, n, seed=None):
    660     raise NotImplementedError("sample_n is not implemented")
    661 
    662   def _call_sample_n(self, sample_shape, seed, name, **kwargs):
    663     with self._name_scope(name, values=[sample_shape]):
    664       sample_shape = ops.convert_to_tensor(
    665           sample_shape, dtype=dtypes.int32, name="sample_shape")
    666       sample_shape, n = self._expand_sample_shape_to_vector(
    667           sample_shape, "sample_shape")
    668       samples = self._sample_n(n, seed, **kwargs)
    669       batch_event_shape = array_ops.shape(samples)[1:]
    670       final_shape = array_ops.concat([sample_shape, batch_event_shape], 0)
    671       samples = array_ops.reshape(samples, final_shape)
    672       samples = self._set_sample_static_shape(samples, sample_shape)
    673       return samples
    674 
    675   def sample(self, sample_shape=(), seed=None, name="sample"):
    676     """Generate samples of the specified shape.
    677 
    678     Note that a call to `sample()` without arguments will generate a single
    679     sample.
    680 
    681     Args:
    682       sample_shape: 0D or 1D `int32` `Tensor`. Shape of the generated samples.
    683       seed: Python integer seed for RNG
    684       name: name to give to the op.
    685 
    686     Returns:
    687       samples: a `Tensor` with prepended dimensions `sample_shape`.
    688     """
    689     return self._call_sample_n(sample_shape, seed, name)
    690 
    691   def _log_prob(self, value):
    692     raise NotImplementedError("log_prob is not implemented")
    693 
    694   def _call_log_prob(self, value, name, **kwargs):
    695     with self._name_scope(name, values=[value]):
    696       value = ops.convert_to_tensor(value, name="value")
    697       try:
    698         return self._log_prob(value, **kwargs)
    699       except NotImplementedError as original_exception:
    700         try:
    701           return math_ops.log(self._prob(value, **kwargs))
    702         except NotImplementedError:
    703           raise original_exception
    704 
    705   def log_prob(self, value, name="log_prob"):
    706     """Log probability density/mass function.
    707 
    708     Args:
    709       value: `float` or `double` `Tensor`.
    710       name: Python `str` prepended to names of ops created by this function.
    711 
    712     Returns:
    713       log_prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
    714         values of type `self.dtype`.
    715     """
    716     return self._call_log_prob(value, name)
    717 
    718   def _prob(self, value):
    719     raise NotImplementedError("prob is not implemented")
    720 
    721   def _call_prob(self, value, name, **kwargs):
    722     with self._name_scope(name, values=[value]):
    723       value = ops.convert_to_tensor(value, name="value")
    724       try:
    725         return self._prob(value, **kwargs)
    726       except NotImplementedError as original_exception:
    727         try:
    728           return math_ops.exp(self._log_prob(value, **kwargs))
    729         except NotImplementedError:
    730           raise original_exception
    731 
    732   def prob(self, value, name="prob"):
    733     """Probability density/mass function.
    734 
    735     Args:
    736       value: `float` or `double` `Tensor`.
    737       name: Python `str` prepended to names of ops created by this function.
    738 
    739     Returns:
    740       prob: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
    741         values of type `self.dtype`.
    742     """
    743     return self._call_prob(value, name)
    744 
    745   def _log_cdf(self, value):
    746     raise NotImplementedError("log_cdf is not implemented")
    747 
    748   def _call_log_cdf(self, value, name, **kwargs):
    749     with self._name_scope(name, values=[value]):
    750       value = ops.convert_to_tensor(value, name="value")
    751       try:
    752         return self._log_cdf(value, **kwargs)
    753       except NotImplementedError as original_exception:
    754         try:
    755           return math_ops.log(self._cdf(value, **kwargs))
    756         except NotImplementedError:
    757           raise original_exception
    758 
    759   def log_cdf(self, value, name="log_cdf"):
    760     """Log cumulative distribution function.
    761 
    762     Given random variable `X`, the cumulative distribution function `cdf` is:
    763 
    764     ```none
    765     log_cdf(x) := Log[ P[X <= x] ]
    766     ```
    767 
    768     Often, a numerical approximation can be used for `log_cdf(x)` that yields
    769     a more accurate answer than simply taking the logarithm of the `cdf` when
    770     `x << -1`.
    771 
    772     Args:
    773       value: `float` or `double` `Tensor`.
    774       name: Python `str` prepended to names of ops created by this function.
    775 
    776     Returns:
    777       logcdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
    778         values of type `self.dtype`.
    779     """
    780     return self._call_log_cdf(value, name)
    781 
    782   def _cdf(self, value):
    783     raise NotImplementedError("cdf is not implemented")
    784 
    785   def _call_cdf(self, value, name, **kwargs):
    786     with self._name_scope(name, values=[value]):
    787       value = ops.convert_to_tensor(value, name="value")
    788       try:
    789         return self._cdf(value, **kwargs)
    790       except NotImplementedError as original_exception:
    791         try:
    792           return math_ops.exp(self._log_cdf(value, **kwargs))
    793         except NotImplementedError:
    794           raise original_exception
    795 
    796   def cdf(self, value, name="cdf"):
    797     """Cumulative distribution function.
    798 
    799     Given random variable `X`, the cumulative distribution function `cdf` is:
    800 
    801     ```none
    802     cdf(x) := P[X <= x]
    803     ```
    804 
    805     Args:
    806       value: `float` or `double` `Tensor`.
    807       name: Python `str` prepended to names of ops created by this function.
    808 
    809     Returns:
    810       cdf: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
    811         values of type `self.dtype`.
    812     """
    813     return self._call_cdf(value, name)
    814 
    815   def _log_survival_function(self, value):
    816     raise NotImplementedError("log_survival_function is not implemented")
    817 
    818   def _call_log_survival_function(self, value, name, **kwargs):
    819     with self._name_scope(name, values=[value]):
    820       value = ops.convert_to_tensor(value, name="value")
    821       try:
    822         return self._log_survival_function(value, **kwargs)
    823       except NotImplementedError as original_exception:
    824         try:
    825           return math_ops.log1p(-self.cdf(value, **kwargs))
    826         except NotImplementedError:
    827           raise original_exception
    828 
    829   def log_survival_function(self, value, name="log_survival_function"):
    830     """Log survival function.
    831 
    832     Given random variable `X`, the survival function is defined:
    833 
    834     ```none
    835     log_survival_function(x) = Log[ P[X > x] ]
    836                              = Log[ 1 - P[X <= x] ]
    837                              = Log[ 1 - cdf(x) ]
    838     ```
    839 
    840     Typically, different numerical approximations can be used for the log
    841     survival function, which are more accurate than `1 - cdf(x)` when `x >> 1`.
    842 
    843     Args:
    844       value: `float` or `double` `Tensor`.
    845       name: Python `str` prepended to names of ops created by this function.
    846 
    847     Returns:
    848       `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type
    849         `self.dtype`.
    850     """
    851     return self._call_log_survival_function(value, name)
    852 
    853   def _survival_function(self, value):
    854     raise NotImplementedError("survival_function is not implemented")
    855 
    856   def _call_survival_function(self, value, name, **kwargs):
    857     with self._name_scope(name, values=[value]):
    858       value = ops.convert_to_tensor(value, name="value")
    859       try:
    860         return self._survival_function(value, **kwargs)
    861       except NotImplementedError as original_exception:
    862         try:
    863           return 1. - self.cdf(value, **kwargs)
    864         except NotImplementedError:
    865           raise original_exception
    866 
    867   def survival_function(self, value, name="survival_function"):
    868     """Survival function.
    869 
    870     Given random variable `X`, the survival function is defined:
    871 
    872     ```none
    873     survival_function(x) = P[X > x]
    874                          = 1 - P[X <= x]
    875                          = 1 - cdf(x).
    876     ```
    877 
    878     Args:
    879       value: `float` or `double` `Tensor`.
    880       name: Python `str` prepended to names of ops created by this function.
    881 
    882     Returns:
    883       `Tensor` of shape `sample_shape(x) + self.batch_shape` with values of type
    884         `self.dtype`.
    885     """
    886     return self._call_survival_function(value, name)
    887 
    888   def _entropy(self):
    889     raise NotImplementedError("entropy is not implemented")
    890 
    891   def entropy(self, name="entropy"):
    892     """Shannon entropy in nats."""
    893     with self._name_scope(name):
    894       return self._entropy()
    895 
    896   def _mean(self):
    897     raise NotImplementedError("mean is not implemented")
    898 
    899   def mean(self, name="mean"):
    900     """Mean."""
    901     with self._name_scope(name):
    902       return self._mean()
    903 
    904   def _quantile(self, value):
    905     raise NotImplementedError("quantile is not implemented")
    906 
    907   def _call_quantile(self, value, name, **kwargs):
    908     with self._name_scope(name, values=[value]):
    909       value = ops.convert_to_tensor(value, name="value")
    910       try:
    911         return self._quantile(value, **kwargs)
    912       except NotImplementedError as original_exception:
    913         raise original_exception
    914 
    915   def quantile(self, value, name="quantile"):
    916     """Quantile function. Aka "inverse cdf" or "percent point function".
    917 
    918     Given random variable `X` and `p in [0, 1]`, the `quantile` is:
    919 
    920     ```none
    921     quantile(p) := x such that P[X <= x] == p
    922     ```
    923 
    924     Args:
    925       value: `float` or `double` `Tensor`.
    926       name: Python `str` prepended to names of ops created by this function.
    927 
    928     Returns:
    929       quantile: a `Tensor` of shape `sample_shape(x) + self.batch_shape` with
    930         values of type `self.dtype`.
    931     """
    932     return self._call_quantile(value, name)
    933 
    934   def _variance(self):
    935     raise NotImplementedError("variance is not implemented")
    936 
    937   def variance(self, name="variance"):
    938     """Variance.
    939 
    940     Variance is defined as,
    941 
    942     ```none
    943     Var = E[(X - E[X])**2]
    944     ```
    945 
    946     where `X` is the random variable associated with this distribution, `E`
    947     denotes expectation, and `Var.shape = batch_shape + event_shape`.
    948 
    949     Args:
    950       name: Python `str` prepended to names of ops created by this function.
    951 
    952     Returns:
    953       variance: Floating-point `Tensor` with shape identical to
    954         `batch_shape + event_shape`, i.e., the same shape as `self.mean()`.
    955     """
    956     with self._name_scope(name):
    957       try:
    958         return self._variance()
    959       except NotImplementedError as original_exception:
    960         try:
    961           return math_ops.square(self._stddev())
    962         except NotImplementedError:
    963           raise original_exception
    964 
    965   def _stddev(self):
    966     raise NotImplementedError("stddev is not implemented")
    967 
    968   def stddev(self, name="stddev"):
    969     """Standard deviation.
    970 
    971     Standard deviation is defined as,
    972 
    973     ```none
    974     stddev = E[(X - E[X])**2]**0.5
    975     ```
    976 
    977     where `X` is the random variable associated with this distribution, `E`
    978     denotes expectation, and `stddev.shape = batch_shape + event_shape`.
    979 
    980     Args:
    981       name: Python `str` prepended to names of ops created by this function.
    982 
    983     Returns:
    984       stddev: Floating-point `Tensor` with shape identical to
    985         `batch_shape + event_shape`, i.e., the same shape as `self.mean()`.
    986     """
    987 
    988     with self._name_scope(name):
    989       try:
    990         return self._stddev()
    991       except NotImplementedError as original_exception:
    992         try:
    993           return math_ops.sqrt(self._variance())
    994         except NotImplementedError:
    995           raise original_exception
    996 
    997   def _covariance(self):
    998     raise NotImplementedError("covariance is not implemented")
    999 
   1000   def covariance(self, name="covariance"):
   1001     """Covariance.
   1002 
   1003     Covariance is (possibly) defined only for non-scalar-event distributions.
   1004 
   1005     For example, for a length-`k`, vector-valued distribution, it is calculated
   1006     as,
   1007 
   1008     ```none
   1009     Cov[i, j] = Covariance(X_i, X_j) = E[(X_i - E[X_i]) (X_j - E[X_j])]
   1010     ```
   1011 
   1012     where `Cov` is a (batch of) `k x k` matrix, `0 <= (i, j) < k`, and `E`
   1013     denotes expectation.
   1014 
   1015     Alternatively, for non-vector, multivariate distributions (e.g.,
   1016     matrix-valued, Wishart), `Covariance` shall return a (batch of) matrices
   1017     under some vectorization of the events, i.e.,
   1018 
   1019     ```none
   1020     Cov[i, j] = Covariance(Vec(X)_i, Vec(X)_j) = [as above]
   1021     ```
   1022 
   1023     where `Cov` is a (batch of) `k' x k'` matrices,
   1024     `0 <= (i, j) < k' = reduce_prod(event_shape)`, and `Vec` is some function
   1025     mapping indices of this distribution's event dimensions to indices of a
   1026     length-`k'` vector.
   1027 
   1028     Args:
   1029       name: Python `str` prepended to names of ops created by this function.
   1030 
   1031     Returns:
   1032       covariance: Floating-point `Tensor` with shape `[B1, ..., Bn, k', k']`
   1033         where the first `n` dimensions are batch coordinates and
   1034         `k' = reduce_prod(self.event_shape)`.
   1035     """
   1036     with self._name_scope(name):
   1037       return self._covariance()
   1038 
   1039   def _mode(self):
   1040     raise NotImplementedError("mode is not implemented")
   1041 
   1042   def mode(self, name="mode"):
   1043     """Mode."""
   1044     with self._name_scope(name):
   1045       return self._mode()
   1046 
   1047   def _cross_entropy(self, other):
   1048     return kullback_leibler.cross_entropy(
   1049         self, other, allow_nan_stats=self.allow_nan_stats)
   1050 
   1051   def cross_entropy(self, other, name="cross_entropy"):
   1052     """Computes the (Shannon) cross entropy.
   1053 
   1054     Denote this distribution (`self`) by `P` and the `other` distribution by
   1055     `Q`. Assuming `P, Q` are absolutely continuous with respect to
   1056     one another and permit densities `p(x) dr(x)` and `q(x) dr(x)`, (Shanon)
   1057     cross entropy is defined as:
   1058 
   1059     ```none
   1060     H[P, Q] = E_p[-log q(X)] = -int_F p(x) log q(x) dr(x)
   1061     ```
   1062 
   1063     where `F` denotes the support of the random variable `X ~ P`.
   1064 
   1065     Args:
   1066       other: `tf.distributions.Distribution` instance.
   1067       name: Python `str` prepended to names of ops created by this function.
   1068 
   1069     Returns:
   1070       cross_entropy: `self.dtype` `Tensor` with shape `[B1, ..., Bn]`
   1071         representing `n` different calculations of (Shanon) cross entropy.
   1072     """
   1073     with self._name_scope(name):
   1074       return self._cross_entropy(other)
   1075 
   1076   def _kl_divergence(self, other):
   1077     return kullback_leibler.kl_divergence(
   1078         self, other, allow_nan_stats=self.allow_nan_stats)
   1079 
   1080   def kl_divergence(self, other, name="kl_divergence"):
   1081     """Computes the Kullback--Leibler divergence.
   1082 
   1083     Denote this distribution (`self`) by `p` and the `other` distribution by
   1084     `q`. Assuming `p, q` are absolutely continuous with respect to reference
   1085     measure `r`, the KL divergence is defined as:
   1086 
   1087     ```none
   1088     KL[p, q] = E_p[log(p(X)/q(X))]
   1089              = -int_F p(x) log q(x) dr(x) + int_F p(x) log p(x) dr(x)
   1090              = H[p, q] - H[p]
   1091     ```
   1092 
   1093     where `F` denotes the support of the random variable `X ~ p`, `H[., .]`
   1094     denotes (Shanon) cross entropy, and `H[.]` denotes (Shanon) entropy.
   1095 
   1096     Args:
   1097       other: `tf.distributions.Distribution` instance.
   1098       name: Python `str` prepended to names of ops created by this function.
   1099 
   1100     Returns:
   1101       kl_divergence: `self.dtype` `Tensor` with shape `[B1, ..., Bn]`
   1102         representing `n` different calculations of the Kullback-Leibler
   1103         divergence.
   1104     """
   1105     with self._name_scope(name):
   1106       return self._kl_divergence(other)
   1107 
   1108   @contextlib.contextmanager
   1109   def _name_scope(self, name=None, values=None):
   1110     """Helper function to standardize op scope."""
   1111     with ops.name_scope(self.name):
   1112       with ops.name_scope(name, values=(
   1113           ([] if values is None else values) + self._graph_parents)) as scope:
   1114         yield scope
   1115 
   1116   def _expand_sample_shape_to_vector(self, x, name):
   1117     """Helper to `sample` which ensures input is 1D."""
   1118     x_static_val = tensor_util.constant_value(x)
   1119     if x_static_val is None:
   1120       prod = math_ops.reduce_prod(x)
   1121     else:
   1122       prod = np.prod(x_static_val, dtype=x.dtype.as_numpy_dtype())
   1123 
   1124     ndims = x.get_shape().ndims  # != sample_ndims
   1125     if ndims is None:
   1126       # Maybe expand_dims.
   1127       ndims = array_ops.rank(x)
   1128       expanded_shape = util.pick_vector(
   1129           math_ops.equal(ndims, 0),
   1130           np.array([1], dtype=np.int32), array_ops.shape(x))
   1131       x = array_ops.reshape(x, expanded_shape)
   1132     elif ndims == 0:
   1133       # Definitely expand_dims.
   1134       if x_static_val is not None:
   1135         x = ops.convert_to_tensor(
   1136             np.array([x_static_val], dtype=x.dtype.as_numpy_dtype()),
   1137             name=name)
   1138       else:
   1139         x = array_ops.reshape(x, [1])
   1140     elif ndims != 1:
   1141       raise ValueError("Input is neither scalar nor vector.")
   1142 
   1143     return x, prod
   1144 
   1145   def _set_sample_static_shape(self, x, sample_shape):
   1146     """Helper to `sample`; sets static shape info."""
   1147     # Set shape hints.
   1148     sample_shape = tensor_shape.TensorShape(
   1149         tensor_util.constant_value(sample_shape))
   1150 
   1151     ndims = x.get_shape().ndims
   1152     sample_ndims = sample_shape.ndims
   1153     batch_ndims = self.batch_shape.ndims
   1154     event_ndims = self.event_shape.ndims
   1155 
   1156     # Infer rank(x).
   1157     if (ndims is None and
   1158         sample_ndims is not None and
   1159         batch_ndims is not None and
   1160         event_ndims is not None):
   1161       ndims = sample_ndims + batch_ndims + event_ndims
   1162       x.set_shape([None] * ndims)
   1163 
   1164     # Infer sample shape.
   1165     if ndims is not None and sample_ndims is not None:
   1166       shape = sample_shape.concatenate([None]*(ndims - sample_ndims))
   1167       x.set_shape(x.get_shape().merge_with(shape))
   1168 
   1169     # Infer event shape.
   1170     if ndims is not None and event_ndims is not None:
   1171       shape = tensor_shape.TensorShape(
   1172           [None]*(ndims - event_ndims)).concatenate(self.event_shape)
   1173       x.set_shape(x.get_shape().merge_with(shape))
   1174 
   1175     # Infer batch shape.
   1176     if batch_ndims is not None:
   1177       if ndims is not None:
   1178         if sample_ndims is None and event_ndims is not None:
   1179           sample_ndims = ndims - batch_ndims - event_ndims
   1180         elif event_ndims is None and sample_ndims is not None:
   1181           event_ndims = ndims - batch_ndims - sample_ndims
   1182       if sample_ndims is not None and event_ndims is not None:
   1183         shape = tensor_shape.TensorShape([None]*sample_ndims).concatenate(
   1184             self.batch_shape).concatenate([None]*event_ndims)
   1185         x.set_shape(x.get_shape().merge_with(shape))
   1186 
   1187     return x
   1188 
   1189   def _is_scalar_helper(self, static_shape, dynamic_shape_fn):
   1190     """Implementation for `is_scalar_batch` and `is_scalar_event`."""
   1191     if static_shape.ndims is not None:
   1192       return static_shape.ndims == 0
   1193     shape = dynamic_shape_fn()
   1194     if (shape.get_shape().ndims is not None and
   1195         shape.get_shape()[0].value is not None):
   1196       # If the static_shape is correctly written then we should never execute
   1197       # this branch. We keep it just in case there's some unimagined corner
   1198       # case.
   1199       return shape.get_shape().as_list() == [0]
   1200     return math_ops.equal(array_ops.shape(shape)[0], 0)
   1201