Home | History | Annotate | Download | only in bijectors
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Chain bijector."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import itertools
     22 
     23 from tensorflow.python.framework import constant_op
     24 from tensorflow.python.ops.distributions import bijector
     25 
     26 
     27 __all__ = [
     28     "Chain",
     29 ]
     30 
     31 
     32 class Chain(bijector.Bijector):
     33   """Bijector which applies a sequence of bijectors.
     34 
     35   Example Use:
     36 
     37   ```python
     38   chain = Chain([Exp(), Softplus()], name="one_plus_exp")
     39   ```
     40 
     41   Results in:
     42 
     43   * Forward:
     44 
     45    ```python
     46    exp = Exp()
     47    softplus = Softplus()
     48    Chain([exp, softplus]).forward(x)
     49    = exp.forward(softplus.forward(x))
     50    = tf.exp(tf.log(1. + tf.exp(x)))
     51    = 1. + tf.exp(x)
     52    ```
     53 
     54   * Inverse:
     55 
     56    ```python
     57    exp = Exp()
     58    softplus = Softplus()
     59    Chain([exp, softplus]).inverse(y)
     60    = softplus.inverse(exp.inverse(y))
     61    = tf.log(tf.exp(tf.log(y)) - 1.)
     62    = tf.log(y - 1.)
     63    ```
     64 
     65   """
     66 
     67   def __init__(self, bijectors=None, validate_args=False, name=None):
     68     """Instantiates `Chain` bijector.
     69 
     70     Args:
     71       bijectors: Python `list` of bijector instances. An empty list makes this
     72         bijector equivalent to the `Identity` bijector.
     73       validate_args: Python `bool` indicating whether arguments should be
     74         checked for correctness.
     75       name: Python `str`, name given to ops managed by this object. Default:
     76         E.g., `Chain([Exp(), Softplus()]).name == "chain_of_exp_of_softplus"`.
     77 
     78     Raises:
     79       ValueError: if bijectors have different dtypes.
     80     """
     81     if bijectors is None:
     82       bijectors = ()
     83     self._bijectors = bijectors
     84 
     85     for a_bijector in bijectors:
     86       if not a_bijector._is_injective:  # pylint: disable=protected-access
     87         raise NotImplementedError(
     88             "Invert is not implemented for non-injective bijector ({})".format(
     89                 a_bijector.name))
     90 
     91     dtype = list(set([b.dtype for b in bijectors]))
     92     if len(dtype) > 2:
     93       raise ValueError("incompatible dtypes: %s" % dtype)
     94     elif len(dtype) == 2:
     95       dtype = dtype[1] if dtype[0] is None else dtype[0]
     96       event_ndims = bijectors[0].event_ndims
     97     elif len(dtype) == 1:
     98       dtype = dtype[0]
     99       event_ndims = bijectors[0].event_ndims
    100     else:
    101       dtype = None
    102       event_ndims = None
    103 
    104     super(Chain, self).__init__(
    105         graph_parents=list(itertools.chain.from_iterable(
    106             b.graph_parents for b in bijectors)),
    107         is_constant_jacobian=all(b.is_constant_jacobian for b in bijectors),
    108         validate_args=validate_args,
    109         dtype=dtype,
    110         event_ndims=event_ndims,
    111         name=name or ("identity" if not bijectors else
    112                       "_of_".join(["chain"] + [b.name for b in bijectors])))
    113 
    114   @property
    115   def bijectors(self):
    116     return self._bijectors
    117 
    118   def _shape_helper(self, func_name, input_shape, reverse):
    119     new_shape = input_shape
    120     for b in reversed(self.bijectors) if reverse else self.bijectors:
    121       func = getattr(b, func_name, None)
    122       if func is None:
    123         raise ValueError("unable to call %s on bijector %s (%s)" %
    124                          (func_name, b.name, func))
    125       new_shape = func(new_shape)
    126     return new_shape
    127 
    128   def _forward_event_shape(self, input_shape):
    129     return self._shape_helper("forward_event_shape", input_shape,
    130                               reverse=True)
    131 
    132   def _forward_event_shape_tensor(self, input_shape):
    133     return self._shape_helper(
    134         "forward_event_shape_tensor", input_shape, reverse=True)
    135 
    136   def _inverse_event_shape(self, output_shape):
    137     return self._shape_helper("inverse_event_shape", output_shape,
    138                               reverse=False)
    139 
    140   def _inverse_event_shape_tensor(self, output_shape):
    141     return self._shape_helper("inverse_event_shape_tensor", output_shape,
    142                               reverse=False)
    143 
    144   def _inverse(self, y, **kwargs):
    145     for b in self.bijectors:
    146       y = b.inverse(y, **kwargs.get(b.name, {}))
    147     return y
    148 
    149   def _inverse_log_det_jacobian(self, y, **kwargs):
    150     ildj = constant_op.constant(0., dtype=y.dtype,
    151                                 name="inverse_log_det_jacobian")
    152     for b in self.bijectors:
    153       ildj += b.inverse_log_det_jacobian(y, **kwargs.get(b.name, {}))
    154       y = b.inverse(y, **kwargs.get(b.name, {}))
    155     return ildj
    156 
    157   def _forward(self, x, **kwargs):
    158     for b in reversed(self.bijectors):
    159       x = b.forward(x, **kwargs.get(b.name, {}))
    160     return x
    161 
    162   def _forward_log_det_jacobian(self, x, **kwargs):
    163     fldj = constant_op.constant(0., dtype=x.dtype,
    164                                 name="forward_log_det_jacobian")
    165     for b in reversed(self.bijectors):
    166       fldj += b.forward_log_det_jacobian(x, **kwargs.get(b.name, {}))
    167       x = b.forward(x, **kwargs.get(b.name, {}))
    168     return fldj
    169