1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Utilities for exporting TensorFlow symbols to the API. 16 17 Exporting a function or a class: 18 19 To export a function or a class use tf_export decorator. For e.g.: 20 ```python 21 @tf_export('foo', 'bar.foo') 22 def foo(...): 23 ... 24 ``` 25 26 If a function is assigned to a variable, you can export it by calling 27 tf_export explicitly. For e.g.: 28 ```python 29 foo = get_foo(...) 30 tf_export('foo', 'bar.foo')(foo) 31 ``` 32 33 34 Exporting a constant 35 ```python 36 foo = 1 37 tf_export("consts.foo").export_constant(__name__, 'foo') 38 ``` 39 """ 40 from __future__ import absolute_import 41 from __future__ import division 42 from __future__ import print_function 43 44 import sys 45 46 from tensorflow.python.util import tf_decorator 47 48 49 class SymbolAlreadyExposedError(Exception): 50 """Raised when adding API names to symbol that already has API names.""" 51 pass 52 53 54 class tf_export(object): # pylint: disable=invalid-name 55 """Provides ways to export symbols to the TensorFlow API.""" 56 57 def __init__(self, *args, **kwargs): 58 """Export under the names *args (first one is considered canonical). 59 60 Args: 61 *args: API names in dot delimited format. 62 **kwargs: Optional keyed arguments. Currently only supports 'overrides' 63 argument. overrides: List of symbols that this is overriding 64 (those overrided api exports will be removed). Note: passing overrides 65 has no effect on exporting a constant. 66 """ 67 self._names = args 68 self._overrides = kwargs.get('overrides', []) 69 70 def __call__(self, func): 71 """Calls this decorator. 72 73 Args: 74 func: decorated symbol (function or class). 75 76 Returns: 77 The input function with _tf_api_names attribute set. 78 79 Raises: 80 SymbolAlreadyExposedError: Raised when a symbol already has API names. 81 """ 82 # Undecorate overridden names 83 for f in self._overrides: 84 _, undecorated_f = tf_decorator.unwrap(f) 85 del undecorated_f._tf_api_names # pylint: disable=protected-access 86 87 _, undecorated_func = tf_decorator.unwrap(func) 88 89 # Check for an existing api. We check if attribute name is in 90 # __dict__ instead of using hasattr to verify that subclasses have 91 # their own _tf_api_names as opposed to just inheriting it. 92 if '_tf_api_names' in undecorated_func.__dict__: 93 # pylint: disable=protected-access 94 raise SymbolAlreadyExposedError( 95 'Symbol %s is already exposed as %s.' % 96 (undecorated_func.__name__, undecorated_func._tf_api_names)) 97 # pylint: enable=protected-access 98 99 # Complete the export by creating/overriding attribute 100 # pylint: disable=protected-access 101 undecorated_func._tf_api_names = self._names 102 # pylint: enable=protected-access 103 return func 104 105 def export_constant(self, module_name, name): 106 """Store export information for constants/string literals. 107 108 Export information is stored in the module where constants/string literals 109 are defined. 110 111 e.g. 112 ```python 113 foo = 1 114 bar = 2 115 tf_export("consts.foo").export_constant(__name__, 'foo') 116 tf_export("consts.bar").export_constant(__name__, 'bar') 117 ``` 118 119 Args: 120 module_name: (string) Name of the module to store constant at. 121 name: (string) Current constant name. 122 """ 123 module = sys.modules[module_name] 124 if not hasattr(module, '_tf_api_constants'): 125 module._tf_api_constants = [] # pylint: disable=protected-access 126 # pylint: disable=protected-access 127 module._tf_api_constants.append((self._names, name)) 128 129