Home | History | Annotate | Download | only in bijectors
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """PowerTransform bijector."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.framework import ops
     22 from tensorflow.python.framework import tensor_util
     23 from tensorflow.python.ops import check_ops
     24 from tensorflow.python.ops import control_flow_ops
     25 from tensorflow.python.ops import math_ops
     26 from tensorflow.python.ops.distributions import bijector
     27 
     28 
     29 __all__ = [
     30     "PowerTransform",
     31 ]
     32 
     33 
     34 class PowerTransform(bijector.Bijector):
     35   """Compute `Y = g(X) = (1 + X * c)**(1 / c), X >= -1 / c`.
     36 
     37   The [power transform](https://en.wikipedia.org/wiki/Power_transform) maps
     38   inputs from `[0, inf]` to `[-1/c, inf]`; this is equivalent to the `inverse`
     39   of this bijector.
     40 
     41   This bijector is equivalent to the `Exp` bijector when `c=0`.
     42   """
     43 
     44   def __init__(self,
     45                power=0.,
     46                event_ndims=0,
     47                validate_args=False,
     48                name="power_transform"):
     49     """Instantiates the `PowerTransform` bijector.
     50 
     51     Args:
     52       power: Python `float` scalar indicating the transform power, i.e.,
     53         `Y = g(X) = (1 + X * c)**(1 / c)` where `c` is the `power`.
     54       event_ndims: Python scalar indicating the number of dimensions associated
     55         with a particular draw from the distribution.
     56       validate_args: Python `bool` indicating whether arguments should be
     57         checked for correctness.
     58       name: Python `str` name given to ops managed by this object.
     59 
     60     Raises:
     61       ValueError: if `power < 0` or is not known statically.
     62     """
     63     self._graph_parents = []
     64     self._name = name
     65     self._validate_args = validate_args
     66     with self._name_scope("init", values=[power]):
     67       power = tensor_util.constant_value(
     68           ops.convert_to_tensor(power, name="power"))
     69     if power is None or power < 0:
     70       raise ValueError("`power` must be a non-negative TF constant.")
     71     self._power = power
     72     super(PowerTransform, self).__init__(
     73         event_ndims=event_ndims,
     74         validate_args=validate_args,
     75         name=name)
     76 
     77   @property
     78   def power(self):
     79     """The `c` in: `Y = g(X) = (1 + X * c)**(1 / c)`."""
     80     return self._power
     81 
     82   def _forward(self, x):
     83     x = self._maybe_assert_valid_x(x)
     84     if self.power == 0.:
     85       return math_ops.exp(x)
     86     # If large x accuracy is an issue, consider using:
     87     # (1. + x * self.power)**(1. / self.power) when x >> 1.
     88     return math_ops.exp(math_ops.log1p(x * self.power) / self.power)
     89 
     90   def _inverse(self, y):
     91     y = self._maybe_assert_valid_y(y)
     92     if self.power == 0.:
     93       return math_ops.log(y)
     94     # If large y accuracy is an issue, consider using:
     95     # (y**self.power - 1.) / self.power when y >> 1.
     96     return math_ops.expm1(math_ops.log(y) * self.power) / self.power
     97 
     98   def _inverse_log_det_jacobian(self, y):
     99     y = self._maybe_assert_valid_y(y)
    100     event_dims = self._event_dims_tensor(y)
    101     return (self.power - 1.) * math_ops.reduce_sum(
    102         math_ops.log(y), axis=event_dims)
    103 
    104   def _forward_log_det_jacobian(self, x):
    105     x = self._maybe_assert_valid_x(x)
    106     event_dims = self._event_dims_tensor(x)
    107     if self.power == 0.:
    108       return math_ops.reduce_sum(x, axis=event_dims)
    109     return (1. / self.power - 1.) * math_ops.reduce_sum(
    110         math_ops.log1p(x * self.power),
    111         axis=event_dims)
    112 
    113   def _maybe_assert_valid_x(self, x):
    114     if not self.validate_args or self.power == 0.:
    115       return x
    116     is_valid = check_ops.assert_non_negative(
    117         1. + self.power * x,
    118         message="Forward transformation input must be at least {}.".format(
    119             -1. / self.power))
    120     return control_flow_ops.with_dependencies([is_valid], x)
    121 
    122   def _maybe_assert_valid_y(self, y):
    123     if not self.validate_args:
    124       return y
    125     is_valid = check_ops.assert_positive(
    126         y, message="Inverse transformation input must be greater than 0.")
    127     return control_flow_ops.with_dependencies([is_valid], y)
    128