Home | History | Annotate | Download | only in distributions
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Bijector base."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import abc
     22 import collections
     23 import contextlib
     24 import re
     25 
     26 import numpy as np
     27 import six
     28 
     29 from tensorflow.python.framework import dtypes
     30 from tensorflow.python.framework import ops
     31 from tensorflow.python.framework import tensor_shape
     32 from tensorflow.python.framework import tensor_util
     33 from tensorflow.python.ops import array_ops
     34 from tensorflow.python.ops import math_ops
     35 from tensorflow.python.util.tf_export import tf_export
     36 
     37 
     38 __all__ = [
     39     "Bijector",
     40 ]
     41 
     42 
     43 class _Mapping(collections.namedtuple(
     44     "_Mapping", ["x", "y", "ildj", "kwargs"])):
     45   """Helper class to make it easier to manage caching in `Bijector`."""
     46 
     47   def __new__(cls, x=None, y=None, ildj=None, kwargs=None):
     48     """Custom __new__ so namedtuple items have defaults.
     49 
     50     Args:
     51       x: `Tensor`. Forward.
     52       y: `Tensor`. Inverse.
     53       ildj: `Tensor`. Inverse log det Jacobian.
     54       kwargs: Python dictionary. Extra args supplied to
     55         forward/inverse/etc functions.
     56 
     57     Returns:
     58       mapping: New instance of _Mapping.
     59     """
     60     return super(_Mapping, cls).__new__(cls, x, y, ildj, kwargs)
     61 
     62   @property
     63   def x_key(self):
     64     """Returns key used for caching Y=g(X)."""
     65     return (self.x,) + self._deep_tuple(tuple(sorted(self.kwargs.items())))
     66 
     67   @property
     68   def y_key(self):
     69     """Returns key used for caching X=g^{-1}(Y)."""
     70     return (self.y,) + self._deep_tuple(tuple(sorted(self.kwargs.items())))
     71 
     72   def merge(self, x=None, y=None, ildj=None, kwargs=None, mapping=None):
     73     """Returns new _Mapping with args merged with self.
     74 
     75     Args:
     76       x: `Tensor`. Forward.
     77       y: `Tensor`. Inverse.
     78       ildj: `Tensor`. Inverse log det Jacobian.
     79       kwargs: Python dictionary. Extra args supplied to
     80         forward/inverse/etc functions.
     81       mapping: Instance of _Mapping to merge. Can only be specified if no other
     82         arg is specified.
     83 
     84     Returns:
     85       mapping: New instance of `_Mapping` which has inputs merged with self.
     86 
     87     Raises:
     88       ValueError: if mapping and any other arg is not `None`.
     89     """
     90     if mapping is None:
     91       mapping = _Mapping(x=x, y=y, ildj=ildj, kwargs=kwargs)
     92     elif not all(arg is None for arg in [x, y, ildj, kwargs]):
     93       raise ValueError("Cannot specify mapping and individual args.")
     94     return _Mapping(
     95         x=self._merge(self.x, mapping.x),
     96         y=self._merge(self.y, mapping.y),
     97         ildj=self._merge(self.ildj, mapping.ildj),
     98         kwargs=self._merge(self.kwargs, mapping.kwargs))
     99 
    100   def _merge(self, old, new):
    101     """Helper to merge which handles merging one value."""
    102     if old is None:
    103       return new
    104     elif new is not None and old != new:
    105       raise ValueError("Incompatible values: %s != %s" % (old, new))
    106     return old
    107 
    108   def _deep_tuple(self, x):
    109     """Converts lists of lists to tuples of tuples."""
    110     return (tuple(map(self._deep_tuple, x))
    111             if isinstance(x, (list, tuple)) else x)
    112 
    113 
    114 @six.add_metaclass(abc.ABCMeta)
    115 @tf_export("distributions.bijectors.Bijector")
    116 class Bijector(object):
    117   r"""Interface for transformations of a `Distribution` sample.
    118 
    119   Bijectors can be used to represent any differentiable and injective
    120   (one to one) function defined on an open subset of `R^n`.  Some non-injective
    121   transformations are also supported (see "Non Injective Transforms" below).
    122 
    123   #### Mathematical Details
    124 
    125   A `Bijector` implements a [smooth covering map](
    126   https://en.wikipedia.org/wiki/Local_diffeomorphism), i.e., a local
    127   diffeomorphism such that every point in the target has a neighborhood evenly
    128   covered by a map ([see also](
    129   https://en.wikipedia.org/wiki/Covering_space#Covering_of_a_manifold)).
    130   A `Bijector` is used by `TransformedDistribution` but can be generally used
    131   for transforming a `Distribution` generated `Tensor`. A `Bijector` is
    132   characterized by three operations:
    133 
    134   1. Forward\
    135      Useful for turning one random outcome into another random outcome from a
    136      different distribution.
    137   2. Inverse\
    138      Useful for "reversing" a transformation to compute one probability in
    139      terms of another.
    140   3. `(log o det o Jacobian o inverse)(x)`\
    141      "The log of the determinant of the matrix of all first-order partial
    142      derivatives of the inverse function."\
    143      Useful for inverting a transformation to compute one probability in terms
    144      of another. Geometrically, the det(Jacobian) is the volume of the
    145      transformation and is used to scale the probability.
    146 
    147   By convention, transformations of random variables are named in terms of the
    148   forward transformation. The forward transformation creates samples, the
    149   inverse is useful for computing probabilities.
    150 
    151   #### Example Uses
    152 
    153   - Basic properties:
    154 
    155   ```python
    156   x = ...  # A tensor.
    157   # Evaluate forward transformation.
    158   fwd_x = my_bijector.forward(x)
    159   x == my_bijector.inverse(fwd_x)
    160   x != my_bijector.forward(fwd_x)  # Not equal because x != g(g(x)).
    161   ```
    162 
    163   - Computing a log-likelihood:
    164 
    165   ```python
    166   def transformed_log_prob(bijector, log_prob, x):
    167     return (bijector.inverse_log_det_jacobian(x) +
    168             log_prob(bijector.inverse(x)))
    169   ```
    170 
    171   - Transforming a random outcome:
    172 
    173   ```python
    174   def transformed_sample(bijector, x):
    175     return bijector.forward(x)
    176   ```
    177 
    178   #### Example Bijectors
    179 
    180   - "Exponential"
    181 
    182     ```none
    183     Y = g(X) = exp(X)
    184     X ~ Normal(0, 1)  # Univariate.
    185     ```
    186 
    187     Implies:
    188 
    189     ```none
    190       g^{-1}(Y) = log(Y)
    191       |Jacobian(g^{-1})(y)| = 1 / y
    192       Y ~ LogNormal(0, 1), i.e.,
    193       prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y))
    194                 = (1 / y) Normal(log(y); 0, 1)
    195     ```
    196 
    197     Here is an example of how one might implement the `Exp` bijector:
    198 
    199     ```python
    200       class Exp(Bijector):
    201 
    202         def __init__(self, event_ndims=0, validate_args=False, name="exp"):
    203           super(Exp, self).__init__(
    204               event_ndims=event_ndims, validate_args=validate_args, name=name)
    205 
    206         def _forward(self, x):
    207           return math_ops.exp(x)
    208 
    209         def _inverse(self, y):
    210           return math_ops.log(y)
    211 
    212         def _inverse_log_det_jacobian(self, y):
    213           return -self._forward_log_det_jacobian(self._inverse(y))
    214 
    215         def _forward_log_det_jacobian(self, x):
    216           if self.event_ndims is None:
    217             raise ValueError("Jacobian requires known event_ndims.")
    218           event_dims = array_ops.shape(x)[-self.event_ndims:]
    219           return math_ops.reduce_sum(x, axis=event_dims)
    220       ```
    221 
    222   - "Affine"
    223 
    224     ```none
    225     Y = g(X) = sqrtSigma * X + mu
    226     X ~ MultivariateNormal(0, I_d)
    227     ```
    228 
    229     Implies:
    230 
    231     ```none
    232       g^{-1}(Y) = inv(sqrtSigma) * (Y - mu)
    233       |Jacobian(g^{-1})(y)| = det(inv(sqrtSigma))
    234       Y ~ MultivariateNormal(mu, sqrtSigma) , i.e.,
    235       prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y))
    236                 = det(sqrtSigma)^(-d) *
    237                   MultivariateNormal(inv(sqrtSigma) * (y - mu); 0, I_d)
    238       ```
    239 
    240   #### Jacobian
    241 
    242   The Jacobian is a reduction over event dims. To see this, consider the `Exp`
    243   `Bijector` applied to a `Tensor` which has sample, batch, and event (S, B, E)
    244   shape semantics. Suppose the `Tensor`'s partitioned-shape is `(S=[4], B=[2],
    245   E=[3, 3])`. The shape of the `Tensor` returned by `forward` and `inverse` is
    246   unchanged, i.e., `[4, 2, 3, 3]`.  However the shape returned by
    247   `inverse_log_det_jacobian` is `[4, 2]` because the Jacobian is a reduction
    248   over the event dimensions.
    249 
    250   It is sometimes useful to implement the inverse Jacobian as the negative
    251   forward Jacobian. For example,
    252 
    253   ```python
    254   def _inverse_log_det_jacobian(self, y):
    255      return -self._forward_log_det_jac(self._inverse(y))  # Note negation.
    256   ```
    257 
    258   The correctness of this approach can be seen from the following claim.
    259 
    260   - Claim:
    261 
    262       Assume `Y = g(X)` is a bijection whose derivative exists and is nonzero
    263       for its domain, i.e., `dY/dX = d/dX g(X) != 0`. Then:
    264 
    265       ```none
    266       (log o det o jacobian o g^{-1})(Y) = -(log o det o jacobian o g)(X)
    267       ```
    268 
    269   - Proof:
    270 
    271       From the bijective, nonzero differentiability of `g`, the
    272       [inverse function theorem](
    273           https://en.wikipedia.org/wiki/Inverse_function_theorem)
    274       implies `g^{-1}` is differentiable in the image of `g`.
    275       Applying the chain rule to `y = g(x) = g(g^{-1}(y))` yields
    276       `I = g'(g^{-1}(y))*g^{-1}'(y)`.
    277       The same theorem also implies `g^{-1}'` is non-singular therefore:
    278       `inv[ g'(g^{-1}(y)) ] = g^{-1}'(y)`.
    279       The claim follows from [properties of determinant](
    280   https://en.wikipedia.org/wiki/Determinant#Multiplicativity_and_matrix_groups).
    281 
    282   Generally its preferable to directly implement the inverse Jacobian. This
    283   should have superior numerical stability and will often share subgraphs with
    284   the `_inverse` implementation.
    285 
    286   #### Subclass Requirements
    287 
    288   - Subclasses typically implement:
    289 
    290       - `_forward`,
    291       - `_inverse`,
    292       - `_inverse_log_det_jacobian`,
    293       - `_forward_log_det_jacobian` (optional).
    294 
    295     The `_forward_log_det_jacobian` is called when the bijector is inverted via
    296     the `Invert` bijector. If undefined, a slightly less efficiently
    297     calculation, `-1 * _inverse_log_det_jacobian`, is used.
    298 
    299     If the bijector changes the shape of the input, you must also implement:
    300 
    301       - _forward_event_shape_tensor,
    302       - _forward_event_shape (optional),
    303       - _inverse_event_shape_tensor,
    304       - _inverse_event_shape (optional).
    305 
    306     By default the event-shape is assumed unchanged from input.
    307 
    308   - If the `Bijector`'s use is limited to `TransformedDistribution` (or friends
    309     like `QuantizedDistribution`) then depending on your use, you may not need
    310     to implement all of `_forward` and `_inverse` functions.
    311 
    312     Examples:
    313 
    314       1. Sampling (e.g., `sample`) only requires `_forward`.
    315       2. Probability functions (e.g., `prob`, `cdf`, `survival`) only require
    316          `_inverse` (and related).
    317       3. Only calling probability functions on the output of `sample` means
    318         `_inverse` can be implemented as a cache lookup.
    319 
    320     See "Example Uses" [above] which shows how these functions are used to
    321     transform a distribution. (Note: `_forward` could theoretically be
    322     implemented as a cache lookup but this would require controlling the
    323     underlying sample generation mechanism.)
    324 
    325   #### Non Injective Transforms
    326 
    327   **WARNING** Handing of non-injective transforms is subject to change.
    328 
    329   Non injective maps `g` are supported, provided their domain `D` can be
    330   partitioned into `k` disjoint subsets, `Union{D1, ..., Dk}`, such that,
    331   ignoring sets of measure zero, the restriction of `g` to each subset is a
    332   differentiable bijection onto `g(D)`.  In particular, this imples that for
    333   `y in g(D)`, the set inverse, i.e. `g^{-1}(y) = {x in D : g(x) = y}`, always
    334   contains exactly `k` distinct points.
    335 
    336   The property, `_is_injective` is set to `False` to indicate that the bijector
    337   is not injective, yet satisfies the above condition.
    338 
    339   The usual bijector API is modified in the case `_is_injective is False` (see
    340   method docstrings for specifics).  Here we show by example the `AbsoluteValue`
    341   bijector.  In this case, the domain `D = (-inf, inf)`, can be partitioned
    342   into `D1 = (-inf, 0)`, `D2 = {0}`, and `D3 = (0, inf)`.  Let `gi` be the
    343   restriction of `g` to `Di`, then both `g1` and `g3` are bijections onto
    344   `(0, inf)`, with `g1^{-1}(y) = -y`, and `g3^{-1}(y) = y`.  We will use
    345   `g1` and `g3` to define bijector methods over `D1` and `D3`.  `D2 = {0}` is
    346   an oddball in that `g2` is one to one, and the derivative is not well defined.
    347   Fortunately, when considering transformations of probability densities
    348   (e.g. in `TransformedDistribution`), sets of measure zero have no effect in
    349   theory, and only a small effect in 32 or 64 bit precision.  For that reason,
    350   we define `inverse(0)` and `inverse_log_det_jacobian(0)` both as `[0, 0]`,
    351   which is convenient and results in a left-semicontinuous pdf.
    352 
    353 
    354   ```python
    355   abs = tf.contrib.distributions.bijectors.AbsoluteValue()
    356 
    357   abs.forward(-1.)
    358   ==> 1.
    359 
    360   abs.forward(1.)
    361   ==> 1.
    362 
    363   abs.inverse(1.)
    364   ==> (-1., 1.)
    365 
    366   # The |dX/dY| is constant, == 1.  So Log|dX/dY| == 0.
    367   abs.inverse_log_det_jacobian(1.)
    368   ==> (0., 0.)
    369 
    370   # Special case handling of 0.
    371   abs.inverse(0.)
    372   ==> (0., 0.)
    373 
    374   abs.inverse_log_det_jacobian(0.)
    375   ==> (0., 0.)
    376   ```
    377 
    378   """
    379 
    380   @abc.abstractmethod
    381   def __init__(self,
    382                event_ndims=None,
    383                graph_parents=None,
    384                is_constant_jacobian=False,
    385                validate_args=False,
    386                dtype=None,
    387                name=None):
    388     """Constructs Bijector.
    389 
    390     A `Bijector` transforms random variables into new random variables.
    391 
    392     Examples:
    393 
    394     ```python
    395     # Create the Y = g(X) = X transform which operates on vector events.
    396     identity = Identity(event_ndims=1)
    397 
    398     # Create the Y = g(X) = exp(X) transform which operates on matrices.
    399     exp = Exp(event_ndims=2)
    400     ```
    401 
    402     See `Bijector` subclass docstring for more details and specific examples.
    403 
    404     Args:
    405       event_ndims: number of dimensions associated with event coordinates.
    406       graph_parents: Python list of graph prerequisites of this `Bijector`.
    407       is_constant_jacobian: Python `bool` indicating that the Jacobian is not a
    408         function of the input.
    409       validate_args: Python `bool`, default `False`. Whether to validate input
    410         with asserts. If `validate_args` is `False`, and the inputs are invalid,
    411         correct behavior is not guaranteed.
    412       dtype: `tf.dtype` supported by this `Bijector`. `None` means dtype is not
    413         enforced.
    414       name: The name to give Ops created by the initializer.
    415 
    416     Raises:
    417       ValueError:  If a member of `graph_parents` is not a `Tensor`.
    418     """
    419     self._event_ndims = (
    420         ops.convert_to_tensor(event_ndims, dtype=dtypes.int32)
    421         if event_ndims is not None else None)
    422     self._graph_parents = graph_parents or []
    423     self._is_constant_jacobian = is_constant_jacobian
    424     self._validate_args = validate_args
    425     self._dtype = dtype
    426     self._from_y = {}
    427     self._from_x = {}
    428     # Using abbreviation ildj for "inverse log det Jacobian."
    429     # This variable is not `None` iff is_constant_jacobian is `True`.
    430     self._constant_ildj = None
    431     if name:
    432       self._name = name
    433     else:
    434       # We want the default convention to be snake_case rather than CamelCase
    435       # since `Chain` uses bijector.name as the kwargs dictionary key.
    436       def camel_to_snake(name):
    437         s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
    438         return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
    439       self._name = camel_to_snake(type(self).__name__.lstrip("_"))
    440 
    441     for i, t in enumerate(self._graph_parents):
    442       if t is None or not tensor_util.is_tensor(t):
    443         raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
    444 
    445   @property
    446   def event_ndims(self):
    447     """Returns then number of event dimensions this bijector operates on."""
    448     return self._event_ndims
    449 
    450   @property
    451   def graph_parents(self):
    452     """Returns this `Bijector`'s graph_parents as a Python list."""
    453     return self._graph_parents
    454 
    455   @property
    456   def is_constant_jacobian(self):
    457     """Returns true iff the Jacobian is not a function of x.
    458 
    459     Note: Jacobian is either constant for both forward and inverse or neither.
    460 
    461     Returns:
    462       is_constant_jacobian: Python `bool`.
    463     """
    464     return self._is_constant_jacobian
    465 
    466   @property
    467   def _is_injective(self):
    468     """Returns true iff the forward map `g` is injective (one-to-one function).
    469 
    470     **WARNING** This hidden property and its behavior are subject to change.
    471 
    472     Note:  Non-injective maps `g` are supported, provided their domain `D` can
    473     be partitioned into `k` disjoint subsets, `Union{D1, ..., Dk}`, such that,
    474     ignoring sets of measure zero, the restriction of `g` to each subset is a
    475     differentiable bijection onto `g(D)`.
    476 
    477     Returns:
    478       is_injective: Python `bool`.
    479     """
    480     return True
    481 
    482   @property
    483   def validate_args(self):
    484     """Returns True if Tensor arguments will be validated."""
    485     return self._validate_args
    486 
    487   @property
    488   def dtype(self):
    489     """dtype of `Tensor`s transformable by this distribution."""
    490     return self._dtype
    491 
    492   @property
    493   def name(self):
    494     """Returns the string name of this `Bijector`."""
    495     return self._name
    496 
    497   def _forward_event_shape_tensor(self, input_shape):
    498     """Subclass implementation for `forward_event_shape_tensor` function."""
    499     # By default, we assume event_shape is unchanged.
    500     return input_shape
    501 
    502   def forward_event_shape_tensor(self,
    503                                  input_shape,
    504                                  name="forward_event_shape_tensor"):
    505     """Shape of a single sample from a single batch as an `int32` 1D `Tensor`.
    506 
    507     Args:
    508       input_shape: `Tensor`, `int32` vector indicating event-portion shape
    509         passed into `forward` function.
    510       name: name to give to the op
    511 
    512     Returns:
    513       forward_event_shape_tensor: `Tensor`, `int32` vector indicating
    514         event-portion shape after applying `forward`.
    515     """
    516     with self._name_scope(name, [input_shape]):
    517       input_shape = ops.convert_to_tensor(input_shape, dtype=dtypes.int32,
    518                                           name="input_shape")
    519       return self._forward_event_shape_tensor(input_shape)
    520 
    521   def _forward_event_shape(self, input_shape):
    522     """Subclass implementation for `forward_event_shape` public function."""
    523     # By default, we assume event_shape is unchanged.
    524     return input_shape
    525 
    526   def forward_event_shape(self, input_shape):
    527     """Shape of a single sample from a single batch as a `TensorShape`.
    528 
    529     Same meaning as `forward_event_shape_tensor`. May be only partially defined.
    530 
    531     Args:
    532       input_shape: `TensorShape` indicating event-portion shape passed into
    533         `forward` function.
    534 
    535     Returns:
    536       forward_event_shape_tensor: `TensorShape` indicating event-portion shape
    537         after applying `forward`. Possibly unknown.
    538     """
    539     return self._forward_event_shape(tensor_shape.TensorShape(input_shape))
    540 
    541   def _inverse_event_shape_tensor(self, output_shape):
    542     """Subclass implementation for `inverse_event_shape_tensor` function."""
    543     # By default, we assume event_shape is unchanged.
    544     return output_shape
    545 
    546   def inverse_event_shape_tensor(self,
    547                                  output_shape,
    548                                  name="inverse_event_shape_tensor"):
    549     """Shape of a single sample from a single batch as an `int32` 1D `Tensor`.
    550 
    551     Args:
    552       output_shape: `Tensor`, `int32` vector indicating event-portion shape
    553         passed into `inverse` function.
    554       name: name to give to the op
    555 
    556     Returns:
    557       inverse_event_shape_tensor: `Tensor`, `int32` vector indicating
    558         event-portion shape after applying `inverse`.
    559     """
    560     with self._name_scope(name, [output_shape]):
    561       output_shape = ops.convert_to_tensor(output_shape, dtype=dtypes.int32,
    562                                            name="output_shape")
    563       return self._inverse_event_shape_tensor(output_shape)
    564 
    565   def _inverse_event_shape(self, output_shape):
    566     """Subclass implementation for `inverse_event_shape` public function."""
    567     # By default, we assume event_shape is unchanged.
    568     return tensor_shape.TensorShape(output_shape)
    569 
    570   def inverse_event_shape(self, output_shape):
    571     """Shape of a single sample from a single batch as a `TensorShape`.
    572 
    573     Same meaning as `inverse_event_shape_tensor`. May be only partially defined.
    574 
    575     Args:
    576       output_shape: `TensorShape` indicating event-portion shape passed into
    577         `inverse` function.
    578 
    579     Returns:
    580       inverse_event_shape_tensor: `TensorShape` indicating event-portion shape
    581         after applying `inverse`. Possibly unknown.
    582     """
    583     return self._inverse_event_shape(output_shape)
    584 
    585   def _forward(self, x):
    586     """Subclass implementation for `forward` public function."""
    587     raise NotImplementedError("forward not implemented.")
    588 
    589   def _call_forward(self, x, name, **kwargs):
    590     with self._name_scope(name, [x]):
    591       x = ops.convert_to_tensor(x, name="x")
    592       self._maybe_assert_dtype(x)
    593       if not self._is_injective:  # No caching for non-injective
    594         return self._forward(x, **kwargs)
    595       mapping = self._lookup(x=x, kwargs=kwargs)
    596       if mapping.y is not None:
    597         return mapping.y
    598       mapping = mapping.merge(y=self._forward(x, **kwargs))
    599       self._cache(mapping)
    600       return mapping.y
    601 
    602   def forward(self, x, name="forward"):
    603     """Returns the forward `Bijector` evaluation, i.e., X = g(Y).
    604 
    605     Args:
    606       x: `Tensor`. The input to the "forward" evaluation.
    607       name: The name to give this op.
    608 
    609     Returns:
    610       `Tensor`.
    611 
    612     Raises:
    613       TypeError: if `self.dtype` is specified and `x.dtype` is not
    614         `self.dtype`.
    615       NotImplementedError: if `_forward` is not implemented.
    616     """
    617     return self._call_forward(x, name)
    618 
    619   def _inverse(self, y):
    620     """Subclass implementation for `inverse` public function."""
    621     raise NotImplementedError("inverse not implemented")
    622 
    623   def _call_inverse(self, y, name, **kwargs):
    624     with self._name_scope(name, [y]):
    625       y = ops.convert_to_tensor(y, name="y")
    626       self._maybe_assert_dtype(y)
    627       if not self._is_injective:  # No caching for non-injective
    628         return self._inverse(y, **kwargs)
    629       mapping = self._lookup(y=y, kwargs=kwargs)
    630       if mapping.x is not None:
    631         return mapping.x
    632       mapping = mapping.merge(x=self._inverse(y, **kwargs))
    633       self._cache(mapping)
    634       return mapping.x
    635 
    636   def inverse(self, y, name="inverse"):
    637     """Returns the inverse `Bijector` evaluation, i.e., X = g^{-1}(Y).
    638 
    639     Args:
    640       y: `Tensor`. The input to the "inverse" evaluation.
    641       name: The name to give this op.
    642 
    643     Returns:
    644       `Tensor`, if this bijector is injective.
    645         If not injective, returns the k-tuple containing the unique
    646         `k` points `(x1, ..., xk)` such that `g(xi) = y`.
    647 
    648     Raises:
    649       TypeError: if `self.dtype` is specified and `y.dtype` is not
    650         `self.dtype`.
    651       NotImplementedError: if `_inverse` is not implemented.
    652     """
    653     return self._call_inverse(y, name)
    654 
    655   def _inverse_log_det_jacobian(self, y):
    656     """Subclass implementation of `inverse_log_det_jacobian` public function."""
    657     raise NotImplementedError("inverse_log_det_jacobian not implemented.")
    658 
    659   def _call_inverse_log_det_jacobian(self, y, name, **kwargs):
    660     with self._name_scope(name, [y]):
    661       if self._constant_ildj is not None:
    662         return self._constant_ildj
    663       y = ops.convert_to_tensor(y, name="y")
    664       self._maybe_assert_dtype(y)
    665       if not self._is_injective:  # No caching for non-injective
    666         return self._inverse_log_det_jacobian(y, **kwargs)
    667       mapping = self._lookup(y=y, kwargs=kwargs)
    668       if mapping.ildj is not None:
    669         return mapping.ildj
    670       try:
    671         x = None  # Not needed; leave cache as is.
    672         ildj = self._inverse_log_det_jacobian(y, **kwargs)
    673       except NotImplementedError as original_exception:
    674         try:
    675           x = mapping.x if mapping.x is not None else self._inverse(y, **kwargs)
    676           ildj = -self._forward_log_det_jacobian(x, **kwargs)
    677         except NotImplementedError:
    678           raise original_exception
    679       mapping = mapping.merge(x=x, ildj=ildj)
    680       self._cache(mapping)
    681       if self.is_constant_jacobian:
    682         self._constant_ildj = mapping.ildj
    683       return mapping.ildj
    684 
    685   def inverse_log_det_jacobian(self, y, name="inverse_log_det_jacobian"):
    686     """Returns the (log o det o Jacobian o inverse)(y).
    687 
    688     Mathematically, returns: `log(det(dX/dY))(Y)`. (Recall that: `X=g^{-1}(Y)`.)
    689 
    690     Note that `forward_log_det_jacobian` is the negative of this function,
    691     evaluated at `g^{-1}(y)`.
    692 
    693     Args:
    694       y: `Tensor`. The input to the "inverse" Jacobian evaluation.
    695       name: The name to give this op.
    696 
    697     Returns:
    698       `Tensor`, if this bijector is injective.
    699         If not injective, returns the tuple of local log det
    700         Jacobians, `log(det(Dg_i^{-1}(y)))`, where `g_i` is the restriction
    701         of `g` to the `ith` partition `Di`.
    702 
    703     Raises:
    704       TypeError: if `self.dtype` is specified and `y.dtype` is not
    705         `self.dtype`.
    706       NotImplementedError: if `_inverse_log_det_jacobian` is not implemented.
    707     """
    708     return self._call_inverse_log_det_jacobian(y, name)
    709 
    710   def _forward_log_det_jacobian(self, x):
    711     """Subclass implementation of `forward_log_det_jacobian`."""
    712     raise NotImplementedError(
    713         "forward_log_det_jacobian not implemented.")
    714 
    715   def _call_forward_log_det_jacobian(self, x, name, **kwargs):
    716     with self._name_scope(name, [x]):
    717       if self._constant_ildj is not None:
    718         # Need "-1. *" to avoid invalid-unary-operand-type linter warning.
    719         return -1. * self._constant_ildj
    720       x = ops.convert_to_tensor(x, name="x")
    721       self._maybe_assert_dtype(x)
    722       if not self._is_injective:
    723         return self._forward_log_det_jacobian(x, **kwargs)  # No caching.
    724       mapping = self._lookup(x=x, kwargs=kwargs)
    725       if mapping.ildj is not None:
    726         return -mapping.ildj
    727       try:
    728         y = None  # Not needed; leave cache as is.
    729         ildj = -self._forward_log_det_jacobian(x, **kwargs)
    730       except NotImplementedError as original_exception:
    731         try:
    732           y = mapping.y if mapping.y is not None else self._forward(x, **kwargs)
    733           ildj = self._inverse_log_det_jacobian(y, **kwargs)
    734         except NotImplementedError:
    735           raise original_exception
    736       mapping = mapping.merge(y=y, ildj=ildj)
    737       self._cache(mapping)
    738       if self.is_constant_jacobian:
    739         self._constant_ildj = mapping.ildj
    740       return -mapping.ildj
    741 
    742   def forward_log_det_jacobian(self, x, name="forward_log_det_jacobian"):
    743     """Returns both the forward_log_det_jacobian.
    744 
    745     Args:
    746       x: `Tensor`. The input to the "forward" Jacobian evaluation.
    747       name: The name to give this op.
    748 
    749     Returns:
    750       `Tensor`, if this bijector is injective.
    751         If not injective this is not implemented.
    752 
    753     Raises:
    754       TypeError: if `self.dtype` is specified and `y.dtype` is not
    755         `self.dtype`.
    756       NotImplementedError: if neither `_forward_log_det_jacobian`
    757         nor {`_inverse`, `_inverse_log_det_jacobian`} are implemented, or
    758         this is a non-injective bijector.
    759     """
    760     if not self._is_injective:
    761       raise NotImplementedError(
    762           "forward_log_det_jacobian cannot be implemented for non-injective "
    763           "transforms.")
    764     return self._call_forward_log_det_jacobian(x, name)
    765 
    766   @contextlib.contextmanager
    767   def _name_scope(self, name=None, values=None):
    768     """Helper function to standardize op scope."""
    769     with ops.name_scope(self.name):
    770       with ops.name_scope(
    771           name, values=(values or []) + self.graph_parents) as scope:
    772         yield scope
    773 
    774   def _maybe_assert_dtype(self, x):
    775     """Helper to check dtype when self.dtype is known."""
    776     if self.dtype is not None and self.dtype.base_dtype != x.dtype.base_dtype:
    777       raise TypeError("Input had dtype %s but expected %s." %
    778                       (self.dtype, x.dtype))
    779 
    780   def _cache(self, mapping):
    781     """Helper which stores mapping info in forward/inverse dicts."""
    782     if self._constant_ildj is not None:
    783       # Fold in ildj if known constant Jacobian.
    784       mapping = mapping.merge(ildj=self._constant_ildj)
    785     # Merging from lookup is an added check that we're not overwriting anything
    786     # which is not None.
    787     mapping = mapping.merge(mapping=self._lookup(
    788         mapping.x, mapping.y, mapping.kwargs))
    789     if mapping.x is None and mapping.y is None:
    790       raise ValueError("Caching expects at least one of (x,y) to be known, "
    791                        "i.e., not None.")
    792     self._from_x[mapping.x_key] = mapping
    793     self._from_y[mapping.y_key] = mapping
    794 
    795   def _lookup(self, x=None, y=None, kwargs=None):
    796     """Helper which retrieves mapping info from forward/inverse dicts."""
    797     mapping = _Mapping(x=x, y=y, kwargs=kwargs)
    798     # Since _cache requires both x,y to be set, we only need to do one cache
    799     # lookup since the mapping is always in both or neither.
    800     if mapping.x is not None:
    801       return self._from_x.get(mapping.x_key, mapping)
    802     if mapping.y is not None:
    803       return self._from_y.get(mapping.y_key, mapping)
    804     return mapping
    805 
    806   def _event_dims_tensor(self, sample):
    807     """Return a 1D `int32` tensor: `range(rank(sample))[-event_ndims:]`."""
    808     if self.event_ndims is None:
    809       raise ValueError("Jacobian cannot be computed with unknown event_ndims")
    810     static_event_ndims = tensor_util.constant_value(self.event_ndims)
    811     static_rank = sample.get_shape().ndims
    812     if static_event_ndims is not None and static_rank is not None:
    813       return ops.convert_to_tensor(
    814           static_rank + np.arange(-static_event_ndims, 0).astype(np.int32))
    815 
    816     if static_event_ndims is not None:
    817       event_range = np.arange(-static_event_ndims, 0).astype(np.int32)
    818     else:
    819       event_range = math_ops.range(-self.event_ndims, 0, dtype=dtypes.int32)
    820 
    821     if static_rank is not None:
    822       return event_range + static_rank
    823     else:
    824       return event_range + array_ops.rank(sample)
    825