Home | History | Annotate | Download | only in util
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Unit tests for tf_decorator."""
     16 
     17 # pylint: disable=unused-import
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import functools
     23 
     24 from tensorflow.python.platform import test
     25 from tensorflow.python.platform import tf_logging as logging
     26 from tensorflow.python.util import tf_decorator
     27 from tensorflow.python.util import tf_inspect
     28 
     29 
     30 def test_tfdecorator(decorator_name, decorator_doc=None):
     31 
     32   def make_tf_decorator(target):
     33     return tf_decorator.TFDecorator(decorator_name, target, decorator_doc)
     34 
     35   return make_tf_decorator
     36 
     37 
     38 def test_decorator_increment_first_int_arg(target):
     39   """This test decorator skips past `self` as args[0] in the bound case."""
     40 
     41   def wrapper(*args, **kwargs):
     42     new_args = []
     43     found = False
     44     for arg in args:
     45       if not found and isinstance(arg, int):
     46         new_args.append(arg + 1)
     47         found = True
     48       else:
     49         new_args.append(arg)
     50     return target(*new_args, **kwargs)
     51 
     52   return tf_decorator.make_decorator(target, wrapper)
     53 
     54 
     55 def test_function(x):
     56   """Test Function Docstring."""
     57   return x + 1
     58 
     59 
     60 @test_tfdecorator('decorator 1')
     61 @test_decorator_increment_first_int_arg
     62 @test_tfdecorator('decorator 3', 'decorator 3 documentation')
     63 def test_decorated_function(x):
     64   """Test Decorated Function Docstring."""
     65   return x * 2
     66 
     67 
     68 @test_tfdecorator('decorator')
     69 class TestDecoratedClass(object):
     70   """Test Decorated Class."""
     71 
     72   def __init__(self, two_attr=2):
     73     self.two_attr = two_attr
     74 
     75   @property
     76   def two_prop(self):
     77     return 2
     78 
     79   def two_func(self):
     80     return 2
     81 
     82   @test_decorator_increment_first_int_arg
     83   def return_params(self, a, b, c):
     84     """Return parameters."""
     85     return [a, b, c]
     86 
     87 
     88 class TfDecoratorTest(test.TestCase):
     89 
     90   def testInitCapturesTarget(self):
     91     self.assertIs(test_function,
     92                   tf_decorator.TFDecorator('', test_function).decorated_target)
     93 
     94   def testInitCapturesDecoratorName(self):
     95     self.assertEqual('decorator name',
     96                      tf_decorator.TFDecorator('decorator name',
     97                                               test_function).decorator_name)
     98 
     99   def testInitCapturesDecoratorDoc(self):
    100     self.assertEqual('decorator doc',
    101                      tf_decorator.TFDecorator('', test_function,
    102                                               'decorator doc').decorator_doc)
    103 
    104   def testInitCapturesNonNoneArgspec(self):
    105     argspec = tf_inspect.ArgSpec(
    106         args=['a', 'b', 'c'],
    107         varargs=None,
    108         keywords=None,
    109         defaults=(1, 'hello'))
    110     self.assertIs(argspec,
    111                   tf_decorator.TFDecorator('', test_function, '',
    112                                            argspec).decorator_argspec)
    113 
    114   def testInitSetsDecoratorNameToTargetName(self):
    115     self.assertEqual('test_function',
    116                      tf_decorator.TFDecorator('', test_function).__name__)
    117 
    118   def testInitSetsDecoratorDocToTargetDoc(self):
    119     self.assertEqual('Test Function Docstring.',
    120                      tf_decorator.TFDecorator('', test_function).__doc__)
    121 
    122   def testCallingATFDecoratorCallsTheTarget(self):
    123     self.assertEqual(124, tf_decorator.TFDecorator('', test_function)(123))
    124 
    125   def testCallingADecoratedFunctionCallsTheTarget(self):
    126     self.assertEqual((2 + 1) * 2, test_decorated_function(2))
    127 
    128   def testInitializingDecoratedClassWithInitParamsDoesntRaise(self):
    129     try:
    130       TestDecoratedClass(2)
    131     except TypeError:
    132       self.assertFail()
    133 
    134   def testReadingClassAttributeOnDecoratedClass(self):
    135     self.assertEqual(2, TestDecoratedClass().two_attr)
    136 
    137   def testCallingClassMethodOnDecoratedClass(self):
    138     self.assertEqual(2, TestDecoratedClass().two_func())
    139 
    140   def testReadingClassPropertyOnDecoratedClass(self):
    141     self.assertEqual(2, TestDecoratedClass().two_prop)
    142 
    143   def testNameOnBoundProperty(self):
    144     self.assertEqual('return_params',
    145                      TestDecoratedClass().return_params.__name__)
    146 
    147   def testDocstringOnBoundProperty(self):
    148     self.assertEqual('Return parameters.',
    149                      TestDecoratedClass().return_params.__doc__)
    150 
    151 
    152 def test_wrapper(*args, **kwargs):
    153   return test_function(*args, **kwargs)
    154 
    155 
    156 class TfMakeDecoratorTest(test.TestCase):
    157 
    158   def testAttachesATFDecoratorAttr(self):
    159     decorated = tf_decorator.make_decorator(test_function, test_wrapper)
    160     decorator = getattr(decorated, '_tf_decorator')
    161     self.assertIsInstance(decorator, tf_decorator.TFDecorator)
    162 
    163   def testAttachesWrappedAttr(self):
    164     decorated = tf_decorator.make_decorator(test_function, test_wrapper)
    165     wrapped_attr = getattr(decorated, '__wrapped__')
    166     self.assertIs(test_function, wrapped_attr)
    167 
    168   def testSetsTFDecoratorNameToDecoratorNameArg(self):
    169     decorated = tf_decorator.make_decorator(test_function, test_wrapper,
    170                                             'test decorator name')
    171     decorator = getattr(decorated, '_tf_decorator')
    172     self.assertEqual('test decorator name', decorator.decorator_name)
    173 
    174   def testSetsTFDecoratorDocToDecoratorDocArg(self):
    175     decorated = tf_decorator.make_decorator(
    176         test_function, test_wrapper, decorator_doc='test decorator doc')
    177     decorator = getattr(decorated, '_tf_decorator')
    178     self.assertEqual('test decorator doc', decorator.decorator_doc)
    179 
    180   def testSetsTFDecoratorArgSpec(self):
    181     argspec = tf_inspect.ArgSpec(
    182         args=['a', 'b', 'c'],
    183         varargs=None,
    184         keywords=None,
    185         defaults=(1, 'hello'))
    186     decorated = tf_decorator.make_decorator(test_function, test_wrapper, '', '',
    187                                             argspec)
    188     decorator = getattr(decorated, '_tf_decorator')
    189     self.assertEqual(argspec, decorator.decorator_argspec)
    190 
    191   def testSetsDecoratorNameToFunctionThatCallsMakeDecoratorIfAbsent(self):
    192 
    193     def test_decorator_name(wrapper):
    194       return tf_decorator.make_decorator(test_function, wrapper)
    195 
    196     decorated = test_decorator_name(test_wrapper)
    197     decorator = getattr(decorated, '_tf_decorator')
    198     self.assertEqual('test_decorator_name', decorator.decorator_name)
    199 
    200   def testCompatibleWithNamelessCallables(self):
    201 
    202     class Callable(object):
    203 
    204       def __call__(self):
    205         pass
    206 
    207     callable_object = Callable()
    208     # Smoke test: This should not raise an exception, even though
    209     # `callable_object` does not have a `__name__` attribute.
    210     _ = tf_decorator.make_decorator(callable_object, test_wrapper)
    211 
    212     partial = functools.partial(test_function, x=1)
    213     # Smoke test: This should not raise an exception, even though `partial` does
    214     # not have `__name__`, `__module__`, and `__doc__` attributes.
    215     _ = tf_decorator.make_decorator(partial, test_wrapper)
    216 
    217 
    218 class TfDecoratorUnwrapTest(test.TestCase):
    219 
    220   def testUnwrapReturnsEmptyArrayForUndecoratedFunction(self):
    221     decorators, _ = tf_decorator.unwrap(test_function)
    222     self.assertEqual(0, len(decorators))
    223 
    224   def testUnwrapReturnsUndecoratedFunctionAsTarget(self):
    225     _, target = tf_decorator.unwrap(test_function)
    226     self.assertIs(test_function, target)
    227 
    228   def testUnwrapReturnsFinalFunctionAsTarget(self):
    229     self.assertEqual((4 + 1) * 2, test_decorated_function(4))
    230     _, target = tf_decorator.unwrap(test_decorated_function)
    231     self.assertTrue(tf_inspect.isfunction(target))
    232     self.assertEqual(4 * 2, target(4))
    233 
    234   def testUnwrapReturnsListOfUniqueTFDecorators(self):
    235     decorators, _ = tf_decorator.unwrap(test_decorated_function)
    236     self.assertEqual(3, len(decorators))
    237     self.assertTrue(isinstance(decorators[0], tf_decorator.TFDecorator))
    238     self.assertTrue(isinstance(decorators[1], tf_decorator.TFDecorator))
    239     self.assertTrue(isinstance(decorators[2], tf_decorator.TFDecorator))
    240     self.assertIsNot(decorators[0], decorators[1])
    241     self.assertIsNot(decorators[1], decorators[2])
    242     self.assertIsNot(decorators[2], decorators[0])
    243 
    244   def testUnwrapReturnsDecoratorListFromOutermostToInnermost(self):
    245     decorators, _ = tf_decorator.unwrap(test_decorated_function)
    246     self.assertEqual('decorator 1', decorators[0].decorator_name)
    247     self.assertEqual('test_decorator_increment_first_int_arg',
    248                      decorators[1].decorator_name)
    249     self.assertEqual('decorator 3', decorators[2].decorator_name)
    250     self.assertEqual('decorator 3 documentation', decorators[2].decorator_doc)
    251 
    252   def testUnwrapBoundMethods(self):
    253     test_decorated_class = TestDecoratedClass()
    254     self.assertEqual([2, 2, 3], test_decorated_class.return_params(1, 2, 3))
    255     decorators, target = tf_decorator.unwrap(test_decorated_class.return_params)
    256     self.assertEqual('test_decorator_increment_first_int_arg',
    257                      decorators[0].decorator_name)
    258     self.assertEqual([1, 2, 3], target(test_decorated_class, 1, 2, 3))
    259 
    260 
    261 if __name__ == '__main__':
    262   test.main()
    263