Home | History | Annotate | Download | only in bijectors
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Real NVP bijector."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.framework import constant_op
     22 from tensorflow.python.framework import ops
     23 from tensorflow.python.layers import core as layers
     24 from tensorflow.python.ops import array_ops
     25 from tensorflow.python.ops import math_ops
     26 from tensorflow.python.ops import nn_ops
     27 from tensorflow.python.ops import template as template_ops
     28 from tensorflow.python.ops.distributions import bijector as bijector_lib
     29 
     30 
     31 __all__ = [
     32     "RealNVP",
     33     "real_nvp_default_template"
     34 ]
     35 
     36 
     37 class RealNVP(bijector_lib.Bijector):
     38   """RealNVP "affine coupling layer" for vector-valued events.
     39 
     40   Real NVP models a normalizing flow on a `D`-dimensional distribution via a
     41   single `D-d`-dimensional conditional distribution [1]:
     42 
     43   `y[d:D] = y[d:D] * math_ops.exp(log_scale_fn(y[d:D])) + shift_fn(y[d:D])`
     44   `y[0:d] = x[0:d]`
     45 
     46   The last `D-d` units are scaled and shifted based on the first `d` units only,
     47   while the first `d` units are 'masked' and left unchanged. Real NVP's
     48   `shift_and_log_scale_fn` computes vector-valued quantities. For
     49   scale-and-shift transforms that do not depend on any masked units, i.e.
     50   `d=0`, use the `tfb.Affine` bijector with learned parameters instead.
     51 
     52   Masking is currently only supported for base distributions with
     53   `event_ndims=1`. For more sophisticated masking schemes like checkerboard or
     54   channel-wise masking [2], use the `tfb.Permute` bijector to re-order desired
     55   masked units into the first `d` units. For base distributions with
     56   `event_ndims > 1`, use the `tfb.Reshape` bijector to flatten the event shape.
     57 
     58   Recall that the MAF bijector [2] implements a normalizing flow via an
     59   autoregressive transformation. MAF and IAF have opposite computational
     60   tradeoffs - MAF can train all units in parallel but must sample units
     61   sequentially, while IAF must train units sequentially but can sample in
     62   parallel. In contrast, Real NVP can compute both forward and inverse
     63   computations in parallel. However, the lack of an autoregressive
     64   transformations makes it less expressive on a per-bijector basis.
     65 
     66   A "valid" `shift_and_log_scale_fn` must compute each `shift` (aka `loc` or
     67   "mu" [2]) and `log(scale)` (aka "alpha" [2]) such that each are broadcastable
     68   with the arguments to `forward` and `inverse`, i.e., such that the
     69   calculations in `forward`, `inverse` [below] are possible. For convenience,
     70   `real_nvp_default_nvp` is offered as a possible `shift_and_log_scale_fn`
     71   function.
     72 
     73   NICE [3] is a special case of the Real NVP bijector which discards the scale
     74   transformation, resulting in a constant-time inverse-log-determinant-Jacobian.
     75   To use a NICE bijector instead of Real NVP, `shift_and_log_scale_fn` should
     76   return `(shift, None)`, and `is_constant_jacobian` should be set to `True` in
     77   the `RealNVP` constructor. Calling `real_nvp_default_template` with
     78   `shift_only=True` returns one such NICE-compatible `shift_and_log_scale_fn`.
     79 
     80   Caching: the scalar input depth `D` of the base distribution is not known at
     81   construction time. The first call to any of `forward(x)`, `inverse(x)`,
     82   `inverse_log_det_jacobian(x)`, or `forward_log_det_jacobian(x)` memoizes
     83   `D`, which is re-used in subsequent calls. This shape must be known prior to
     84   graph execution (which is the case if using tf.layers).
     85 
     86   #### Example Use
     87 
     88   ```python
     89   tfd = tf.contrib.distributions
     90   tfb = tfd.bijectors
     91 
     92   # A common choice for a normalizing flow is to use a Gaussian for the base
     93   # distribution. (However, any continuous distribution would work.) E.g.,
     94   nvp = tfd.TransformedDistribution(
     95       distribution=tfd.MultivariateNormalDiag(loc=[0., 0., 0.])),
     96       bijector=tfb.RealNVP(
     97           num_masked=2,
     98           shift_and_log_scale_fn=tfb.real_nvp_default_template(
     99               hidden_layers=[512, 512])))
    100 
    101   x = nvp.sample()
    102   nvp.log_prob(x)
    103   nvp.log_prob(0.)
    104   ```
    105 
    106   For more examples, see [4].
    107 
    108   [1]: "Density Estimation using Real NVP."
    109        Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio. ICLR. 2017.
    110        https://arxiv.org/abs/1605.08803
    111 
    112   [2]: "Masked Autoregressive Flow for Density Estimation."
    113        George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017.
    114        https://arxiv.org/abs/1705.07057
    115 
    116   [3]: "NICE: Non-linear Independent Components Estimation."
    117        Laurent Dinh, David Krueger, Yoshua Bengio. ICLR. 2015.
    118        https://arxiv.org/abs/1410.8516
    119 
    120   [4]: "Normalizing Flows Tutorial, Part 2: Modern Normalizing Flows."
    121        Eric Jang. Blog post. January 2018.
    122        http://blog.evjang.com/2018/01/nf2.html
    123   """
    124 
    125   def __init__(self,
    126                num_masked,
    127                shift_and_log_scale_fn,
    128                is_constant_jacobian=False,
    129                validate_args=False,
    130                name=None):
    131     """Creates the Real NVP or NICE bijector.
    132 
    133     Args:
    134       num_masked: Python `int` indicating that the first `d` units of the event
    135         should be masked. Must be in the closed interval `[1, D-1]`, where `D`
    136         is the event size of the base distribution.
    137       shift_and_log_scale_fn: Python `callable` which computes `shift` and
    138         `log_scale` from both the forward domain (`x`) and the inverse domain
    139         (`y`). Calculation must respect the "autoregressive property" (see class
    140         docstring). Suggested default
    141         `masked_autoregressive_default_template(hidden_layers=...)`.
    142         Typically the function contains `tf.Variables` and is wrapped using
    143         `tf.make_template`. Returning `None` for either (both) `shift`,
    144         `log_scale` is equivalent to (but more efficient than) returning zero.
    145       is_constant_jacobian: Python `bool`. Default: `False`. When `True` the
    146         implementation assumes `log_scale` does not depend on the forward domain
    147         (`x`) or inverse domain (`y`) values. (No validation is made;
    148         `is_constant_jacobian=False` is always safe but possibly computationally
    149         inefficient.)
    150       validate_args: Python `bool` indicating whether arguments should be
    151         checked for correctness.
    152       name: Python `str`, name given to ops managed by this object.
    153 
    154     Raises:
    155       ValueError: If num_masked < 1.
    156     """
    157     name = name or "real_nvp"
    158     if num_masked <= 0:
    159       raise ValueError("num_masked must be a positive integer.")
    160     self._num_masked = num_masked
    161     # At construction time, we don't know input_depth.
    162     self._input_depth = None
    163     self._shift_and_log_scale_fn = shift_and_log_scale_fn
    164     super(RealNVP, self).__init__(
    165         event_ndims=1,
    166         is_constant_jacobian=is_constant_jacobian,
    167         validate_args=validate_args,
    168         name=name)
    169 
    170   def _cache_input_depth(self, x):
    171     if self._input_depth is None:
    172       self._input_depth = x.shape.with_rank_at_least(1)[-1].value
    173       if self._input_depth is None:
    174         raise NotImplementedError(
    175             "Rightmost dimension must be known prior to graph execution.")
    176       if self._num_masked >= self._input_depth:
    177         raise ValueError(
    178             "Number of masked units must be smaller than the event size.")
    179 
    180   def _forward(self, x):
    181     self._cache_input_depth(x)
    182     # Performs scale and shift.
    183     x0, x1 = x[:, :self._num_masked], x[:, self._num_masked:]
    184     shift, log_scale = self._shift_and_log_scale_fn(
    185         x0, self._input_depth - self._num_masked)
    186     y1 = x1
    187     if log_scale is not None:
    188       y1 *= math_ops.exp(log_scale)
    189     if shift is not None:
    190       y1 += shift
    191     y = array_ops.concat([x0, y1], axis=-1)
    192     return y
    193 
    194   def _inverse(self, y):
    195     self._cache_input_depth(y)
    196     # Performs un-shift and un-scale.
    197     y0, y1 = y[:, :self._num_masked], y[:, self._num_masked:]
    198     shift, log_scale = self._shift_and_log_scale_fn(
    199         y0, self._input_depth - self._num_masked)
    200     x1 = y1
    201     if shift is not None:
    202       x1 -= shift
    203     if log_scale is not None:
    204       x1 *= math_ops.exp(-log_scale)
    205     x = array_ops.concat([y0, x1], axis=-1)
    206     return x
    207 
    208   def _inverse_log_det_jacobian(self, y):
    209     self._cache_input_depth(y)
    210     y0 = y[:, :self._num_masked]
    211     _, log_scale = self._shift_and_log_scale_fn(
    212         y0, self._input_depth - self._num_masked)
    213     if log_scale is None:
    214       return constant_op.constant(0., dtype=y.dtype, name="ildj")
    215     return -math_ops.reduce_sum(log_scale, axis=-1)
    216 
    217   def _forward_log_det_jacobian(self, x):
    218     self._cache_input_depth(x)
    219     x0 = x[:, :self._num_masked]
    220     _, log_scale = self._shift_and_log_scale_fn(
    221         x0, self._input_depth - self._num_masked)
    222     if log_scale is None:
    223       return constant_op.constant(0., dtype=x.dtype, name="ildj")
    224     return math_ops.reduce_sum(log_scale, axis=-1)
    225 
    226 
    227 def real_nvp_default_template(
    228     hidden_layers,
    229     shift_only=False,
    230     activation=nn_ops.relu,
    231     name=None,
    232     *args,
    233     **kwargs):
    234   """Build a scale-and-shift function using a multi-layer neural network.
    235 
    236   This will be wrapped in a make_template to ensure the variables are only
    237   created once. It takes the `d`-dimensional input x[0:d] and returns the `D-d`
    238   dimensional outputs `loc` ("mu") and `log_scale` ("alpha").
    239 
    240   Arguments:
    241     hidden_layers: Python `list`-like of non-negative integer, scalars
    242       indicating the number of units in each hidden layer. Default: `[512, 512].
    243     shift_only: Python `bool` indicating if only the `shift` term shall be
    244       computed (i.e. NICE bijector). Default: `False`.
    245     activation: Activation function (callable). Explicitly setting to `None`
    246       implies a linear activation.
    247     name: A name for ops managed by this function. Default:
    248       "real_nvp_default_template".
    249     *args: `tf.layers.dense` arguments.
    250     **kwargs: `tf.layers.dense` keyword arguments.
    251 
    252   Returns:
    253     shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]).
    254     log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]).
    255 
    256   Raises:
    257     NotImplementedError: if rightmost dimension of `inputs` is unknown prior to
    258       graph execution.
    259   """
    260 
    261   with ops.name_scope(name, "real_nvp_default_template"):
    262     def _fn(x, output_units):
    263       """Fully connected MLP parameterized via `real_nvp_template`."""
    264       for units in hidden_layers:
    265         x = layers.dense(
    266             inputs=x,
    267             units=units,
    268             activation=activation,
    269             *args,
    270             **kwargs)
    271       x = layers.dense(
    272           inputs=x,
    273           units=(1 if shift_only else 2) * output_units,
    274           activation=None,
    275           *args,
    276           **kwargs)
    277       if shift_only:
    278         return x, None
    279       shift, log_scale = array_ops.split(x, 2, axis=-1)
    280       return shift, log_scale
    281     return template_ops.make_template(
    282         "real_nvp_default_template", _fn)
    283