1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Real NVP bijector.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.framework import constant_op 22 from tensorflow.python.framework import ops 23 from tensorflow.python.layers import core as layers 24 from tensorflow.python.ops import array_ops 25 from tensorflow.python.ops import math_ops 26 from tensorflow.python.ops import nn_ops 27 from tensorflow.python.ops import template as template_ops 28 from tensorflow.python.ops.distributions import bijector as bijector_lib 29 30 31 __all__ = [ 32 "RealNVP", 33 "real_nvp_default_template" 34 ] 35 36 37 class RealNVP(bijector_lib.Bijector): 38 """RealNVP "affine coupling layer" for vector-valued events. 39 40 Real NVP models a normalizing flow on a `D`-dimensional distribution via a 41 single `D-d`-dimensional conditional distribution [1]: 42 43 `y[d:D] = y[d:D] * math_ops.exp(log_scale_fn(y[d:D])) + shift_fn(y[d:D])` 44 `y[0:d] = x[0:d]` 45 46 The last `D-d` units are scaled and shifted based on the first `d` units only, 47 while the first `d` units are 'masked' and left unchanged. Real NVP's 48 `shift_and_log_scale_fn` computes vector-valued quantities. For 49 scale-and-shift transforms that do not depend on any masked units, i.e. 50 `d=0`, use the `tfb.Affine` bijector with learned parameters instead. 51 52 Masking is currently only supported for base distributions with 53 `event_ndims=1`. For more sophisticated masking schemes like checkerboard or 54 channel-wise masking [2], use the `tfb.Permute` bijector to re-order desired 55 masked units into the first `d` units. For base distributions with 56 `event_ndims > 1`, use the `tfb.Reshape` bijector to flatten the event shape. 57 58 Recall that the MAF bijector [2] implements a normalizing flow via an 59 autoregressive transformation. MAF and IAF have opposite computational 60 tradeoffs - MAF can train all units in parallel but must sample units 61 sequentially, while IAF must train units sequentially but can sample in 62 parallel. In contrast, Real NVP can compute both forward and inverse 63 computations in parallel. However, the lack of an autoregressive 64 transformations makes it less expressive on a per-bijector basis. 65 66 A "valid" `shift_and_log_scale_fn` must compute each `shift` (aka `loc` or 67 "mu" [2]) and `log(scale)` (aka "alpha" [2]) such that each are broadcastable 68 with the arguments to `forward` and `inverse`, i.e., such that the 69 calculations in `forward`, `inverse` [below] are possible. For convenience, 70 `real_nvp_default_nvp` is offered as a possible `shift_and_log_scale_fn` 71 function. 72 73 NICE [3] is a special case of the Real NVP bijector which discards the scale 74 transformation, resulting in a constant-time inverse-log-determinant-Jacobian. 75 To use a NICE bijector instead of Real NVP, `shift_and_log_scale_fn` should 76 return `(shift, None)`, and `is_constant_jacobian` should be set to `True` in 77 the `RealNVP` constructor. Calling `real_nvp_default_template` with 78 `shift_only=True` returns one such NICE-compatible `shift_and_log_scale_fn`. 79 80 Caching: the scalar input depth `D` of the base distribution is not known at 81 construction time. The first call to any of `forward(x)`, `inverse(x)`, 82 `inverse_log_det_jacobian(x)`, or `forward_log_det_jacobian(x)` memoizes 83 `D`, which is re-used in subsequent calls. This shape must be known prior to 84 graph execution (which is the case if using tf.layers). 85 86 #### Example Use 87 88 ```python 89 tfd = tf.contrib.distributions 90 tfb = tfd.bijectors 91 92 # A common choice for a normalizing flow is to use a Gaussian for the base 93 # distribution. (However, any continuous distribution would work.) E.g., 94 nvp = tfd.TransformedDistribution( 95 distribution=tfd.MultivariateNormalDiag(loc=[0., 0., 0.])), 96 bijector=tfb.RealNVP( 97 num_masked=2, 98 shift_and_log_scale_fn=tfb.real_nvp_default_template( 99 hidden_layers=[512, 512]))) 100 101 x = nvp.sample() 102 nvp.log_prob(x) 103 nvp.log_prob(0.) 104 ``` 105 106 For more examples, see [4]. 107 108 [1]: "Density Estimation using Real NVP." 109 Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio. ICLR. 2017. 110 https://arxiv.org/abs/1605.08803 111 112 [2]: "Masked Autoregressive Flow for Density Estimation." 113 George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. 114 https://arxiv.org/abs/1705.07057 115 116 [3]: "NICE: Non-linear Independent Components Estimation." 117 Laurent Dinh, David Krueger, Yoshua Bengio. ICLR. 2015. 118 https://arxiv.org/abs/1410.8516 119 120 [4]: "Normalizing Flows Tutorial, Part 2: Modern Normalizing Flows." 121 Eric Jang. Blog post. January 2018. 122 http://blog.evjang.com/2018/01/nf2.html 123 """ 124 125 def __init__(self, 126 num_masked, 127 shift_and_log_scale_fn, 128 is_constant_jacobian=False, 129 validate_args=False, 130 name=None): 131 """Creates the Real NVP or NICE bijector. 132 133 Args: 134 num_masked: Python `int` indicating that the first `d` units of the event 135 should be masked. Must be in the closed interval `[1, D-1]`, where `D` 136 is the event size of the base distribution. 137 shift_and_log_scale_fn: Python `callable` which computes `shift` and 138 `log_scale` from both the forward domain (`x`) and the inverse domain 139 (`y`). Calculation must respect the "autoregressive property" (see class 140 docstring). Suggested default 141 `masked_autoregressive_default_template(hidden_layers=...)`. 142 Typically the function contains `tf.Variables` and is wrapped using 143 `tf.make_template`. Returning `None` for either (both) `shift`, 144 `log_scale` is equivalent to (but more efficient than) returning zero. 145 is_constant_jacobian: Python `bool`. Default: `False`. When `True` the 146 implementation assumes `log_scale` does not depend on the forward domain 147 (`x`) or inverse domain (`y`) values. (No validation is made; 148 `is_constant_jacobian=False` is always safe but possibly computationally 149 inefficient.) 150 validate_args: Python `bool` indicating whether arguments should be 151 checked for correctness. 152 name: Python `str`, name given to ops managed by this object. 153 154 Raises: 155 ValueError: If num_masked < 1. 156 """ 157 name = name or "real_nvp" 158 if num_masked <= 0: 159 raise ValueError("num_masked must be a positive integer.") 160 self._num_masked = num_masked 161 # At construction time, we don't know input_depth. 162 self._input_depth = None 163 self._shift_and_log_scale_fn = shift_and_log_scale_fn 164 super(RealNVP, self).__init__( 165 event_ndims=1, 166 is_constant_jacobian=is_constant_jacobian, 167 validate_args=validate_args, 168 name=name) 169 170 def _cache_input_depth(self, x): 171 if self._input_depth is None: 172 self._input_depth = x.shape.with_rank_at_least(1)[-1].value 173 if self._input_depth is None: 174 raise NotImplementedError( 175 "Rightmost dimension must be known prior to graph execution.") 176 if self._num_masked >= self._input_depth: 177 raise ValueError( 178 "Number of masked units must be smaller than the event size.") 179 180 def _forward(self, x): 181 self._cache_input_depth(x) 182 # Performs scale and shift. 183 x0, x1 = x[:, :self._num_masked], x[:, self._num_masked:] 184 shift, log_scale = self._shift_and_log_scale_fn( 185 x0, self._input_depth - self._num_masked) 186 y1 = x1 187 if log_scale is not None: 188 y1 *= math_ops.exp(log_scale) 189 if shift is not None: 190 y1 += shift 191 y = array_ops.concat([x0, y1], axis=-1) 192 return y 193 194 def _inverse(self, y): 195 self._cache_input_depth(y) 196 # Performs un-shift and un-scale. 197 y0, y1 = y[:, :self._num_masked], y[:, self._num_masked:] 198 shift, log_scale = self._shift_and_log_scale_fn( 199 y0, self._input_depth - self._num_masked) 200 x1 = y1 201 if shift is not None: 202 x1 -= shift 203 if log_scale is not None: 204 x1 *= math_ops.exp(-log_scale) 205 x = array_ops.concat([y0, x1], axis=-1) 206 return x 207 208 def _inverse_log_det_jacobian(self, y): 209 self._cache_input_depth(y) 210 y0 = y[:, :self._num_masked] 211 _, log_scale = self._shift_and_log_scale_fn( 212 y0, self._input_depth - self._num_masked) 213 if log_scale is None: 214 return constant_op.constant(0., dtype=y.dtype, name="ildj") 215 return -math_ops.reduce_sum(log_scale, axis=-1) 216 217 def _forward_log_det_jacobian(self, x): 218 self._cache_input_depth(x) 219 x0 = x[:, :self._num_masked] 220 _, log_scale = self._shift_and_log_scale_fn( 221 x0, self._input_depth - self._num_masked) 222 if log_scale is None: 223 return constant_op.constant(0., dtype=x.dtype, name="ildj") 224 return math_ops.reduce_sum(log_scale, axis=-1) 225 226 227 def real_nvp_default_template( 228 hidden_layers, 229 shift_only=False, 230 activation=nn_ops.relu, 231 name=None, 232 *args, 233 **kwargs): 234 """Build a scale-and-shift function using a multi-layer neural network. 235 236 This will be wrapped in a make_template to ensure the variables are only 237 created once. It takes the `d`-dimensional input x[0:d] and returns the `D-d` 238 dimensional outputs `loc` ("mu") and `log_scale` ("alpha"). 239 240 Arguments: 241 hidden_layers: Python `list`-like of non-negative integer, scalars 242 indicating the number of units in each hidden layer. Default: `[512, 512]. 243 shift_only: Python `bool` indicating if only the `shift` term shall be 244 computed (i.e. NICE bijector). Default: `False`. 245 activation: Activation function (callable). Explicitly setting to `None` 246 implies a linear activation. 247 name: A name for ops managed by this function. Default: 248 "real_nvp_default_template". 249 *args: `tf.layers.dense` arguments. 250 **kwargs: `tf.layers.dense` keyword arguments. 251 252 Returns: 253 shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). 254 log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). 255 256 Raises: 257 NotImplementedError: if rightmost dimension of `inputs` is unknown prior to 258 graph execution. 259 """ 260 261 with ops.name_scope(name, "real_nvp_default_template"): 262 def _fn(x, output_units): 263 """Fully connected MLP parameterized via `real_nvp_template`.""" 264 for units in hidden_layers: 265 x = layers.dense( 266 inputs=x, 267 units=units, 268 activation=activation, 269 *args, 270 **kwargs) 271 x = layers.dense( 272 inputs=x, 273 units=(1 if shift_only else 2) * output_units, 274 activation=None, 275 *args, 276 **kwargs) 277 if shift_only: 278 return x, None 279 shift, log_scale = array_ops.split(x, 2, axis=-1) 280 return shift, log_scale 281 return template_ops.make_template( 282 "real_nvp_default_template", _fn) 283