1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Unit tests for tf_decorator.""" 16 17 # pylint: disable=unused-import 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import functools 23 24 from tensorflow.python.platform import test 25 from tensorflow.python.platform import tf_logging as logging 26 from tensorflow.python.util import tf_decorator 27 from tensorflow.python.util import tf_inspect 28 29 30 def test_tfdecorator(decorator_name, decorator_doc=None): 31 32 def make_tf_decorator(target): 33 return tf_decorator.TFDecorator(decorator_name, target, decorator_doc) 34 35 return make_tf_decorator 36 37 38 def test_decorator_increment_first_int_arg(target): 39 """This test decorator skips past `self` as args[0] in the bound case.""" 40 41 def wrapper(*args, **kwargs): 42 new_args = [] 43 found = False 44 for arg in args: 45 if not found and isinstance(arg, int): 46 new_args.append(arg + 1) 47 found = True 48 else: 49 new_args.append(arg) 50 return target(*new_args, **kwargs) 51 52 return tf_decorator.make_decorator(target, wrapper) 53 54 55 def test_function(x): 56 """Test Function Docstring.""" 57 return x + 1 58 59 60 @test_tfdecorator('decorator 1') 61 @test_decorator_increment_first_int_arg 62 @test_tfdecorator('decorator 3', 'decorator 3 documentation') 63 def test_decorated_function(x): 64 """Test Decorated Function Docstring.""" 65 return x * 2 66 67 68 @test_tfdecorator('decorator') 69 class TestDecoratedClass(object): 70 """Test Decorated Class.""" 71 72 def __init__(self, two_attr=2): 73 self.two_attr = two_attr 74 75 @property 76 def two_prop(self): 77 return 2 78 79 def two_func(self): 80 return 2 81 82 @test_decorator_increment_first_int_arg 83 def return_params(self, a, b, c): 84 """Return parameters.""" 85 return [a, b, c] 86 87 88 class TfDecoratorTest(test.TestCase): 89 90 def testInitCapturesTarget(self): 91 self.assertIs(test_function, 92 tf_decorator.TFDecorator('', test_function).decorated_target) 93 94 def testInitCapturesDecoratorName(self): 95 self.assertEqual('decorator name', 96 tf_decorator.TFDecorator('decorator name', 97 test_function).decorator_name) 98 99 def testInitCapturesDecoratorDoc(self): 100 self.assertEqual('decorator doc', 101 tf_decorator.TFDecorator('', test_function, 102 'decorator doc').decorator_doc) 103 104 def testInitCapturesNonNoneArgspec(self): 105 argspec = tf_inspect.ArgSpec( 106 args=['a', 'b', 'c'], 107 varargs=None, 108 keywords=None, 109 defaults=(1, 'hello')) 110 self.assertIs(argspec, 111 tf_decorator.TFDecorator('', test_function, '', 112 argspec).decorator_argspec) 113 114 def testInitSetsDecoratorNameToTargetName(self): 115 self.assertEqual('test_function', 116 tf_decorator.TFDecorator('', test_function).__name__) 117 118 def testInitSetsDecoratorDocToTargetDoc(self): 119 self.assertEqual('Test Function Docstring.', 120 tf_decorator.TFDecorator('', test_function).__doc__) 121 122 def testCallingATFDecoratorCallsTheTarget(self): 123 self.assertEqual(124, tf_decorator.TFDecorator('', test_function)(123)) 124 125 def testCallingADecoratedFunctionCallsTheTarget(self): 126 self.assertEqual((2 + 1) * 2, test_decorated_function(2)) 127 128 def testInitializingDecoratedClassWithInitParamsDoesntRaise(self): 129 try: 130 TestDecoratedClass(2) 131 except TypeError: 132 self.assertFail() 133 134 def testReadingClassAttributeOnDecoratedClass(self): 135 self.assertEqual(2, TestDecoratedClass().two_attr) 136 137 def testCallingClassMethodOnDecoratedClass(self): 138 self.assertEqual(2, TestDecoratedClass().two_func()) 139 140 def testReadingClassPropertyOnDecoratedClass(self): 141 self.assertEqual(2, TestDecoratedClass().two_prop) 142 143 def testNameOnBoundProperty(self): 144 self.assertEqual('return_params', 145 TestDecoratedClass().return_params.__name__) 146 147 def testDocstringOnBoundProperty(self): 148 self.assertEqual('Return parameters.', 149 TestDecoratedClass().return_params.__doc__) 150 151 152 def test_wrapper(*args, **kwargs): 153 return test_function(*args, **kwargs) 154 155 156 class TfMakeDecoratorTest(test.TestCase): 157 158 def testAttachesATFDecoratorAttr(self): 159 decorated = tf_decorator.make_decorator(test_function, test_wrapper) 160 decorator = getattr(decorated, '_tf_decorator') 161 self.assertIsInstance(decorator, tf_decorator.TFDecorator) 162 163 def testAttachesWrappedAttr(self): 164 decorated = tf_decorator.make_decorator(test_function, test_wrapper) 165 wrapped_attr = getattr(decorated, '__wrapped__') 166 self.assertIs(test_function, wrapped_attr) 167 168 def testSetsTFDecoratorNameToDecoratorNameArg(self): 169 decorated = tf_decorator.make_decorator(test_function, test_wrapper, 170 'test decorator name') 171 decorator = getattr(decorated, '_tf_decorator') 172 self.assertEqual('test decorator name', decorator.decorator_name) 173 174 def testSetsTFDecoratorDocToDecoratorDocArg(self): 175 decorated = tf_decorator.make_decorator( 176 test_function, test_wrapper, decorator_doc='test decorator doc') 177 decorator = getattr(decorated, '_tf_decorator') 178 self.assertEqual('test decorator doc', decorator.decorator_doc) 179 180 def testSetsTFDecoratorArgSpec(self): 181 argspec = tf_inspect.ArgSpec( 182 args=['a', 'b', 'c'], 183 varargs=None, 184 keywords=None, 185 defaults=(1, 'hello')) 186 decorated = tf_decorator.make_decorator(test_function, test_wrapper, '', '', 187 argspec) 188 decorator = getattr(decorated, '_tf_decorator') 189 self.assertEqual(argspec, decorator.decorator_argspec) 190 191 def testSetsDecoratorNameToFunctionThatCallsMakeDecoratorIfAbsent(self): 192 193 def test_decorator_name(wrapper): 194 return tf_decorator.make_decorator(test_function, wrapper) 195 196 decorated = test_decorator_name(test_wrapper) 197 decorator = getattr(decorated, '_tf_decorator') 198 self.assertEqual('test_decorator_name', decorator.decorator_name) 199 200 def testCompatibleWithNamelessCallables(self): 201 202 class Callable(object): 203 204 def __call__(self): 205 pass 206 207 callable_object = Callable() 208 # Smoke test: This should not raise an exception, even though 209 # `callable_object` does not have a `__name__` attribute. 210 _ = tf_decorator.make_decorator(callable_object, test_wrapper) 211 212 partial = functools.partial(test_function, x=1) 213 # Smoke test: This should not raise an exception, even though `partial` does 214 # not have `__name__`, `__module__`, and `__doc__` attributes. 215 _ = tf_decorator.make_decorator(partial, test_wrapper) 216 217 218 class TfDecoratorUnwrapTest(test.TestCase): 219 220 def testUnwrapReturnsEmptyArrayForUndecoratedFunction(self): 221 decorators, _ = tf_decorator.unwrap(test_function) 222 self.assertEqual(0, len(decorators)) 223 224 def testUnwrapReturnsUndecoratedFunctionAsTarget(self): 225 _, target = tf_decorator.unwrap(test_function) 226 self.assertIs(test_function, target) 227 228 def testUnwrapReturnsFinalFunctionAsTarget(self): 229 self.assertEqual((4 + 1) * 2, test_decorated_function(4)) 230 _, target = tf_decorator.unwrap(test_decorated_function) 231 self.assertTrue(tf_inspect.isfunction(target)) 232 self.assertEqual(4 * 2, target(4)) 233 234 def testUnwrapReturnsListOfUniqueTFDecorators(self): 235 decorators, _ = tf_decorator.unwrap(test_decorated_function) 236 self.assertEqual(3, len(decorators)) 237 self.assertTrue(isinstance(decorators[0], tf_decorator.TFDecorator)) 238 self.assertTrue(isinstance(decorators[1], tf_decorator.TFDecorator)) 239 self.assertTrue(isinstance(decorators[2], tf_decorator.TFDecorator)) 240 self.assertIsNot(decorators[0], decorators[1]) 241 self.assertIsNot(decorators[1], decorators[2]) 242 self.assertIsNot(decorators[2], decorators[0]) 243 244 def testUnwrapReturnsDecoratorListFromOutermostToInnermost(self): 245 decorators, _ = tf_decorator.unwrap(test_decorated_function) 246 self.assertEqual('decorator 1', decorators[0].decorator_name) 247 self.assertEqual('test_decorator_increment_first_int_arg', 248 decorators[1].decorator_name) 249 self.assertEqual('decorator 3', decorators[2].decorator_name) 250 self.assertEqual('decorator 3 documentation', decorators[2].decorator_doc) 251 252 def testUnwrapBoundMethods(self): 253 test_decorated_class = TestDecoratedClass() 254 self.assertEqual([2, 2, 3], test_decorated_class.return_params(1, 2, 3)) 255 decorators, target = tf_decorator.unwrap(test_decorated_class.return_params) 256 self.assertEqual('test_decorator_increment_first_int_arg', 257 decorators[0].decorator_name) 258 self.assertEqual([1, 2, 3], target(test_decorated_class, 1, 2, 3)) 259 260 261 if __name__ == '__main__': 262 test.main() 263