Home | History | Annotate | Download | only in py_utils
      1 # Copyright 2012 The Chromium Authors. All rights reserved.
      2 # Use of this source code is governed by a BSD-style license that can be
      3 # found in the LICENSE file.
      4 
      5 import fnmatch
      6 import importlib
      7 import inspect
      8 import os
      9 import re
     10 import sys
     11 
     12 from py_utils import camel_case
     13 
     14 
     15 def DiscoverModules(start_dir, top_level_dir, pattern='*'):
     16   """Discover all modules in |start_dir| which match |pattern|.
     17 
     18   Args:
     19     start_dir: The directory to recursively search.
     20     top_level_dir: The top level of the package, for importing.
     21     pattern: Unix shell-style pattern for filtering the filenames to import.
     22 
     23   Returns:
     24     list of modules.
     25   """
     26   # start_dir and top_level_dir must be consistent with each other.
     27   start_dir = os.path.realpath(start_dir)
     28   top_level_dir = os.path.realpath(top_level_dir)
     29 
     30   modules = []
     31   sub_paths = list(os.walk(start_dir))
     32   # We sort the directories & file paths to ensure a deterministic ordering when
     33   # traversing |top_level_dir|.
     34   sub_paths.sort(key=lambda paths_tuple: paths_tuple[0])
     35   for dir_path, _, filenames in sub_paths:
     36     # Sort the directories to walk recursively by the directory path.
     37     filenames.sort()
     38     for filename in filenames:
     39       # Filter out unwanted filenames.
     40       if filename.startswith('.') or filename.startswith('_'):
     41         continue
     42       if os.path.splitext(filename)[1] != '.py':
     43         continue
     44       if not fnmatch.fnmatch(filename, pattern):
     45         continue
     46 
     47       # Find the module.
     48       module_rel_path = os.path.relpath(
     49           os.path.join(dir_path, filename), top_level_dir)
     50       module_name = re.sub(r'[/\\]', '.', os.path.splitext(module_rel_path)[0])
     51 
     52       # Import the module.
     53       try:
     54         # Make sure that top_level_dir is the first path in the sys.path in case
     55         # there are naming conflict in module parts.
     56         original_sys_path = sys.path[:]
     57         sys.path.insert(0, top_level_dir)
     58         module = importlib.import_module(module_name)
     59         modules.append(module)
     60       finally:
     61         sys.path = original_sys_path
     62   return modules
     63 
     64 
     65 def AssertNoKeyConflicts(classes_by_key_1, classes_by_key_2):
     66   for k in classes_by_key_1:
     67     if k in classes_by_key_2:
     68       assert classes_by_key_1[k] is classes_by_key_2[k], (
     69           'Found conflicting classes for the same key: '
     70           'key=%s, class_1=%s, class_2=%s' % (
     71               k, classes_by_key_1[k], classes_by_key_2[k]))
     72 
     73 
     74 # TODO(dtu): Normalize all discoverable classes to have corresponding module
     75 # and class names, then always index by class name.
     76 def DiscoverClasses(start_dir,
     77                     top_level_dir,
     78                     base_class,
     79                     pattern='*',
     80                     index_by_class_name=True,
     81                     directly_constructable=False):
     82   """Discover all classes in |start_dir| which subclass |base_class|.
     83 
     84   Base classes that contain subclasses are ignored by default.
     85 
     86   Args:
     87     start_dir: The directory to recursively search.
     88     top_level_dir: The top level of the package, for importing.
     89     base_class: The base class to search for.
     90     pattern: Unix shell-style pattern for filtering the filenames to import.
     91     index_by_class_name: If True, use class name converted to
     92         lowercase_with_underscores instead of module name in return dict keys.
     93     directly_constructable: If True, will only return classes that can be
     94         constructed without arguments
     95 
     96   Returns:
     97     dict of {module_name: class} or {underscored_class_name: class}
     98   """
     99   modules = DiscoverModules(start_dir, top_level_dir, pattern)
    100   classes = {}
    101   for module in modules:
    102     new_classes = DiscoverClassesInModule(
    103         module, base_class, index_by_class_name, directly_constructable)
    104     # TODO(nednguyen): we should remove index_by_class_name once
    105     # benchmark_smoke_unittest in chromium/src/tools/perf no longer relied
    106     # naming collisions to reduce the number of smoked benchmark tests.
    107     # crbug.com/548652
    108     if index_by_class_name:
    109       AssertNoKeyConflicts(classes, new_classes)
    110     classes = dict(classes.items() + new_classes.items())
    111   return classes
    112 
    113 
    114 # TODO(nednguyen): we should remove index_by_class_name once
    115 # benchmark_smoke_unittest in chromium/src/tools/perf no longer relied
    116 # naming collisions to reduce the number of smoked benchmark tests.
    117 # crbug.com/548652
    118 def DiscoverClassesInModule(module,
    119                             base_class,
    120                             index_by_class_name=False,
    121                             directly_constructable=False):
    122   """Discover all classes in |module| which subclass |base_class|.
    123 
    124   Base classes that contain subclasses are ignored by default.
    125 
    126   Args:
    127     module: The module to search.
    128     base_class: The base class to search for.
    129     index_by_class_name: If True, use class name converted to
    130         lowercase_with_underscores instead of module name in return dict keys.
    131 
    132   Returns:
    133     dict of {module_name: class} or {underscored_class_name: class}
    134   """
    135   classes = {}
    136   for _, obj in inspect.getmembers(module):
    137     # Ensure object is a class.
    138     if not inspect.isclass(obj):
    139       continue
    140     # Include only subclasses of base_class.
    141     if not issubclass(obj, base_class):
    142       continue
    143     # Exclude the base_class itself.
    144     if obj is base_class:
    145       continue
    146     # Exclude protected or private classes.
    147     if obj.__name__.startswith('_'):
    148       continue
    149     # Include only the module in which the class is defined.
    150     # If a class is imported by another module, exclude those duplicates.
    151     if obj.__module__ != module.__name__:
    152       continue
    153 
    154     if index_by_class_name:
    155       key_name = camel_case.ToUnderscore(obj.__name__)
    156     else:
    157       key_name = module.__name__.split('.')[-1]
    158     if not directly_constructable or IsDirectlyConstructable(obj):
    159       if key_name in classes and index_by_class_name:
    160         assert classes[key_name] is obj, (
    161             'Duplicate key_name with different objs detected: '
    162             'key=%s, obj1=%s, obj2=%s' % (key_name, classes[key_name], obj))
    163       else:
    164         classes[key_name] = obj
    165 
    166   return classes
    167 
    168 
    169 def IsDirectlyConstructable(cls):
    170   """Returns True if instance of |cls| can be construct without arguments."""
    171   assert inspect.isclass(cls)
    172   if not hasattr(cls, '__init__'):
    173     # Case |class A: pass|.
    174     return True
    175   if cls.__init__ is object.__init__:
    176     # Case |class A(object): pass|.
    177     return True
    178   # Case |class (object):| with |__init__| other than |object.__init__|.
    179   args, _, _, defaults = inspect.getargspec(cls.__init__)
    180   if defaults is None:
    181     defaults = ()
    182   # Return true if |self| is only arg without a default.
    183   return len(args) == len(defaults) + 1
    184 
    185 
    186 _COUNTER = [0]
    187 
    188 
    189 def _GetUniqueModuleName():
    190   _COUNTER[0] += 1
    191   return "module_" + str(_COUNTER[0])
    192