Home | History | Annotate | Download | only in util
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Unit tests for tf_contextlib."""
     16 
     17 # pylint: disable=unused-import
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 from tensorflow.python.platform import test
     23 from tensorflow.python.util import tf_contextlib
     24 from tensorflow.python.util import tf_decorator
     25 from tensorflow.python.util import tf_inspect
     26 
     27 
     28 @tf_contextlib.contextmanager
     29 def test_yield_append_before_and_after_yield(x, before, after):
     30   x.append(before)
     31   yield
     32   x.append(after)
     33 
     34 
     35 @tf_contextlib.contextmanager
     36 def test_yield_return_x_plus_1(x):
     37   yield x + 1
     38 
     39 
     40 @tf_contextlib.contextmanager
     41 def test_params_and_defaults(a, b=2, c=True, d='hello'):
     42   return [a, b, c, d]
     43 
     44 
     45 class TfContextlibTest(test.TestCase):
     46 
     47   def testRunsCodeBeforeYield(self):
     48     x = []
     49     with test_yield_append_before_and_after_yield(x, 'before', ''):
     50       self.assertEqual('before', x[-1])
     51 
     52   def testRunsCodeAfterYield(self):
     53     x = []
     54     with test_yield_append_before_and_after_yield(x, '', 'after'):
     55       pass
     56     self.assertEqual('after', x[-1])
     57 
     58   def testNestedWith(self):
     59     x = []
     60     with test_yield_append_before_and_after_yield(x, 'before', 'after'):
     61       with test_yield_append_before_and_after_yield(x, 'inner', 'outer'):
     62         with test_yield_return_x_plus_1(1) as var:
     63           x.append(var)
     64     self.assertEqual(['before', 'inner', 2, 'outer', 'after'], x)
     65 
     66   def testMultipleCallsOfSeparateInstances(self):
     67     x = []
     68     with test_yield_append_before_and_after_yield(x, 1, 2):
     69       pass
     70     with test_yield_append_before_and_after_yield(x, 3, 4):
     71       pass
     72     self.assertEqual([1, 2, 3, 4], x)
     73 
     74   def testReturnsResultFromYield(self):
     75     with test_yield_return_x_plus_1(3) as result:
     76       self.assertEqual(4, result)
     77 
     78   def testUnwrapContextManager(self):
     79     decorators, target = tf_decorator.unwrap(test_params_and_defaults)
     80     self.assertEqual(1, len(decorators))
     81     self.assertTrue(isinstance(decorators[0], tf_decorator.TFDecorator))
     82     self.assertEqual('contextmanager', decorators[0].decorator_name)
     83     self.assertFalse(isinstance(target, tf_decorator.TFDecorator))
     84 
     85   def testGetArgSpecReturnsWrappedArgSpec(self):
     86     argspec = tf_inspect.getargspec(test_params_and_defaults)
     87     self.assertEqual(['a', 'b', 'c', 'd'], argspec.args)
     88     self.assertEqual((2, True, 'hello'), argspec.defaults)
     89 
     90 
     91 if __name__ == '__main__':
     92   test.main()
     93