1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """ConditionalBijector Tests.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import ConditionalBijector 22 from tensorflow.python.framework import dtypes 23 from tensorflow.python.platform import test 24 25 26 class _TestBijector(ConditionalBijector): 27 28 def __init__(self): 29 super(_TestBijector, self).__init__( 30 event_ndims=0, 31 graph_parents=[], 32 is_constant_jacobian=True, 33 validate_args=False, 34 dtype=dtypes.float32, 35 name="test_bijector") 36 37 def _forward(self, _, arg1, arg2): 38 raise ValueError("forward", arg1, arg2) 39 40 def _inverse(self, _, arg1, arg2): 41 raise ValueError("inverse", arg1, arg2) 42 43 def _inverse_log_det_jacobian(self, _, arg1, arg2): 44 raise ValueError("inverse_log_det_jacobian", arg1, arg2) 45 46 def _forward_log_det_jacobian(self, _, arg1, arg2): 47 raise ValueError("forward_log_det_jacobian", arg1, arg2) 48 49 50 class ConditionalBijectorTest(test.TestCase): 51 52 def testConditionalBijector(self): 53 b = _TestBijector() 54 for name in ["forward", "inverse", "inverse_log_det_jacobian", 55 "forward_log_det_jacobian"]: 56 method = getattr(b, name) 57 with self.assertRaisesRegexp(ValueError, name + ".*b1.*b2"): 58 method(1.0, arg1="b1", arg2="b2") 59 60 61 if __name__ == "__main__": 62 test.main() 63