1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Permutation bijectors.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.framework import constant_op 24 from tensorflow.python.framework import ops 25 from tensorflow.python.framework import tensor_util 26 from tensorflow.python.ops import array_ops 27 from tensorflow.python.ops import check_ops 28 from tensorflow.python.ops import control_flow_ops 29 from tensorflow.python.ops import math_ops 30 from tensorflow.python.ops import nn_ops 31 from tensorflow.python.ops.distributions import bijector as bijector_lib 32 33 34 __all__ = [ 35 "Permute", 36 ] 37 38 39 class Permute(bijector_lib.Bijector): 40 """Permutes the rightmost dimension of a `Tensor`. 41 42 ```python 43 tfd = tf.contrib.distributions 44 45 reverse = tfd.bijectors.Permute(permutation=[2, 1, 0]) 46 47 reverse.forward([-1., 0., 1.]) 48 # ==> [1., 0., -1] 49 50 reverse.inverse([1., 0., -1]) 51 # ==> [-1., 0., 1.] 52 53 reverse.forward_log_det_jacobian(any_value) 54 # ==> 0. 55 56 reverse.inverse_log_det_jacobian(any_value) 57 # ==> 0. 58 ``` 59 60 Warning: `tf.estimator` may repeatedly build the graph thus 61 `Permute(np.random.permutation(event_size)).astype("int32"))` is not a 62 reliable parameterization (nor would it be even if using `tf.constant`). A 63 safe alternative is to use `tf.get_variable` to achieve "init once" behavior, 64 i.e., 65 66 ```python 67 def init_once(x, name): 68 return tf.get_variable(name, initializer=x, trainable=False) 69 70 Permute(permutation=init_once( 71 np.random.permutation(event_size).astype("int32"), 72 name="permutation")) 73 ``` 74 75 """ 76 77 def __init__(self, permutation, validate_args=False, name=None): 78 """Creates the `Permute` bijector. 79 80 Args: 81 permutation: An `int`-like vector-shaped `Tensor` representing the 82 permutation to apply to the rightmost dimension of the transformed 83 `Tensor`. 84 validate_args: Python `bool` indicating whether arguments should be 85 checked for correctness. 86 name: Python `str`, name given to ops managed by this object. 87 88 Raises: 89 TypeError: if `not permutation.dtype.is_integer`. 90 ValueError: if `permutation` does not contain exactly one of each of 91 `{0, 1, ..., d}`. 92 """ 93 with ops.name_scope(name, "permute", values=[permutation]): 94 permutation = ops.convert_to_tensor( 95 permutation, 96 name="permutation") 97 if not permutation.dtype.is_integer: 98 raise TypeError("permutation.dtype ({}) should be `int`-like.".format( 99 permutation.dtype.name)) 100 p = tensor_util.constant_value(permutation) 101 if p is not None: 102 if set(p) != set(np.arange(p.size)): 103 raise ValueError("Permutation over `d` must contain exactly one of " 104 "each of `{0, 1, ..., d}`.") 105 elif validate_args: 106 p, _ = nn_ops.top_k(-permutation, 107 k=array_ops.shape(permutation)[-1], 108 sorted=True) 109 permutation = control_flow_ops.with_dependencies([ 110 check_ops.assert_equal( 111 -p, math_ops.range(array_ops.size(p)), 112 message=("Permutation over `d` must contain exactly one of " 113 "each of `{0, 1, ..., d}`.")), 114 ], permutation) 115 self._permutation = permutation 116 super(Permute, self).__init__( 117 is_constant_jacobian=True, 118 validate_args=validate_args, 119 name=name or "permute") 120 121 @property 122 def permutation(self): 123 return self._permutation 124 125 def _forward(self, x): 126 return array_ops.gather(x, self.permutation, axis=-1) 127 128 def _inverse(self, y): 129 return array_ops.gather( 130 y, 131 array_ops.invert_permutation(self.permutation), 132 axis=-1) 133 134 def _inverse_log_det_jacobian(self, y): 135 return constant_op.constant(0., dtype=y.dtype) 136 137 def _forward_log_det_jacobian(self, x): 138 return constant_op.constant(0., dtype=x.dtype) 139