Home | History | Annotate | Download | only in bijectors
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Permutation bijectors."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.framework import constant_op
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.framework import tensor_util
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import check_ops
     28 from tensorflow.python.ops import control_flow_ops
     29 from tensorflow.python.ops import math_ops
     30 from tensorflow.python.ops import nn_ops
     31 from tensorflow.python.ops.distributions import bijector as bijector_lib
     32 
     33 
     34 __all__ = [
     35     "Permute",
     36 ]
     37 
     38 
     39 class Permute(bijector_lib.Bijector):
     40   """Permutes the rightmost dimension of a `Tensor`.
     41 
     42   ```python
     43   tfd = tf.contrib.distributions
     44 
     45   reverse = tfd.bijectors.Permute(permutation=[2, 1, 0])
     46 
     47   reverse.forward([-1., 0., 1.])
     48   # ==> [1., 0., -1]
     49 
     50   reverse.inverse([1., 0., -1])
     51   # ==> [-1., 0., 1.]
     52 
     53   reverse.forward_log_det_jacobian(any_value)
     54   # ==> 0.
     55 
     56   reverse.inverse_log_det_jacobian(any_value)
     57   # ==> 0.
     58   ```
     59 
     60   Warning: `tf.estimator` may repeatedly build the graph thus
     61   `Permute(np.random.permutation(event_size)).astype("int32"))` is not a
     62   reliable parameterization (nor would it be even if using `tf.constant`). A
     63   safe alternative is to use `tf.get_variable` to achieve "init once" behavior,
     64   i.e.,
     65 
     66   ```python
     67   def init_once(x, name):
     68     return tf.get_variable(name, initializer=x, trainable=False)
     69 
     70   Permute(permutation=init_once(
     71       np.random.permutation(event_size).astype("int32"),
     72       name="permutation"))
     73   ```
     74 
     75   """
     76 
     77   def __init__(self, permutation, validate_args=False, name=None):
     78     """Creates the `Permute` bijector.
     79 
     80     Args:
     81       permutation: An `int`-like vector-shaped `Tensor` representing the
     82         permutation to apply to the rightmost dimension of the transformed
     83         `Tensor`.
     84       validate_args: Python `bool` indicating whether arguments should be
     85         checked for correctness.
     86       name: Python `str`, name given to ops managed by this object.
     87 
     88     Raises:
     89       TypeError: if `not permutation.dtype.is_integer`.
     90       ValueError: if `permutation` does not contain exactly one of each of
     91         `{0, 1, ..., d}`.
     92     """
     93     with ops.name_scope(name, "permute", values=[permutation]):
     94       permutation = ops.convert_to_tensor(
     95           permutation,
     96           name="permutation")
     97       if not permutation.dtype.is_integer:
     98         raise TypeError("permutation.dtype ({}) should be `int`-like.".format(
     99             permutation.dtype.name))
    100       p = tensor_util.constant_value(permutation)
    101       if p is not None:
    102         if set(p) != set(np.arange(p.size)):
    103           raise ValueError("Permutation over `d` must contain exactly one of "
    104                            "each of `{0, 1, ..., d}`.")
    105       elif validate_args:
    106         p, _ = nn_ops.top_k(-permutation,
    107                             k=array_ops.shape(permutation)[-1],
    108                             sorted=True)
    109         permutation = control_flow_ops.with_dependencies([
    110             check_ops.assert_equal(
    111                 -p, math_ops.range(array_ops.size(p)),
    112                 message=("Permutation over `d` must contain exactly one of "
    113                          "each of `{0, 1, ..., d}`.")),
    114         ], permutation)
    115       self._permutation = permutation
    116       super(Permute, self).__init__(
    117           is_constant_jacobian=True,
    118           validate_args=validate_args,
    119           name=name or "permute")
    120 
    121   @property
    122   def permutation(self):
    123     return self._permutation
    124 
    125   def _forward(self, x):
    126     return array_ops.gather(x, self.permutation, axis=-1)
    127 
    128   def _inverse(self, y):
    129     return array_ops.gather(
    130         y,
    131         array_ops.invert_permutation(self.permutation),
    132         axis=-1)
    133 
    134   def _inverse_log_det_jacobian(self, y):
    135     return constant_op.constant(0., dtype=y.dtype)
    136 
    137   def _forward_log_det_jacobian(self, x):
    138     return constant_op.constant(0., dtype=x.dtype)
    139