Home | History | Annotate | Download | only in bijectors
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for Bijector."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp
     24 from tensorflow.contrib.distributions.python.ops.bijectors.inline import Inline
     25 from tensorflow.python.framework import tensor_shape
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import math_ops
     28 from tensorflow.python.platform import test
     29 
     30 
     31 class InlineBijectorTest(test.TestCase):
     32   """Tests correctness of the inline constructed bijector."""
     33 
     34   def testBijector(self):
     35     with self.test_session():
     36       exp = Exp(event_ndims=1)
     37       inline = Inline(
     38           forward_fn=math_ops.exp,
     39           inverse_fn=math_ops.log,
     40           inverse_log_det_jacobian_fn=(
     41               lambda y: -math_ops.reduce_sum(  # pylint: disable=g-long-lambda
     42                   math_ops.log(y), reduction_indices=-1)),
     43           forward_log_det_jacobian_fn=(
     44               lambda x: math_ops.reduce_sum(x, reduction_indices=-1)),
     45           name="exp")
     46 
     47       self.assertEqual(exp.name, inline.name)
     48       x = [[[1., 2.], [3., 4.], [5., 6.]]]
     49       y = np.exp(x)
     50       self.assertAllClose(y, inline.forward(x).eval())
     51       self.assertAllClose(x, inline.inverse(y).eval())
     52       self.assertAllClose(
     53           -np.sum(np.log(y), axis=-1),
     54           inline.inverse_log_det_jacobian(y).eval())
     55       self.assertAllClose(-inline.inverse_log_det_jacobian(y).eval(),
     56                           inline.forward_log_det_jacobian(x).eval())
     57 
     58   def testShapeGetters(self):
     59     with self.test_session():
     60       bijector = Inline(
     61           forward_event_shape_tensor_fn=lambda x: array_ops.concat((x, [1]), 0),
     62           forward_event_shape_fn=lambda x: x.as_list() + [1],
     63           inverse_event_shape_tensor_fn=lambda x: x[:-1],
     64           inverse_event_shape_fn=lambda x: x[:-1],
     65           name="shape_only")
     66       x = tensor_shape.TensorShape([1, 2, 3])
     67       y = tensor_shape.TensorShape([1, 2, 3, 1])
     68       self.assertAllEqual(y, bijector.forward_event_shape(x))
     69       self.assertAllEqual(
     70           y.as_list(),
     71           bijector.forward_event_shape_tensor(x.as_list()).eval())
     72       self.assertAllEqual(x, bijector.inverse_event_shape(y))
     73       self.assertAllEqual(
     74           x.as_list(),
     75           bijector.inverse_event_shape_tensor(y.as_list()).eval())
     76 
     77 
     78 if __name__ == "__main__":
     79   test.main()
     80