1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Base TFDecorator class and utility functions for working with decorators. 16 17 There are two ways to create decorators that TensorFlow can introspect into. 18 This is important for documentation generation purposes, so that function 19 signatures aren't obscured by the (*args, **kwds) signature that decorators 20 often provide. 21 22 1. Call `tf_decorator.make_decorator` on your wrapper function. If your 23 decorator is stateless, or can capture all of the variables it needs to work 24 with through lexical closure, this is the simplest option. Create your wrapper 25 function as usual, but instead of returning it, return 26 `tf_decorator.make_decorator(target, your_wrapper)`. This will attach some 27 decorator introspection metadata onto your wrapper and return it. 28 29 Example: 30 31 def print_hello_before_calling(target): 32 def wrapper(*args, **kwargs): 33 print('hello') 34 return target(*args, **kwargs) 35 return tf_decorator.make_decorator(target, wrapper) 36 37 2. Derive from TFDecorator. If your decorator needs to be stateful, you can 38 implement it in terms of a TFDecorator. Store whatever state you need in your 39 derived class, and implement the `__call__` method to do your work before 40 calling into your target. You can retrieve the target via 41 `super(MyDecoratorClass, self).decorated_target`, and call it with whatever 42 parameters it needs. 43 44 Example: 45 46 class CallCounter(tf_decorator.TFDecorator): 47 def __init__(self, target): 48 super(CallCounter, self).__init__('count_calls', target) 49 self.call_count = 0 50 51 def __call__(self, *args, **kwargs): 52 self.call_count += 1 53 return super(CallCounter, self).decorated_target(*args, **kwargs) 54 55 def count_calls(target): 56 return CallCounter(target) 57 """ 58 from __future__ import absolute_import 59 from __future__ import division 60 from __future__ import print_function 61 62 import functools as _functools 63 import traceback as _traceback 64 65 66 def make_decorator(target, 67 decorator_func, 68 decorator_name=None, 69 decorator_doc='', 70 decorator_argspec=None): 71 """Make a decorator from a wrapper and a target. 72 73 Args: 74 target: The final callable to be wrapped. 75 decorator_func: The wrapper function. 76 decorator_name: The name of the decorator. If `None`, the name of the 77 function calling make_decorator. 78 decorator_doc: Documentation specific to this application of 79 `decorator_func` to `target`. 80 decorator_argspec: The new callable signature of this decorator. 81 82 Returns: 83 The `decorator_func` argument with new metadata attached. 84 """ 85 if decorator_name is None: 86 frame = _traceback.extract_stack(limit=2)[0] 87 # frame name is tuple[2] in python2, and object.name in python3 88 decorator_name = getattr(frame, 'name', frame[2]) # Caller's name 89 decorator = TFDecorator(decorator_name, target, decorator_doc, 90 decorator_argspec) 91 setattr(decorator_func, '_tf_decorator', decorator) 92 # Objects that are callables (e.g., a functools.partial object) may not have 93 # the following attributes. 94 if hasattr(target, '__name__'): 95 decorator_func.__name__ = target.__name__ 96 if hasattr(target, '__module__'): 97 decorator_func.__module__ = target.__module__ 98 if hasattr(target, '__doc__'): 99 decorator_func.__doc__ = decorator.__doc__ 100 decorator_func.__wrapped__ = target 101 return decorator_func 102 103 104 def unwrap(maybe_tf_decorator): 105 """Unwraps an object into a list of TFDecorators and a final target. 106 107 Args: 108 maybe_tf_decorator: Any callable object. 109 110 Returns: 111 A tuple whose first element is an list of TFDecorator-derived objects that 112 were applied to the final callable target, and whose second element is the 113 final undecorated callable target. If the `maybe_tf_decorator` parameter is 114 not decorated by any TFDecorators, the first tuple element will be an empty 115 list. The `TFDecorator` list is ordered from outermost to innermost 116 decorators. 117 """ 118 decorators = [] 119 cur = maybe_tf_decorator 120 while True: 121 if isinstance(cur, TFDecorator): 122 decorators.append(cur) 123 elif hasattr(cur, '_tf_decorator'): 124 decorators.append(getattr(cur, '_tf_decorator')) 125 else: 126 break 127 cur = decorators[-1].decorated_target 128 return decorators, cur 129 130 131 class TFDecorator(object): 132 """Base class for all TensorFlow decorators. 133 134 TFDecorator captures and exposes the wrapped target, and provides details 135 about the current decorator. 136 """ 137 138 def __init__(self, 139 decorator_name, 140 target, 141 decorator_doc='', 142 decorator_argspec=None): 143 self._decorated_target = target 144 self._decorator_name = decorator_name 145 self._decorator_doc = decorator_doc 146 self._decorator_argspec = decorator_argspec 147 if hasattr(target, '__name__'): 148 self.__name__ = target.__name__ 149 if self._decorator_doc: 150 self.__doc__ = self._decorator_doc 151 elif hasattr(target, '__doc__') and target.__doc__: 152 self.__doc__ = target.__doc__ 153 else: 154 self.__doc__ = '' 155 156 def __get__(self, obj, objtype): 157 return _functools.partial(self.__call__, obj) 158 159 def __call__(self, *args, **kwargs): 160 return self._decorated_target(*args, **kwargs) 161 162 @property 163 def decorated_target(self): 164 return self._decorated_target 165 166 @property 167 def decorator_name(self): 168 return self._decorator_name 169 170 @property 171 def decorator_doc(self): 172 return self._decorator_doc 173 174 @property 175 def decorator_argspec(self): 176 return self._decorator_argspec 177