1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for multiple_dispatch.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 from tensorflow.contrib.py2tf.utils import multiple_dispatch 21 from tensorflow.python.client.session import Session 22 from tensorflow.python.framework.constant_op import constant 23 from tensorflow.python.platform import test 24 25 26 class MultipleDispatchTest(test.TestCase): 27 28 def test_run_cond_python(self): 29 true_fn = lambda: 2.0 30 false_fn = lambda: 3.0 31 self.assertEqual(multiple_dispatch.run_cond(True, true_fn, false_fn), 2.0) 32 self.assertEqual(multiple_dispatch.run_cond(False, true_fn, false_fn), 3.0) 33 34 def test_run_cond_tf(self): 35 36 true_fn = lambda: constant([2.0]) 37 false_fn = lambda: constant([3.0]) 38 with Session() as sess: 39 out = multiple_dispatch.run_cond(constant(True), true_fn, false_fn) 40 self.assertEqual(sess.run(out), 2.0) 41 out = multiple_dispatch.run_cond(constant(False), true_fn, false_fn) 42 self.assertEqual(sess.run(out), 3.0) 43 44 def test_run_while_python(self): 45 cond_fn = lambda x, t, s: x > t 46 body_fn = lambda x, t, s: (x * s, t, s) 47 48 x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, [3.0, 1.0, 0.5]) 49 self.assertEqual(x, 0.75) 50 51 x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, [3.0, 4.0, 0.5]) 52 self.assertEqual(x, 3.0) 53 54 def test_run_while_tf(self): 55 cond_fn = lambda x, t, s: x > t 56 body_fn = lambda x, t, s: (x * s, t, s) 57 58 with Session() as sess: 59 x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, 60 [constant(3.0), 1.0, 0.5]) 61 self.assertEqual(sess.run(x), 0.75) 62 63 x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, 64 [constant(3.0), 4.0, 0.5]) 65 self.assertEqual(sess.run(x), 3.0) 66 67 68 if __name__ == '__main__': 69 test.main() 70