1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for Bijector.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp 24 from tensorflow.contrib.distributions.python.ops.bijectors.inline import Inline 25 from tensorflow.python.framework import tensor_shape 26 from tensorflow.python.ops import array_ops 27 from tensorflow.python.ops import math_ops 28 from tensorflow.python.platform import test 29 30 31 class InlineBijectorTest(test.TestCase): 32 """Tests correctness of the inline constructed bijector.""" 33 34 def testBijector(self): 35 with self.test_session(): 36 exp = Exp(event_ndims=1) 37 inline = Inline( 38 forward_fn=math_ops.exp, 39 inverse_fn=math_ops.log, 40 inverse_log_det_jacobian_fn=( 41 lambda y: -math_ops.reduce_sum( # pylint: disable=g-long-lambda 42 math_ops.log(y), reduction_indices=-1)), 43 forward_log_det_jacobian_fn=( 44 lambda x: math_ops.reduce_sum(x, reduction_indices=-1)), 45 name="exp") 46 47 self.assertEqual(exp.name, inline.name) 48 x = [[[1., 2.], [3., 4.], [5., 6.]]] 49 y = np.exp(x) 50 self.assertAllClose(y, inline.forward(x).eval()) 51 self.assertAllClose(x, inline.inverse(y).eval()) 52 self.assertAllClose( 53 -np.sum(np.log(y), axis=-1), 54 inline.inverse_log_det_jacobian(y).eval()) 55 self.assertAllClose(-inline.inverse_log_det_jacobian(y).eval(), 56 inline.forward_log_det_jacobian(x).eval()) 57 58 def testShapeGetters(self): 59 with self.test_session(): 60 bijector = Inline( 61 forward_event_shape_tensor_fn=lambda x: array_ops.concat((x, [1]), 0), 62 forward_event_shape_fn=lambda x: x.as_list() + [1], 63 inverse_event_shape_tensor_fn=lambda x: x[:-1], 64 inverse_event_shape_fn=lambda x: x[:-1], 65 name="shape_only") 66 x = tensor_shape.TensorShape([1, 2, 3]) 67 y = tensor_shape.TensorShape([1, 2, 3, 1]) 68 self.assertAllEqual(y, bijector.forward_event_shape(x)) 69 self.assertAllEqual( 70 y.as_list(), 71 bijector.forward_event_shape_tensor(x.as_list()).eval()) 72 self.assertAllEqual(x, bijector.inverse_event_shape(y)) 73 self.assertAllEqual( 74 x.as_list(), 75 bijector.inverse_event_shape_tensor(y.as_list()).eval()) 76 77 78 if __name__ == "__main__": 79 test.main() 80