Home | History | Annotate | Download | only in util
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Utilities for exporting TensorFlow symbols to the API.
     16 
     17 Exporting a function or a class:
     18 
     19 To export a function or a class use tf_export decorator. For e.g.:
     20 ```python
     21 @tf_export('foo', 'bar.foo')
     22 def foo(...):
     23   ...
     24 ```
     25 
     26 If a function is assigned to a variable, you can export it by calling
     27 tf_export explicitly. For e.g.:
     28 ```python
     29 foo = get_foo(...)
     30 tf_export('foo', 'bar.foo')(foo)
     31 ```
     32 
     33 
     34 Exporting a constant
     35 ```python
     36 foo = 1
     37 tf_export("consts.foo").export_constant(__name__, 'foo')
     38 ```
     39 """
     40 from __future__ import absolute_import
     41 from __future__ import division
     42 from __future__ import print_function
     43 
     44 import sys
     45 
     46 from tensorflow.python.util import tf_decorator
     47 
     48 
     49 class SymbolAlreadyExposedError(Exception):
     50   """Raised when adding API names to symbol that already has API names."""
     51   pass
     52 
     53 
     54 class tf_export(object):  # pylint: disable=invalid-name
     55   """Provides ways to export symbols to the TensorFlow API."""
     56 
     57   def __init__(self, *args, **kwargs):
     58     """Export under the names *args (first one is considered canonical).
     59 
     60     Args:
     61       *args: API names in dot delimited format.
     62       **kwargs: Optional keyed arguments. Currently only supports 'overrides'
     63         argument. overrides: List of symbols that this is overriding
     64         (those overrided api exports will be removed). Note: passing overrides
     65         has no effect on exporting a constant.
     66     """
     67     self._names = args
     68     self._overrides = kwargs.get('overrides', [])
     69 
     70   def __call__(self, func):
     71     """Calls this decorator.
     72 
     73     Args:
     74       func: decorated symbol (function or class).
     75 
     76     Returns:
     77       The input function with _tf_api_names attribute set.
     78 
     79     Raises:
     80       SymbolAlreadyExposedError: Raised when a symbol already has API names.
     81     """
     82     # Undecorate overridden names
     83     for f in self._overrides:
     84       _, undecorated_f = tf_decorator.unwrap(f)
     85       del undecorated_f._tf_api_names  # pylint: disable=protected-access
     86 
     87     _, undecorated_func = tf_decorator.unwrap(func)
     88 
     89     # Check for an existing api. We check if attribute name is in
     90     # __dict__ instead of using hasattr to verify that subclasses have
     91     # their own _tf_api_names as opposed to just inheriting it.
     92     if '_tf_api_names' in undecorated_func.__dict__:
     93       # pylint: disable=protected-access
     94       raise SymbolAlreadyExposedError(
     95           'Symbol %s is already exposed as %s.' %
     96           (undecorated_func.__name__, undecorated_func._tf_api_names))
     97       # pylint: enable=protected-access
     98 
     99     # Complete the export by creating/overriding attribute
    100     # pylint: disable=protected-access
    101     undecorated_func._tf_api_names = self._names
    102     # pylint: enable=protected-access
    103     return func
    104 
    105   def export_constant(self, module_name, name):
    106     """Store export information for constants/string literals.
    107 
    108     Export information is stored in the module where constants/string literals
    109     are defined.
    110 
    111     e.g.
    112     ```python
    113     foo = 1
    114     bar = 2
    115     tf_export("consts.foo").export_constant(__name__, 'foo')
    116     tf_export("consts.bar").export_constant(__name__, 'bar')
    117     ```
    118 
    119     Args:
    120       module_name: (string) Name of the module to store constant at.
    121       name: (string) Current constant name.
    122     """
    123     module = sys.modules[module_name]
    124     if not hasattr(module, '_tf_api_constants'):
    125       module._tf_api_constants = []  # pylint: disable=protected-access
    126     # pylint: disable=protected-access
    127     module._tf_api_constants.append((self._names, name))
    128 
    129