1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Chain bijector.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import itertools 22 23 from tensorflow.python.framework import constant_op 24 from tensorflow.python.ops.distributions import bijector 25 26 27 __all__ = [ 28 "Chain", 29 ] 30 31 32 class Chain(bijector.Bijector): 33 """Bijector which applies a sequence of bijectors. 34 35 Example Use: 36 37 ```python 38 chain = Chain([Exp(), Softplus()], name="one_plus_exp") 39 ``` 40 41 Results in: 42 43 * Forward: 44 45 ```python 46 exp = Exp() 47 softplus = Softplus() 48 Chain([exp, softplus]).forward(x) 49 = exp.forward(softplus.forward(x)) 50 = tf.exp(tf.log(1. + tf.exp(x))) 51 = 1. + tf.exp(x) 52 ``` 53 54 * Inverse: 55 56 ```python 57 exp = Exp() 58 softplus = Softplus() 59 Chain([exp, softplus]).inverse(y) 60 = softplus.inverse(exp.inverse(y)) 61 = tf.log(tf.exp(tf.log(y)) - 1.) 62 = tf.log(y - 1.) 63 ``` 64 65 """ 66 67 def __init__(self, bijectors=None, validate_args=False, name=None): 68 """Instantiates `Chain` bijector. 69 70 Args: 71 bijectors: Python `list` of bijector instances. An empty list makes this 72 bijector equivalent to the `Identity` bijector. 73 validate_args: Python `bool` indicating whether arguments should be 74 checked for correctness. 75 name: Python `str`, name given to ops managed by this object. Default: 76 E.g., `Chain([Exp(), Softplus()]).name == "chain_of_exp_of_softplus"`. 77 78 Raises: 79 ValueError: if bijectors have different dtypes. 80 """ 81 if bijectors is None: 82 bijectors = () 83 self._bijectors = bijectors 84 85 for a_bijector in bijectors: 86 if not a_bijector._is_injective: # pylint: disable=protected-access 87 raise NotImplementedError( 88 "Invert is not implemented for non-injective bijector ({})".format( 89 a_bijector.name)) 90 91 dtype = list(set([b.dtype for b in bijectors])) 92 if len(dtype) > 2: 93 raise ValueError("incompatible dtypes: %s" % dtype) 94 elif len(dtype) == 2: 95 dtype = dtype[1] if dtype[0] is None else dtype[0] 96 event_ndims = bijectors[0].event_ndims 97 elif len(dtype) == 1: 98 dtype = dtype[0] 99 event_ndims = bijectors[0].event_ndims 100 else: 101 dtype = None 102 event_ndims = None 103 104 super(Chain, self).__init__( 105 graph_parents=list(itertools.chain.from_iterable( 106 b.graph_parents for b in bijectors)), 107 is_constant_jacobian=all(b.is_constant_jacobian for b in bijectors), 108 validate_args=validate_args, 109 dtype=dtype, 110 event_ndims=event_ndims, 111 name=name or ("identity" if not bijectors else 112 "_of_".join(["chain"] + [b.name for b in bijectors]))) 113 114 @property 115 def bijectors(self): 116 return self._bijectors 117 118 def _shape_helper(self, func_name, input_shape, reverse): 119 new_shape = input_shape 120 for b in reversed(self.bijectors) if reverse else self.bijectors: 121 func = getattr(b, func_name, None) 122 if func is None: 123 raise ValueError("unable to call %s on bijector %s (%s)" % 124 (func_name, b.name, func)) 125 new_shape = func(new_shape) 126 return new_shape 127 128 def _forward_event_shape(self, input_shape): 129 return self._shape_helper("forward_event_shape", input_shape, 130 reverse=True) 131 132 def _forward_event_shape_tensor(self, input_shape): 133 return self._shape_helper( 134 "forward_event_shape_tensor", input_shape, reverse=True) 135 136 def _inverse_event_shape(self, output_shape): 137 return self._shape_helper("inverse_event_shape", output_shape, 138 reverse=False) 139 140 def _inverse_event_shape_tensor(self, output_shape): 141 return self._shape_helper("inverse_event_shape_tensor", output_shape, 142 reverse=False) 143 144 def _inverse(self, y, **kwargs): 145 for b in self.bijectors: 146 y = b.inverse(y, **kwargs.get(b.name, {})) 147 return y 148 149 def _inverse_log_det_jacobian(self, y, **kwargs): 150 ildj = constant_op.constant(0., dtype=y.dtype, 151 name="inverse_log_det_jacobian") 152 for b in self.bijectors: 153 ildj += b.inverse_log_det_jacobian(y, **kwargs.get(b.name, {})) 154 y = b.inverse(y, **kwargs.get(b.name, {})) 155 return ildj 156 157 def _forward(self, x, **kwargs): 158 for b in reversed(self.bijectors): 159 x = b.forward(x, **kwargs.get(b.name, {})) 160 return x 161 162 def _forward_log_det_jacobian(self, x, **kwargs): 163 fldj = constant_op.constant(0., dtype=x.dtype, 164 name="forward_log_det_jacobian") 165 for b in reversed(self.bijectors): 166 fldj += b.forward_log_det_jacobian(x, **kwargs.get(b.name, {})) 167 x = b.forward(x, **kwargs.get(b.name, {})) 168 return fldj 169