Home | History | Annotate | Download | only in tools
      1 #!/usr/bin/env python
      2 
      3 import os
      4 import re
      5 import sys
      6 
      7 def fail_with_usage():
      8   sys.stderr.write("usage: java-layers.py DEPENDENCY_FILE SOURCE_DIRECTORIES...\n")
      9   sys.stderr.write("\n")
     10   sys.stderr.write("Enforces layering between java packages.  Scans\n")
     11   sys.stderr.write("DIRECTORY and prints errors when the packages violate\n")
     12   sys.stderr.write("the rules defined in the DEPENDENCY_FILE.\n")
     13   sys.stderr.write("\n")
     14   sys.stderr.write("Prints a warning when an unknown package is encountered\n")
     15   sys.stderr.write("on the assumption that it should fit somewhere into the\n")
     16   sys.stderr.write("layering.\n")
     17   sys.stderr.write("\n")
     18   sys.stderr.write("DEPENDENCY_FILE format\n")
     19   sys.stderr.write("  - # starts comment\n")
     20   sys.stderr.write("  - Lines consisting of two java package names:  The\n")
     21   sys.stderr.write("    first package listed must not contain any references\n")
     22   sys.stderr.write("    to any classes present in the second package, or any\n")
     23   sys.stderr.write("    of its dependencies.\n")
     24   sys.stderr.write("  - Lines consisting of one java package name:  The\n")
     25   sys.stderr.write("    packge is assumed to be a high level package and\n")
     26   sys.stderr.write("    nothing may depend on it.\n")
     27   sys.stderr.write("  - Lines consisting of a dash (+) followed by one java\n")
     28   sys.stderr.write("    package name: The package is considered a low level\n")
     29   sys.stderr.write("    package and may not import any of the other packages\n")
     30   sys.stderr.write("    listed in the dependency file.\n")
     31   sys.stderr.write("  - Lines consisting of a plus (-) followed by one java\n")
     32   sys.stderr.write("    package name: The package is considered \'legacy\'\n")
     33   sys.stderr.write("    and excluded from errors.\n")
     34   sys.stderr.write("\n")
     35   sys.exit(1)
     36 
     37 class Dependency:
     38   def __init__(self, filename, lineno, lower, top, lowlevel, legacy):
     39     self.filename = filename
     40     self.lineno = lineno
     41     self.lower = lower
     42     self.top = top
     43     self.lowlevel = lowlevel
     44     self.legacy = legacy
     45     self.uppers = []
     46     self.transitive = set()
     47 
     48   def matches(self, imp):
     49     for d in self.transitive:
     50       if imp.startswith(d):
     51         return True
     52     return False
     53 
     54 class Dependencies:
     55   def __init__(self, deps):
     56     def recurse(obj, dep, visited):
     57       global err
     58       if dep in visited:
     59         sys.stderr.write("%s:%d: Circular dependency found:\n"
     60             % (dep.filename, dep.lineno))
     61         for v in visited:
     62           sys.stderr.write("%s:%d:    Dependency: %s\n"
     63               % (v.filename, v.lineno, v.lower))
     64         err = True
     65         return
     66       visited.append(dep)
     67       for upper in dep.uppers:
     68         obj.transitive.add(upper)
     69         if upper in deps:
     70           recurse(obj, deps[upper], visited)
     71     self.deps = deps
     72     self.parts = [(dep.lower.split('.'),dep) for dep in deps.itervalues()]
     73     # transitive closure of dependencies
     74     for dep in deps.itervalues():
     75       recurse(dep, dep, [])
     76     # disallow everything from the low level components
     77     for dep in deps.itervalues():
     78       if dep.lowlevel:
     79         for d in deps.itervalues():
     80           if dep != d and not d.legacy:
     81             dep.transitive.add(d.lower)
     82     # disallow the 'top' components everywhere but in their own package
     83     for dep in deps.itervalues():
     84       if dep.top and not dep.legacy:
     85         for d in deps.itervalues():
     86           if dep != d and not d.legacy:
     87             d.transitive.add(dep.lower)
     88     for dep in deps.itervalues():
     89       dep.transitive = set([x+"." for x in dep.transitive])
     90     if False:
     91       for dep in deps.itervalues():
     92         print "-->", dep.lower, "-->", dep.transitive
     93 
     94   # Lookup the dep object for the given package.  If pkg is a subpackage
     95   # of one with a rule, that one will be returned.  If no matches are found,
     96   # None is returned.
     97   def lookup(self, pkg):
     98     # Returns the number of parts that match
     99     def compare_parts(parts, pkg):
    100       if len(parts) > len(pkg):
    101         return 0
    102       n = 0
    103       for i in range(0, len(parts)):
    104         if parts[i] != pkg[i]:
    105           return 0
    106         n = n + 1
    107       return n
    108     pkg = pkg.split(".")
    109     matched = 0
    110     result = None
    111     for (parts,dep) in self.parts:
    112       x = compare_parts(parts, pkg)
    113       if x > matched:
    114         matched = x
    115         result = dep
    116     return result
    117 
    118 def parse_dependency_file(filename):
    119   global err
    120   f = file(filename)
    121   lines = f.readlines()
    122   f.close()
    123   def lineno(s, i):
    124     i[0] = i[0] + 1
    125     return (i[0],s)
    126   n = [0]
    127   lines = [lineno(x,n) for x in lines]
    128   lines = [(n,s.split("#")[0].strip()) for (n,s) in lines]
    129   lines = [(n,s) for (n,s) in lines if len(s) > 0]
    130   lines = [(n,s.split()) for (n,s) in lines]
    131   deps = {}
    132   for n,words in lines:
    133     if len(words) == 1:
    134       lower = words[0]
    135       top = True
    136       legacy = False
    137       lowlevel = False
    138       if lower[0] == '+':
    139         lower = lower[1:]
    140         top = False
    141         lowlevel = True
    142       elif lower[0] == '-':
    143         lower = lower[1:]
    144         legacy = True
    145       if lower in deps:
    146         sys.stderr.write(("%s:%d: Package '%s' already defined on"
    147             + " line %d.\n") % (filename, n, lower, deps[lower].lineno))
    148         err = True
    149       else:
    150         deps[lower] = Dependency(filename, n, lower, top, lowlevel, legacy)
    151     elif len(words) == 2:
    152       lower = words[0]
    153       upper = words[1]
    154       if lower in deps:
    155         dep = deps[lower]
    156         if dep.top:
    157           sys.stderr.write(("%s:%d: Can't add dependency to top level package "
    158             + "'%s'\n") % (filename, n, lower))
    159           err = True
    160       else:
    161         dep = Dependency(filename, n, lower, False, False, False)
    162         deps[lower] = dep
    163       dep.uppers.append(upper)
    164     else:
    165       sys.stderr.write("%s:%d: Too many words on line starting at \'%s\'\n" % (
    166           filename, n, words[2]))
    167       err = True
    168   return Dependencies(deps)
    169 
    170 def find_java_files(srcs):
    171   result = []
    172   for d in srcs:
    173     if d[0] == '@':
    174       f = file(d[1:])
    175       result.extend([fn for fn in [s.strip() for s in f.readlines()]
    176           if len(fn) != 0])
    177       f.close()
    178     else:
    179       for root, dirs, files in os.walk(d):
    180         result.extend([os.sep.join((root,f)) for f in files
    181             if f.lower().endswith(".java")])
    182   return result
    183 
    184 COMMENTS = re.compile("//.*?\n|/\*.*?\*/", re.S)
    185 PACKAGE = re.compile("package\s+(.*)")
    186 IMPORT = re.compile("import\s+(.*)")
    187 
    188 def examine_java_file(deps, filename):
    189   global err
    190   # Yes, this is a crappy java parser.  Write a better one if you want to.
    191   f = file(filename)
    192   text = f.read()
    193   f.close()
    194   text = COMMENTS.sub("", text)
    195   index = text.find("{")
    196   if index < 0:
    197     sys.stderr.write(("%s: Error: Unable to parse java. Can't find class "
    198         + "declaration.\n") % filename)
    199     err = True
    200     return
    201   text = text[0:index]
    202   statements = [s.strip() for s in text.split(";")]
    203   # First comes the package declaration.  Then iterate while we see import
    204   # statements.  Anything else is either bad syntax that we don't care about
    205   # because the compiler will fail, or the beginning of the class declaration.
    206   m = PACKAGE.match(statements[0])
    207   if not m:
    208     sys.stderr.write(("%s: Error: Unable to parse java. Missing package "
    209         + "statement.\n") % filename)
    210     err = True
    211     return
    212   pkg = m.group(1)
    213   imports = []
    214   for statement in statements[1:]:
    215     m = IMPORT.match(statement)
    216     if not m:
    217       break
    218     imports.append(m.group(1))
    219   # Do the checking
    220   if False:
    221     print filename
    222     print "'%s' --> %s" % (pkg, imports)
    223   dep = deps.lookup(pkg)
    224   if not dep:
    225     sys.stderr.write(("%s: Error: Package does not appear in dependency file: "
    226       + "%s\n") % (filename, pkg))
    227     err = True
    228     return
    229   for imp in imports:
    230     if dep.matches(imp):
    231       sys.stderr.write("%s: Illegal import in package '%s' of '%s'\n"
    232           % (filename, pkg, imp))
    233       err = True
    234 
    235 err = False
    236 
    237 def main(argv):
    238   if len(argv) < 3:
    239     fail_with_usage()
    240   deps = parse_dependency_file(argv[1])
    241 
    242   if err:
    243     sys.exit(1)
    244 
    245   java = find_java_files(argv[2:])
    246   for filename in java:
    247     examine_java_file(deps, filename)
    248 
    249   if err:
    250     sys.stderr.write("%s: Using this file as dependency file.\n" % argv[1])
    251     sys.exit(1)
    252 
    253   sys.exit(0)
    254 
    255 if __name__ == "__main__":
    256   main(sys.argv)
    257 
    258