Home | History | Annotate | Download | only in util
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Generate __all__ from a module docstring."""
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import re as _re
     22 import sys as _sys
     23 
     24 from tensorflow.python.util import tf_inspect as _tf_inspect
     25 
     26 
     27 _reference_pattern = _re.compile(r'^@@(\w+)$', flags=_re.MULTILINE)
     28 
     29 
     30 def make_all(module_name, doc_string_modules=None):
     31   """Generates `__all__` from the docstring of one or more modules.
     32 
     33   Usage: `make_all(__name__)` or
     34   `make_all(__name__, [sys.modules(__name__), other_module])`. The doc string
     35   modules must each a docstring, and `__all__` will contain all symbols with
     36   `@@` references, where that symbol currently exists in the module named
     37   `module_name`.
     38 
     39   Args:
     40     module_name: The name of the module (usually `__name__`).
     41     doc_string_modules: a list of modules from which to take docstring.
     42     If None, then a list containing only the module named `module_name` is used.
     43 
     44   Returns:
     45     A list suitable for use as `__all__`.
     46   """
     47   if doc_string_modules is None:
     48     doc_string_modules = [_sys.modules[module_name]]
     49   cur_members = set([name for name, _
     50                      in _tf_inspect.getmembers(_sys.modules[module_name])])
     51 
     52   results = set()
     53   for doc_module in doc_string_modules:
     54     results.update([m.group(1)
     55                     for m in _reference_pattern.finditer(doc_module.__doc__)
     56                     if m.group(1) in cur_members])
     57   return list(results)
     58 
     59 # Hidden attributes are attributes that have been hidden by
     60 # `remove_undocumented`. They can be re-instated by `reveal_undocumented`.
     61 # This maps symbol names to a tuple, containing:
     62 #   (module object, attribute value)
     63 _HIDDEN_ATTRIBUTES = {}
     64 
     65 
     66 def reveal_undocumented(symbol_name, target_module=None):
     67   """Reveals a symbol that was previously removed by `remove_undocumented`.
     68 
     69   This should be used by tensorflow internal tests only. It explicitly
     70   defeats the encapsulation afforded by `remove_undocumented`.
     71 
     72   It throws an exception when the symbol was not hidden in the first place.
     73 
     74   Args:
     75     symbol_name: a string representing the full absolute path of the symbol.
     76     target_module: if specified, the module in which to restore the symbol.
     77   """
     78   if symbol_name not in _HIDDEN_ATTRIBUTES:
     79     raise LookupError('Symbol %s is not a hidden symbol' % symbol_name)
     80   symbol_basename = symbol_name.split('.')[-1]
     81   (original_module, attr_value) = _HIDDEN_ATTRIBUTES[symbol_name]
     82   if not target_module: target_module = original_module
     83   setattr(target_module, symbol_basename, attr_value)
     84 
     85 
     86 def remove_undocumented(module_name, allowed_exception_list=None,
     87                         doc_string_modules=None):
     88   """Removes symbols in a module that are not referenced by a docstring.
     89 
     90   Args:
     91     module_name: the name of the module (usually `__name__`).
     92     allowed_exception_list: a list of names that should not be removed.
     93     doc_string_modules: a list of modules from which to take the docstrings.
     94     If None, then a list containing only the module named `module_name` is used.
     95 
     96     Furthermore, if a symbol previously added with `add_to_global_whitelist`,
     97     then it will always be allowed. This is useful for internal tests.
     98 
     99   Returns:
    100     None
    101   """
    102   current_symbols = set(dir(_sys.modules[module_name]))
    103   should_have = make_all(module_name, doc_string_modules)
    104   should_have += allowed_exception_list or []
    105   extra_symbols = current_symbols - set(should_have)
    106   target_module = _sys.modules[module_name]
    107   for extra_symbol in extra_symbols:
    108     # Skip over __file__, etc. Also preserves internal symbols.
    109     if extra_symbol.startswith('_'): continue
    110     fully_qualified_name = module_name + '.' + extra_symbol
    111     _HIDDEN_ATTRIBUTES[fully_qualified_name] = (target_module,
    112                                                 getattr(target_module,
    113                                                         extra_symbol))
    114     delattr(target_module, extra_symbol)
    115 
    116 
    117 __all__ = [
    118     'make_all',
    119     'remove_undocumented',
    120     'reveal_undocumented',
    121 ]
    122