1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """PowerTransform bijector.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.framework import ops 22 from tensorflow.python.framework import tensor_util 23 from tensorflow.python.ops import check_ops 24 from tensorflow.python.ops import control_flow_ops 25 from tensorflow.python.ops import math_ops 26 from tensorflow.python.ops.distributions import bijector 27 28 29 __all__ = [ 30 "PowerTransform", 31 ] 32 33 34 class PowerTransform(bijector.Bijector): 35 """Compute `Y = g(X) = (1 + X * c)**(1 / c), X >= -1 / c`. 36 37 The [power transform](https://en.wikipedia.org/wiki/Power_transform) maps 38 inputs from `[0, inf]` to `[-1/c, inf]`; this is equivalent to the `inverse` 39 of this bijector. 40 41 This bijector is equivalent to the `Exp` bijector when `c=0`. 42 """ 43 44 def __init__(self, 45 power=0., 46 event_ndims=0, 47 validate_args=False, 48 name="power_transform"): 49 """Instantiates the `PowerTransform` bijector. 50 51 Args: 52 power: Python `float` scalar indicating the transform power, i.e., 53 `Y = g(X) = (1 + X * c)**(1 / c)` where `c` is the `power`. 54 event_ndims: Python scalar indicating the number of dimensions associated 55 with a particular draw from the distribution. 56 validate_args: Python `bool` indicating whether arguments should be 57 checked for correctness. 58 name: Python `str` name given to ops managed by this object. 59 60 Raises: 61 ValueError: if `power < 0` or is not known statically. 62 """ 63 self._graph_parents = [] 64 self._name = name 65 self._validate_args = validate_args 66 with self._name_scope("init", values=[power]): 67 power = tensor_util.constant_value( 68 ops.convert_to_tensor(power, name="power")) 69 if power is None or power < 0: 70 raise ValueError("`power` must be a non-negative TF constant.") 71 self._power = power 72 super(PowerTransform, self).__init__( 73 event_ndims=event_ndims, 74 validate_args=validate_args, 75 name=name) 76 77 @property 78 def power(self): 79 """The `c` in: `Y = g(X) = (1 + X * c)**(1 / c)`.""" 80 return self._power 81 82 def _forward(self, x): 83 x = self._maybe_assert_valid_x(x) 84 if self.power == 0.: 85 return math_ops.exp(x) 86 # If large x accuracy is an issue, consider using: 87 # (1. + x * self.power)**(1. / self.power) when x >> 1. 88 return math_ops.exp(math_ops.log1p(x * self.power) / self.power) 89 90 def _inverse(self, y): 91 y = self._maybe_assert_valid_y(y) 92 if self.power == 0.: 93 return math_ops.log(y) 94 # If large y accuracy is an issue, consider using: 95 # (y**self.power - 1.) / self.power when y >> 1. 96 return math_ops.expm1(math_ops.log(y) * self.power) / self.power 97 98 def _inverse_log_det_jacobian(self, y): 99 y = self._maybe_assert_valid_y(y) 100 event_dims = self._event_dims_tensor(y) 101 return (self.power - 1.) * math_ops.reduce_sum( 102 math_ops.log(y), axis=event_dims) 103 104 def _forward_log_det_jacobian(self, x): 105 x = self._maybe_assert_valid_x(x) 106 event_dims = self._event_dims_tensor(x) 107 if self.power == 0.: 108 return math_ops.reduce_sum(x, axis=event_dims) 109 return (1. / self.power - 1.) * math_ops.reduce_sum( 110 math_ops.log1p(x * self.power), 111 axis=event_dims) 112 113 def _maybe_assert_valid_x(self, x): 114 if not self.validate_args or self.power == 0.: 115 return x 116 is_valid = check_ops.assert_non_negative( 117 1. + self.power * x, 118 message="Forward transformation input must be at least {}.".format( 119 -1. / self.power)) 120 return control_flow_ops.with_dependencies([is_valid], x) 121 122 def _maybe_assert_valid_y(self, y): 123 if not self.validate_args: 124 return y 125 is_valid = check_ops.assert_positive( 126 y, message="Inverse transformation input must be greater than 0.") 127 return control_flow_ops.with_dependencies([is_valid], y) 128