1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Unit tests for tf_contextlib.""" 16 17 # pylint: disable=unused-import 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 from tensorflow.python.platform import test 23 from tensorflow.python.util import tf_contextlib 24 from tensorflow.python.util import tf_decorator 25 from tensorflow.python.util import tf_inspect 26 27 28 @tf_contextlib.contextmanager 29 def test_yield_append_before_and_after_yield(x, before, after): 30 x.append(before) 31 yield 32 x.append(after) 33 34 35 @tf_contextlib.contextmanager 36 def test_yield_return_x_plus_1(x): 37 yield x + 1 38 39 40 @tf_contextlib.contextmanager 41 def test_params_and_defaults(a, b=2, c=True, d='hello'): 42 return [a, b, c, d] 43 44 45 class TfContextlibTest(test.TestCase): 46 47 def testRunsCodeBeforeYield(self): 48 x = [] 49 with test_yield_append_before_and_after_yield(x, 'before', ''): 50 self.assertEqual('before', x[-1]) 51 52 def testRunsCodeAfterYield(self): 53 x = [] 54 with test_yield_append_before_and_after_yield(x, '', 'after'): 55 pass 56 self.assertEqual('after', x[-1]) 57 58 def testNestedWith(self): 59 x = [] 60 with test_yield_append_before_and_after_yield(x, 'before', 'after'): 61 with test_yield_append_before_and_after_yield(x, 'inner', 'outer'): 62 with test_yield_return_x_plus_1(1) as var: 63 x.append(var) 64 self.assertEqual(['before', 'inner', 2, 'outer', 'after'], x) 65 66 def testMultipleCallsOfSeparateInstances(self): 67 x = [] 68 with test_yield_append_before_and_after_yield(x, 1, 2): 69 pass 70 with test_yield_append_before_and_after_yield(x, 3, 4): 71 pass 72 self.assertEqual([1, 2, 3, 4], x) 73 74 def testReturnsResultFromYield(self): 75 with test_yield_return_x_plus_1(3) as result: 76 self.assertEqual(4, result) 77 78 def testUnwrapContextManager(self): 79 decorators, target = tf_decorator.unwrap(test_params_and_defaults) 80 self.assertEqual(1, len(decorators)) 81 self.assertTrue(isinstance(decorators[0], tf_decorator.TFDecorator)) 82 self.assertEqual('contextmanager', decorators[0].decorator_name) 83 self.assertFalse(isinstance(target, tf_decorator.TFDecorator)) 84 85 def testGetArgSpecReturnsWrappedArgSpec(self): 86 argspec = tf_inspect.getargspec(test_params_and_defaults) 87 self.assertEqual(['a', 'b', 'c', 'd'], argspec.args) 88 self.assertEqual((2, True, 'hello'), argspec.defaults) 89 90 91 if __name__ == '__main__': 92 test.main() 93