Home | History | Annotate | Download | only in lib2to3
      1 # Copyright 2006 Google, Inc. All Rights Reserved.
      2 # Licensed to PSF under a Contributor Agreement.
      3 
      4 """Refactoring framework.
      5 
      6 Used as a main program, this can refactor any number of files and/or
      7 recursively descend down directories.  Imported as a module, this
      8 provides infrastructure to write your own refactoring tool.
      9 """
     10 
     11 from __future__ import with_statement
     12 
     13 __author__ = "Guido van Rossum <guido (at] python.org>"
     14 
     15 
     16 # Python imports
     17 import os
     18 import sys
     19 import logging
     20 import operator
     21 import collections
     22 import StringIO
     23 from itertools import chain
     24 
     25 # Local imports
     26 from .pgen2 import driver, tokenize, token
     27 from .fixer_util import find_root
     28 from . import pytree, pygram
     29 from . import btm_utils as bu
     30 from . import btm_matcher as bm
     31 
     32 
     33 def get_all_fix_names(fixer_pkg, remove_prefix=True):
     34     """Return a sorted list of all available fix names in the given package."""
     35     pkg = __import__(fixer_pkg, [], [], ["*"])
     36     fixer_dir = os.path.dirname(pkg.__file__)
     37     fix_names = []
     38     for name in sorted(os.listdir(fixer_dir)):
     39         if name.startswith("fix_") and name.endswith(".py"):
     40             if remove_prefix:
     41                 name = name[4:]
     42             fix_names.append(name[:-3])
     43     return fix_names
     44 
     45 
     46 class _EveryNode(Exception):
     47     pass
     48 
     49 
     50 def _get_head_types(pat):
     51     """ Accepts a pytree Pattern Node and returns a set
     52         of the pattern types which will match first. """
     53 
     54     if isinstance(pat, (pytree.NodePattern, pytree.LeafPattern)):
     55         # NodePatters must either have no type and no content
     56         #   or a type and content -- so they don't get any farther
     57         # Always return leafs
     58         if pat.type is None:
     59             raise _EveryNode
     60         return set([pat.type])
     61 
     62     if isinstance(pat, pytree.NegatedPattern):
     63         if pat.content:
     64             return _get_head_types(pat.content)
     65         raise _EveryNode # Negated Patterns don't have a type
     66 
     67     if isinstance(pat, pytree.WildcardPattern):
     68         # Recurse on each node in content
     69         r = set()
     70         for p in pat.content:
     71             for x in p:
     72                 r.update(_get_head_types(x))
     73         return r
     74 
     75     raise Exception("Oh no! I don't understand pattern %s" %(pat))
     76 
     77 
     78 def _get_headnode_dict(fixer_list):
     79     """ Accepts a list of fixers and returns a dictionary
     80         of head node type --> fixer list.  """
     81     head_nodes = collections.defaultdict(list)
     82     every = []
     83     for fixer in fixer_list:
     84         if fixer.pattern:
     85             try:
     86                 heads = _get_head_types(fixer.pattern)
     87             except _EveryNode:
     88                 every.append(fixer)
     89             else:
     90                 for node_type in heads:
     91                     head_nodes[node_type].append(fixer)
     92         else:
     93             if fixer._accept_type is not None:
     94                 head_nodes[fixer._accept_type].append(fixer)
     95             else:
     96                 every.append(fixer)
     97     for node_type in chain(pygram.python_grammar.symbol2number.itervalues(),
     98                            pygram.python_grammar.tokens):
     99         head_nodes[node_type].extend(every)
    100     return dict(head_nodes)
    101 
    102 
    103 def get_fixers_from_package(pkg_name):
    104     """
    105     Return the fully qualified names for fixers in the package pkg_name.
    106     """
    107     return [pkg_name + "." + fix_name
    108             for fix_name in get_all_fix_names(pkg_name, False)]
    109 
    110 def _identity(obj):
    111     return obj
    112 
    113 if sys.version_info < (3, 0):
    114     import codecs
    115     _open_with_encoding = codecs.open
    116     # codecs.open doesn't translate newlines sadly.
    117     def _from_system_newlines(input):
    118         return input.replace(u"\r\n", u"\n")
    119     def _to_system_newlines(input):
    120         if os.linesep != "\n":
    121             return input.replace(u"\n", os.linesep)
    122         else:
    123             return input
    124 else:
    125     _open_with_encoding = open
    126     _from_system_newlines = _identity
    127     _to_system_newlines = _identity
    128 
    129 
    130 def _detect_future_features(source):
    131     have_docstring = False
    132     gen = tokenize.generate_tokens(StringIO.StringIO(source).readline)
    133     def advance():
    134         tok = gen.next()
    135         return tok[0], tok[1]
    136     ignore = frozenset((token.NEWLINE, tokenize.NL, token.COMMENT))
    137     features = set()
    138     try:
    139         while True:
    140             tp, value = advance()
    141             if tp in ignore:
    142                 continue
    143             elif tp == token.STRING:
    144                 if have_docstring:
    145                     break
    146                 have_docstring = True
    147             elif tp == token.NAME and value == u"from":
    148                 tp, value = advance()
    149                 if tp != token.NAME or value != u"__future__":
    150                     break
    151                 tp, value = advance()
    152                 if tp != token.NAME or value != u"import":
    153                     break
    154                 tp, value = advance()
    155                 if tp == token.OP and value == u"(":
    156                     tp, value = advance()
    157                 while tp == token.NAME:
    158                     features.add(value)
    159                     tp, value = advance()
    160                     if tp != token.OP or value != u",":
    161                         break
    162                     tp, value = advance()
    163             else:
    164                 break
    165     except StopIteration:
    166         pass
    167     return frozenset(features)
    168 
    169 
    170 class FixerError(Exception):
    171     """A fixer could not be loaded."""
    172 
    173 
    174 class RefactoringTool(object):
    175 
    176     _default_options = {"print_function" : False,
    177                         "write_unchanged_files" : False}
    178 
    179     CLASS_PREFIX = "Fix" # The prefix for fixer classes
    180     FILE_PREFIX = "fix_" # The prefix for modules with a fixer within
    181 
    182     def __init__(self, fixer_names, options=None, explicit=None):
    183         """Initializer.
    184 
    185         Args:
    186             fixer_names: a list of fixers to import
    187             options: an dict with configuration.
    188             explicit: a list of fixers to run even if they are explicit.
    189         """
    190         self.fixers = fixer_names
    191         self.explicit = explicit or []
    192         self.options = self._default_options.copy()
    193         if options is not None:
    194             self.options.update(options)
    195         if self.options["print_function"]:
    196             self.grammar = pygram.python_grammar_no_print_statement
    197         else:
    198             self.grammar = pygram.python_grammar
    199         # When this is True, the refactor*() methods will call write_file() for
    200         # files processed even if they were not changed during refactoring. If
    201         # and only if the refactor method's write parameter was True.
    202         self.write_unchanged_files = self.options.get("write_unchanged_files")
    203         self.errors = []
    204         self.logger = logging.getLogger("RefactoringTool")
    205         self.fixer_log = []
    206         self.wrote = False
    207         self.driver = driver.Driver(self.grammar,
    208                                     convert=pytree.convert,
    209                                     logger=self.logger)
    210         self.pre_order, self.post_order = self.get_fixers()
    211 
    212 
    213         self.files = []  # List of files that were or should be modified
    214 
    215         self.BM = bm.BottomMatcher()
    216         self.bmi_pre_order = [] # Bottom Matcher incompatible fixers
    217         self.bmi_post_order = []
    218 
    219         for fixer in chain(self.post_order, self.pre_order):
    220             if fixer.BM_compatible:
    221                 self.BM.add_fixer(fixer)
    222                 # remove fixers that will be handled by the bottom-up
    223                 # matcher
    224             elif fixer in self.pre_order:
    225                 self.bmi_pre_order.append(fixer)
    226             elif fixer in self.post_order:
    227                 self.bmi_post_order.append(fixer)
    228 
    229         self.bmi_pre_order_heads = _get_headnode_dict(self.bmi_pre_order)
    230         self.bmi_post_order_heads = _get_headnode_dict(self.bmi_post_order)
    231 
    232 
    233 
    234     def get_fixers(self):
    235         """Inspects the options to load the requested patterns and handlers.
    236 
    237         Returns:
    238           (pre_order, post_order), where pre_order is the list of fixers that
    239           want a pre-order AST traversal, and post_order is the list that want
    240           post-order traversal.
    241         """
    242         pre_order_fixers = []
    243         post_order_fixers = []
    244         for fix_mod_path in self.fixers:
    245             mod = __import__(fix_mod_path, {}, {}, ["*"])
    246             fix_name = fix_mod_path.rsplit(".", 1)[-1]
    247             if fix_name.startswith(self.FILE_PREFIX):
    248                 fix_name = fix_name[len(self.FILE_PREFIX):]
    249             parts = fix_name.split("_")
    250             class_name = self.CLASS_PREFIX + "".join([p.title() for p in parts])
    251             try:
    252                 fix_class = getattr(mod, class_name)
    253             except AttributeError:
    254                 raise FixerError("Can't find %s.%s" % (fix_name, class_name))
    255             fixer = fix_class(self.options, self.fixer_log)
    256             if fixer.explicit and self.explicit is not True and \
    257                     fix_mod_path not in self.explicit:
    258                 self.log_message("Skipping implicit fixer: %s", fix_name)
    259                 continue
    260 
    261             self.log_debug("Adding transformation: %s", fix_name)
    262             if fixer.order == "pre":
    263                 pre_order_fixers.append(fixer)
    264             elif fixer.order == "post":
    265                 post_order_fixers.append(fixer)
    266             else:
    267                 raise FixerError("Illegal fixer order: %r" % fixer.order)
    268 
    269         key_func = operator.attrgetter("run_order")
    270         pre_order_fixers.sort(key=key_func)
    271         post_order_fixers.sort(key=key_func)
    272         return (pre_order_fixers, post_order_fixers)
    273 
    274     def log_error(self, msg, *args, **kwds):
    275         """Called when an error occurs."""
    276         raise
    277 
    278     def log_message(self, msg, *args):
    279         """Hook to log a message."""
    280         if args:
    281             msg = msg % args
    282         self.logger.info(msg)
    283 
    284     def log_debug(self, msg, *args):
    285         if args:
    286             msg = msg % args
    287         self.logger.debug(msg)
    288 
    289     def print_output(self, old_text, new_text, filename, equal):
    290         """Called with the old version, new version, and filename of a
    291         refactored file."""
    292         pass
    293 
    294     def refactor(self, items, write=False, doctests_only=False):
    295         """Refactor a list of files and directories."""
    296 
    297         for dir_or_file in items:
    298             if os.path.isdir(dir_or_file):
    299                 self.refactor_dir(dir_or_file, write, doctests_only)
    300             else:
    301                 self.refactor_file(dir_or_file, write, doctests_only)
    302 
    303     def refactor_dir(self, dir_name, write=False, doctests_only=False):
    304         """Descends down a directory and refactor every Python file found.
    305 
    306         Python files are assumed to have a .py extension.
    307 
    308         Files and subdirectories starting with '.' are skipped.
    309         """
    310         py_ext = os.extsep + "py"
    311         for dirpath, dirnames, filenames in os.walk(dir_name):
    312             self.log_debug("Descending into %s", dirpath)
    313             dirnames.sort()
    314             filenames.sort()
    315             for name in filenames:
    316                 if (not name.startswith(".") and
    317                     os.path.splitext(name)[1] == py_ext):
    318                     fullname = os.path.join(dirpath, name)
    319                     self.refactor_file(fullname, write, doctests_only)
    320             # Modify dirnames in-place to remove subdirs with leading dots
    321             dirnames[:] = [dn for dn in dirnames if not dn.startswith(".")]
    322 
    323     def _read_python_source(self, filename):
    324         """
    325         Do our best to decode a Python source file correctly.
    326         """
    327         try:
    328             f = open(filename, "rb")
    329         except IOError as err:
    330             self.log_error("Can't open %s: %s", filename, err)
    331             return None, None
    332         try:
    333             encoding = tokenize.detect_encoding(f.readline)[0]
    334         finally:
    335             f.close()
    336         with _open_with_encoding(filename, "r", encoding=encoding) as f:
    337             return _from_system_newlines(f.read()), encoding
    338 
    339     def refactor_file(self, filename, write=False, doctests_only=False):
    340         """Refactors a file."""
    341         input, encoding = self._read_python_source(filename)
    342         if input is None:
    343             # Reading the file failed.
    344             return
    345         input += u"\n" # Silence certain parse errors
    346         if doctests_only:
    347             self.log_debug("Refactoring doctests in %s", filename)
    348             output = self.refactor_docstring(input, filename)
    349             if self.write_unchanged_files or output != input:
    350                 self.processed_file(output, filename, input, write, encoding)
    351             else:
    352                 self.log_debug("No doctest changes in %s", filename)
    353         else:
    354             tree = self.refactor_string(input, filename)
    355             if self.write_unchanged_files or (tree and tree.was_changed):
    356                 # The [:-1] is to take off the \n we added earlier
    357                 self.processed_file(unicode(tree)[:-1], filename,
    358                                     write=write, encoding=encoding)
    359             else:
    360                 self.log_debug("No changes in %s", filename)
    361 
    362     def refactor_string(self, data, name):
    363         """Refactor a given input string.
    364 
    365         Args:
    366             data: a string holding the code to be refactored.
    367             name: a human-readable name for use in error/log messages.
    368 
    369         Returns:
    370             An AST corresponding to the refactored input stream; None if
    371             there were errors during the parse.
    372         """
    373         features = _detect_future_features(data)
    374         if "print_function" in features:
    375             self.driver.grammar = pygram.python_grammar_no_print_statement
    376         try:
    377             tree = self.driver.parse_string(data)
    378         except Exception as err:
    379             self.log_error("Can't parse %s: %s: %s",
    380                            name, err.__class__.__name__, err)
    381             return
    382         finally:
    383             self.driver.grammar = self.grammar
    384         tree.future_features = features
    385         self.log_debug("Refactoring %s", name)
    386         self.refactor_tree(tree, name)
    387         return tree
    388 
    389     def refactor_stdin(self, doctests_only=False):
    390         input = sys.stdin.read()
    391         if doctests_only:
    392             self.log_debug("Refactoring doctests in stdin")
    393             output = self.refactor_docstring(input, "<stdin>")
    394             if self.write_unchanged_files or output != input:
    395                 self.processed_file(output, "<stdin>", input)
    396             else:
    397                 self.log_debug("No doctest changes in stdin")
    398         else:
    399             tree = self.refactor_string(input, "<stdin>")
    400             if self.write_unchanged_files or (tree and tree.was_changed):
    401                 self.processed_file(unicode(tree), "<stdin>", input)
    402             else:
    403                 self.log_debug("No changes in stdin")
    404 
    405     def refactor_tree(self, tree, name):
    406         """Refactors a parse tree (modifying the tree in place).
    407 
    408         For compatible patterns the bottom matcher module is
    409         used. Otherwise the tree is traversed node-to-node for
    410         matches.
    411 
    412         Args:
    413             tree: a pytree.Node instance representing the root of the tree
    414                   to be refactored.
    415             name: a human-readable name for this tree.
    416 
    417         Returns:
    418             True if the tree was modified, False otherwise.
    419         """
    420 
    421         for fixer in chain(self.pre_order, self.post_order):
    422             fixer.start_tree(tree, name)
    423 
    424         #use traditional matching for the incompatible fixers
    425         self.traverse_by(self.bmi_pre_order_heads, tree.pre_order())
    426         self.traverse_by(self.bmi_post_order_heads, tree.post_order())
    427 
    428         # obtain a set of candidate nodes
    429         match_set = self.BM.run(tree.leaves())
    430 
    431         while any(match_set.values()):
    432             for fixer in self.BM.fixers:
    433                 if fixer in match_set and match_set[fixer]:
    434                     #sort by depth; apply fixers from bottom(of the AST) to top
    435                     match_set[fixer].sort(key=pytree.Base.depth, reverse=True)
    436 
    437                     if fixer.keep_line_order:
    438                         #some fixers(eg fix_imports) must be applied
    439                         #with the original file's line order
    440                         match_set[fixer].sort(key=pytree.Base.get_lineno)
    441 
    442                     for node in list(match_set[fixer]):
    443                         if node in match_set[fixer]:
    444                             match_set[fixer].remove(node)
    445 
    446                         try:
    447                             find_root(node)
    448                         except ValueError:
    449                             # this node has been cut off from a
    450                             # previous transformation ; skip
    451                             continue
    452 
    453                         if node.fixers_applied and fixer in node.fixers_applied:
    454                             # do not apply the same fixer again
    455                             continue
    456 
    457                         results = fixer.match(node)
    458 
    459                         if results:
    460                             new = fixer.transform(node, results)
    461                             if new is not None:
    462                                 node.replace(new)
    463                                 #new.fixers_applied.append(fixer)
    464                                 for node in new.post_order():
    465                                     # do not apply the fixer again to
    466                                     # this or any subnode
    467                                     if not node.fixers_applied:
    468                                         node.fixers_applied = []
    469                                     node.fixers_applied.append(fixer)
    470 
    471                                 # update the original match set for
    472                                 # the added code
    473                                 new_matches = self.BM.run(new.leaves())
    474                                 for fxr in new_matches:
    475                                     if not fxr in match_set:
    476                                         match_set[fxr]=[]
    477 
    478                                     match_set[fxr].extend(new_matches[fxr])
    479 
    480         for fixer in chain(self.pre_order, self.post_order):
    481             fixer.finish_tree(tree, name)
    482         return tree.was_changed
    483 
    484     def traverse_by(self, fixers, traversal):
    485         """Traverse an AST, applying a set of fixers to each node.
    486 
    487         This is a helper method for refactor_tree().
    488 
    489         Args:
    490             fixers: a list of fixer instances.
    491             traversal: a generator that yields AST nodes.
    492 
    493         Returns:
    494             None
    495         """
    496         if not fixers:
    497             return
    498         for node in traversal:
    499             for fixer in fixers[node.type]:
    500                 results = fixer.match(node)
    501                 if results:
    502                     new = fixer.transform(node, results)
    503                     if new is not None:
    504                         node.replace(new)
    505                         node = new
    506 
    507     def processed_file(self, new_text, filename, old_text=None, write=False,
    508                        encoding=None):
    509         """
    510         Called when a file has been refactored and there may be changes.
    511         """
    512         self.files.append(filename)
    513         if old_text is None:
    514             old_text = self._read_python_source(filename)[0]
    515             if old_text is None:
    516                 return
    517         equal = old_text == new_text
    518         self.print_output(old_text, new_text, filename, equal)
    519         if equal:
    520             self.log_debug("No changes to %s", filename)
    521             if not self.write_unchanged_files:
    522                 return
    523         if write:
    524             self.write_file(new_text, filename, old_text, encoding)
    525         else:
    526             self.log_debug("Not writing changes to %s", filename)
    527 
    528     def write_file(self, new_text, filename, old_text, encoding=None):
    529         """Writes a string to a file.
    530 
    531         It first shows a unified diff between the old text and the new text, and
    532         then rewrites the file; the latter is only done if the write option is
    533         set.
    534         """
    535         try:
    536             f = _open_with_encoding(filename, "w", encoding=encoding)
    537         except os.error as err:
    538             self.log_error("Can't create %s: %s", filename, err)
    539             return
    540         try:
    541             f.write(_to_system_newlines(new_text))
    542         except os.error as err:
    543             self.log_error("Can't write %s: %s", filename, err)
    544         finally:
    545             f.close()
    546         self.log_debug("Wrote changes to %s", filename)
    547         self.wrote = True
    548 
    549     PS1 = ">>> "
    550     PS2 = "... "
    551 
    552     def refactor_docstring(self, input, filename):
    553         """Refactors a docstring, looking for doctests.
    554 
    555         This returns a modified version of the input string.  It looks
    556         for doctests, which start with a ">>>" prompt, and may be
    557         continued with "..." prompts, as long as the "..." is indented
    558         the same as the ">>>".
    559 
    560         (Unfortunately we can't use the doctest module's parser,
    561         since, like most parsers, it is not geared towards preserving
    562         the original source.)
    563         """
    564         result = []
    565         block = None
    566         block_lineno = None
    567         indent = None
    568         lineno = 0
    569         for line in input.splitlines(True):
    570             lineno += 1
    571             if line.lstrip().startswith(self.PS1):
    572                 if block is not None:
    573                     result.extend(self.refactor_doctest(block, block_lineno,
    574                                                         indent, filename))
    575                 block_lineno = lineno
    576                 block = [line]
    577                 i = line.find(self.PS1)
    578                 indent = line[:i]
    579             elif (indent is not None and
    580                   (line.startswith(indent + self.PS2) or
    581                    line == indent + self.PS2.rstrip() + u"\n")):
    582                 block.append(line)
    583             else:
    584                 if block is not None:
    585                     result.extend(self.refactor_doctest(block, block_lineno,
    586                                                         indent, filename))
    587                 block = None
    588                 indent = None
    589                 result.append(line)
    590         if block is not None:
    591             result.extend(self.refactor_doctest(block, block_lineno,
    592                                                 indent, filename))
    593         return u"".join(result)
    594 
    595     def refactor_doctest(self, block, lineno, indent, filename):
    596         """Refactors one doctest.
    597 
    598         A doctest is given as a block of lines, the first of which starts
    599         with ">>>" (possibly indented), while the remaining lines start
    600         with "..." (identically indented).
    601 
    602         """
    603         try:
    604             tree = self.parse_block(block, lineno, indent)
    605         except Exception as err:
    606             if self.logger.isEnabledFor(logging.DEBUG):
    607                 for line in block:
    608                     self.log_debug("Source: %s", line.rstrip(u"\n"))
    609             self.log_error("Can't parse docstring in %s line %s: %s: %s",
    610                            filename, lineno, err.__class__.__name__, err)
    611             return block
    612         if self.refactor_tree(tree, filename):
    613             new = unicode(tree).splitlines(True)
    614             # Undo the adjustment of the line numbers in wrap_toks() below.
    615             clipped, new = new[:lineno-1], new[lineno-1:]
    616             assert clipped == [u"\n"] * (lineno-1), clipped
    617             if not new[-1].endswith(u"\n"):
    618                 new[-1] += u"\n"
    619             block = [indent + self.PS1 + new.pop(0)]
    620             if new:
    621                 block += [indent + self.PS2 + line for line in new]
    622         return block
    623 
    624     def summarize(self):
    625         if self.wrote:
    626             were = "were"
    627         else:
    628             were = "need to be"
    629         if not self.files:
    630             self.log_message("No files %s modified.", were)
    631         else:
    632             self.log_message("Files that %s modified:", were)
    633             for file in self.files:
    634                 self.log_message(file)
    635         if self.fixer_log:
    636             self.log_message("Warnings/messages while refactoring:")
    637             for message in self.fixer_log:
    638                 self.log_message(message)
    639         if self.errors:
    640             if len(self.errors) == 1:
    641                 self.log_message("There was 1 error:")
    642             else:
    643                 self.log_message("There were %d errors:", len(self.errors))
    644             for msg, args, kwds in self.errors:
    645                 self.log_message(msg, *args, **kwds)
    646 
    647     def parse_block(self, block, lineno, indent):
    648         """Parses a block into a tree.
    649 
    650         This is necessary to get correct line number / offset information
    651         in the parser diagnostics and embedded into the parse tree.
    652         """
    653         tree = self.driver.parse_tokens(self.wrap_toks(block, lineno, indent))
    654         tree.future_features = frozenset()
    655         return tree
    656 
    657     def wrap_toks(self, block, lineno, indent):
    658         """Wraps a tokenize stream to systematically modify start/end."""
    659         tokens = tokenize.generate_tokens(self.gen_lines(block, indent).next)
    660         for type, value, (line0, col0), (line1, col1), line_text in tokens:
    661             line0 += lineno - 1
    662             line1 += lineno - 1
    663             # Don't bother updating the columns; this is too complicated
    664             # since line_text would also have to be updated and it would
    665             # still break for tokens spanning lines.  Let the user guess
    666             # that the column numbers for doctests are relative to the
    667             # end of the prompt string (PS1 or PS2).
    668             yield type, value, (line0, col0), (line1, col1), line_text
    669 
    670 
    671     def gen_lines(self, block, indent):
    672         """Generates lines as expected by tokenize from a list of lines.
    673 
    674         This strips the first len(indent + self.PS1) characters off each line.
    675         """
    676         prefix1 = indent + self.PS1
    677         prefix2 = indent + self.PS2
    678         prefix = prefix1
    679         for line in block:
    680             if line.startswith(prefix):
    681                 yield line[len(prefix):]
    682             elif line == prefix.rstrip() + u"\n":
    683                 yield u"\n"
    684             else:
    685                 raise AssertionError("line=%r, prefix=%r" % (line, prefix))
    686             prefix = prefix2
    687         while True:
    688             yield ""
    689 
    690 
    691 class MultiprocessingUnsupported(Exception):
    692     pass
    693 
    694 
    695 class MultiprocessRefactoringTool(RefactoringTool):
    696 
    697     def __init__(self, *args, **kwargs):
    698         super(MultiprocessRefactoringTool, self).__init__(*args, **kwargs)
    699         self.queue = None
    700         self.output_lock = None
    701 
    702     def refactor(self, items, write=False, doctests_only=False,
    703                  num_processes=1):
    704         if num_processes == 1:
    705             return super(MultiprocessRefactoringTool, self).refactor(
    706                 items, write, doctests_only)
    707         try:
    708             import multiprocessing
    709         except ImportError:
    710             raise MultiprocessingUnsupported
    711         if self.queue is not None:
    712             raise RuntimeError("already doing multiple processes")
    713         self.queue = multiprocessing.JoinableQueue()
    714         self.output_lock = multiprocessing.Lock()
    715         processes = [multiprocessing.Process(target=self._child)
    716                      for i in xrange(num_processes)]
    717         try:
    718             for p in processes:
    719                 p.start()
    720             super(MultiprocessRefactoringTool, self).refactor(items, write,
    721                                                               doctests_only)
    722         finally:
    723             self.queue.join()
    724             for i in xrange(num_processes):
    725                 self.queue.put(None)
    726             for p in processes:
    727                 if p.is_alive():
    728                     p.join()
    729             self.queue = None
    730 
    731     def _child(self):
    732         task = self.queue.get()
    733         while task is not None:
    734             args, kwargs = task
    735             try:
    736                 super(MultiprocessRefactoringTool, self).refactor_file(
    737                     *args, **kwargs)
    738             finally:
    739                 self.queue.task_done()
    740             task = self.queue.get()
    741 
    742     def refactor_file(self, *args, **kwargs):
    743         if self.queue is not None:
    744             self.queue.put((args, kwargs))
    745         else:
    746             return super(MultiprocessRefactoringTool, self).refactor_file(
    747                 *args, **kwargs)
    748