Home | History | Annotate | Download | only in utils
      1 # Copyright 2013 The Chromium Authors. All rights reserved.
      2 # Use of this source code is governed by a BSD-style license that can be
      3 # found in the LICENSE file.
      4 
      5 """Thread and ThreadGroup that reraise exceptions on the main thread."""
      6 # pylint: disable=W0212
      7 
      8 import logging
      9 import sys
     10 import threading
     11 import time
     12 import traceback
     13 
     14 from devil.utils import watchdog_timer
     15 
     16 
     17 class TimeoutError(Exception):
     18   """Module-specific timeout exception."""
     19   pass
     20 
     21 
     22 def LogThreadStack(thread, error_log_func=logging.critical):
     23   """Log the stack for the given thread.
     24 
     25   Args:
     26     thread: a threading.Thread instance.
     27     error_log_func: Logging function when logging errors.
     28   """
     29   stack = sys._current_frames()[thread.ident]
     30   error_log_func('*' * 80)
     31   error_log_func('Stack dump for thread %r', thread.name)
     32   error_log_func('*' * 80)
     33   for filename, lineno, name, line in traceback.extract_stack(stack):
     34     error_log_func('File: "%s", line %d, in %s', filename, lineno, name)
     35     if line:
     36       error_log_func('  %s', line.strip())
     37   error_log_func('*' * 80)
     38 
     39 
     40 class ReraiserThread(threading.Thread):
     41   """Thread class that can reraise exceptions."""
     42 
     43   def __init__(self, func, args=None, kwargs=None, name=None):
     44     """Initialize thread.
     45 
     46     Args:
     47       func: callable to call on a new thread.
     48       args: list of positional arguments for callable, defaults to empty.
     49       kwargs: dictionary of keyword arguments for callable, defaults to empty.
     50       name: thread name, defaults to Thread-N.
     51     """
     52     if not name and func.__name__ != '<lambda>':
     53       name = func.__name__
     54     super(ReraiserThread, self).__init__(name=name)
     55     if not args:
     56       args = []
     57     if not kwargs:
     58       kwargs = {}
     59     self.daemon = True
     60     self._func = func
     61     self._args = args
     62     self._kwargs = kwargs
     63     self._ret = None
     64     self._exc_info = None
     65     self._thread_group = None
     66 
     67   def ReraiseIfException(self):
     68     """Reraise exception if an exception was raised in the thread."""
     69     if self._exc_info:
     70       raise self._exc_info[0], self._exc_info[1], self._exc_info[2]
     71 
     72   def GetReturnValue(self):
     73     """Reraise exception if present, otherwise get the return value."""
     74     self.ReraiseIfException()
     75     return self._ret
     76 
     77   # override
     78   def run(self):
     79     """Overrides Thread.run() to add support for reraising exceptions."""
     80     try:
     81       self._ret = self._func(*self._args, **self._kwargs)
     82     except:  # pylint: disable=W0702
     83       self._exc_info = sys.exc_info()
     84 
     85 
     86 class ReraiserThreadGroup(object):
     87   """A group of ReraiserThread objects."""
     88 
     89   def __init__(self, threads=None):
     90     """Initialize thread group.
     91 
     92     Args:
     93       threads: a list of ReraiserThread objects; defaults to empty.
     94     """
     95     self._threads = []
     96     # Set when a thread from one group has called JoinAll on another. It is used
     97     # to detect when a there is a TimeoutRetryThread active that links to the
     98     # current thread.
     99     self.blocked_parent_thread_group = None
    100     if threads:
    101       for thread in threads:
    102         self.Add(thread)
    103 
    104   def Add(self, thread):
    105     """Add a thread to the group.
    106 
    107     Args:
    108       thread: a ReraiserThread object.
    109     """
    110     assert thread._thread_group is None
    111     thread._thread_group = self
    112     self._threads.append(thread)
    113 
    114   def StartAll(self, will_block=False):
    115     """Start all threads.
    116 
    117     Args:
    118       will_block: Whether the calling thread will subsequently block on this
    119         thread group. Causes the active ReraiserThreadGroup (if there is one)
    120         to be marked as blocking on this thread group.
    121     """
    122     if will_block:
    123       # Multiple threads blocking on the same outer thread should not happen in
    124       # practice.
    125       assert not self.blocked_parent_thread_group
    126       self.blocked_parent_thread_group = CurrentThreadGroup()
    127     for thread in self._threads:
    128       thread.start()
    129 
    130   def _JoinAll(self, watcher=None, timeout=None):
    131     """Join all threads without stack dumps.
    132 
    133     Reraises exceptions raised by the child threads and supports breaking
    134     immediately on exceptions raised on the main thread.
    135 
    136     Args:
    137       watcher: Watchdog object providing the thread timeout. If none is
    138           provided, the thread will never be timed out.
    139       timeout: An optional number of seconds to wait before timing out the join
    140           operation. This will not time out the threads.
    141     """
    142     if watcher is None:
    143       watcher = watchdog_timer.WatchdogTimer(None)
    144     alive_threads = self._threads[:]
    145     end_time = (time.time() + timeout) if timeout else None
    146     try:
    147       while alive_threads and (end_time is None or end_time > time.time()):
    148         for thread in alive_threads[:]:
    149           if watcher.IsTimedOut():
    150             raise TimeoutError('Timed out waiting for %d of %d threads.' %
    151                                (len(alive_threads), len(self._threads)))
    152           # Allow the main thread to periodically check for interrupts.
    153           thread.join(0.1)
    154           if not thread.isAlive():
    155             alive_threads.remove(thread)
    156       # All threads are allowed to complete before reraising exceptions.
    157       for thread in self._threads:
    158         thread.ReraiseIfException()
    159     finally:
    160       self.blocked_parent_thread_group = None
    161 
    162   def IsAlive(self):
    163     """Check whether any of the threads are still alive.
    164 
    165     Returns:
    166       Whether any of the threads are still alive.
    167     """
    168     return any(t.isAlive() for t in self._threads)
    169 
    170   def JoinAll(self, watcher=None, timeout=None,
    171               error_log_func=logging.critical):
    172     """Join all threads.
    173 
    174     Reraises exceptions raised by the child threads and supports breaking
    175     immediately on exceptions raised on the main thread. Unfinished threads'
    176     stacks will be logged on watchdog timeout.
    177 
    178     Args:
    179       watcher: Watchdog object providing the thread timeout. If none is
    180           provided, the thread will never be timed out.
    181       timeout: An optional number of seconds to wait before timing out the join
    182           operation. This will not time out the threads.
    183       error_log_func: Logging function when logging errors.
    184     """
    185     try:
    186       self._JoinAll(watcher, timeout)
    187     except TimeoutError:
    188       error_log_func('Timed out. Dumping threads.')
    189       for thread in (t for t in self._threads if t.isAlive()):
    190         LogThreadStack(thread, error_log_func=error_log_func)
    191       raise
    192 
    193   def GetAllReturnValues(self, watcher=None):
    194     """Get all return values, joining all threads if necessary.
    195 
    196     Args:
    197       watcher: same as in |JoinAll|. Only used if threads are alive.
    198     """
    199     if any([t.isAlive() for t in self._threads]):
    200       self.JoinAll(watcher)
    201     return [t.GetReturnValue() for t in self._threads]
    202 
    203 
    204 def CurrentThreadGroup():
    205   """Returns the ReraiserThreadGroup that owns the running thread.
    206 
    207   Returns:
    208     The current thread group, otherwise None.
    209   """
    210   current_thread = threading.current_thread()
    211   if isinstance(current_thread, ReraiserThread):
    212     return current_thread._thread_group  # pylint: disable=no-member
    213   return None
    214 
    215 
    216 def RunAsync(funcs, watcher=None):
    217   """Executes the given functions in parallel and returns their results.
    218 
    219   Args:
    220     funcs: List of functions to perform on their own threads.
    221     watcher: Watchdog object providing timeout, by default waits forever.
    222 
    223   Returns:
    224     A list of return values in the order of the given functions.
    225   """
    226   thread_group = ReraiserThreadGroup(ReraiserThread(f) for f in funcs)
    227   thread_group.StartAll(will_block=True)
    228   return thread_group.GetAllReturnValues(watcher=watcher)
    229