Home | History | Annotate | Download | only in multiprocessing
      1 #
      2 # Module providing the `Pool` class for managing a process pool
      3 #
      4 # multiprocessing/pool.py
      5 #
      6 # Copyright (c) 2006-2008, R Oudkerk
      7 # Licensed to PSF under a Contributor Agreement.
      8 #
      9 
     10 __all__ = ['Pool', 'ThreadPool']
     11 
     12 #
     13 # Imports
     14 #
     15 
     16 import threading
     17 import queue
     18 import itertools
     19 import collections
     20 import os
     21 import time
     22 import traceback
     23 
     24 # If threading is available then ThreadPool should be provided.  Therefore
     25 # we avoid top-level imports which are liable to fail on some systems.
     26 from . import util
     27 from . import get_context, TimeoutError
     28 
     29 #
     30 # Constants representing the state of a pool
     31 #
     32 
     33 RUN = 0
     34 CLOSE = 1
     35 TERMINATE = 2
     36 
     37 #
     38 # Miscellaneous
     39 #
     40 
     41 job_counter = itertools.count()
     42 
     43 def mapstar(args):
     44     return list(map(*args))
     45 
     46 def starmapstar(args):
     47     return list(itertools.starmap(args[0], args[1]))
     48 
     49 #
     50 # Hack to embed stringification of remote traceback in local traceback
     51 #
     52 
     53 class RemoteTraceback(Exception):
     54     def __init__(self, tb):
     55         self.tb = tb
     56     def __str__(self):
     57         return self.tb
     58 
     59 class ExceptionWithTraceback:
     60     def __init__(self, exc, tb):
     61         tb = traceback.format_exception(type(exc), exc, tb)
     62         tb = ''.join(tb)
     63         self.exc = exc
     64         self.tb = '\n"""\n%s"""' % tb
     65     def __reduce__(self):
     66         return rebuild_exc, (self.exc, self.tb)
     67 
     68 def rebuild_exc(exc, tb):
     69     exc.__cause__ = RemoteTraceback(tb)
     70     return exc
     71 
     72 #
     73 # Code run by worker processes
     74 #
     75 
     76 class MaybeEncodingError(Exception):
     77     """Wraps possible unpickleable errors, so they can be
     78     safely sent through the socket."""
     79 
     80     def __init__(self, exc, value):
     81         self.exc = repr(exc)
     82         self.value = repr(value)
     83         super(MaybeEncodingError, self).__init__(self.exc, self.value)
     84 
     85     def __str__(self):
     86         return "Error sending result: '%s'. Reason: '%s'" % (self.value,
     87                                                              self.exc)
     88 
     89     def __repr__(self):
     90         return "<%s: %s>" % (self.__class__.__name__, self)
     91 
     92 
     93 def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None,
     94            wrap_exception=False):
     95     assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
     96     put = outqueue.put
     97     get = inqueue.get
     98     if hasattr(inqueue, '_writer'):
     99         inqueue._writer.close()
    100         outqueue._reader.close()
    101 
    102     if initializer is not None:
    103         initializer(*initargs)
    104 
    105     completed = 0
    106     while maxtasks is None or (maxtasks and completed < maxtasks):
    107         try:
    108             task = get()
    109         except (EOFError, OSError):
    110             util.debug('worker got EOFError or OSError -- exiting')
    111             break
    112 
    113         if task is None:
    114             util.debug('worker got sentinel -- exiting')
    115             break
    116 
    117         job, i, func, args, kwds = task
    118         try:
    119             result = (True, func(*args, **kwds))
    120         except Exception as e:
    121             if wrap_exception:
    122                 e = ExceptionWithTraceback(e, e.__traceback__)
    123             result = (False, e)
    124         try:
    125             put((job, i, result))
    126         except Exception as e:
    127             wrapped = MaybeEncodingError(e, result[1])
    128             util.debug("Possible encoding error while sending result: %s" % (
    129                 wrapped))
    130             put((job, i, (False, wrapped)))
    131         completed += 1
    132     util.debug('worker exiting after %d tasks' % completed)
    133 
    134 #
    135 # Class representing a process pool
    136 #
    137 
    138 class Pool(object):
    139     '''
    140     Class which supports an async version of applying functions to arguments.
    141     '''
    142     _wrap_exception = True
    143 
    144     def Process(self, *args, **kwds):
    145         return self._ctx.Process(*args, **kwds)
    146 
    147     def __init__(self, processes=None, initializer=None, initargs=(),
    148                  maxtasksperchild=None, context=None):
    149         self._ctx = context or get_context()
    150         self._setup_queues()
    151         self._taskqueue = queue.Queue()
    152         self._cache = {}
    153         self._state = RUN
    154         self._maxtasksperchild = maxtasksperchild
    155         self._initializer = initializer
    156         self._initargs = initargs
    157 
    158         if processes is None:
    159             processes = os.cpu_count() or 1
    160         if processes < 1:
    161             raise ValueError("Number of processes must be at least 1")
    162 
    163         if initializer is not None and not callable(initializer):
    164             raise TypeError('initializer must be a callable')
    165 
    166         self._processes = processes
    167         self._pool = []
    168         self._repopulate_pool()
    169 
    170         self._worker_handler = threading.Thread(
    171             target=Pool._handle_workers,
    172             args=(self, )
    173             )
    174         self._worker_handler.daemon = True
    175         self._worker_handler._state = RUN
    176         self._worker_handler.start()
    177 
    178 
    179         self._task_handler = threading.Thread(
    180             target=Pool._handle_tasks,
    181             args=(self._taskqueue, self._quick_put, self._outqueue,
    182                   self._pool, self._cache)
    183             )
    184         self._task_handler.daemon = True
    185         self._task_handler._state = RUN
    186         self._task_handler.start()
    187 
    188         self._result_handler = threading.Thread(
    189             target=Pool._handle_results,
    190             args=(self._outqueue, self._quick_get, self._cache)
    191             )
    192         self._result_handler.daemon = True
    193         self._result_handler._state = RUN
    194         self._result_handler.start()
    195 
    196         self._terminate = util.Finalize(
    197             self, self._terminate_pool,
    198             args=(self._taskqueue, self._inqueue, self._outqueue, self._pool,
    199                   self._worker_handler, self._task_handler,
    200                   self._result_handler, self._cache),
    201             exitpriority=15
    202             )
    203 
    204     def _join_exited_workers(self):
    205         """Cleanup after any worker processes which have exited due to reaching
    206         their specified lifetime.  Returns True if any workers were cleaned up.
    207         """
    208         cleaned = False
    209         for i in reversed(range(len(self._pool))):
    210             worker = self._pool[i]
    211             if worker.exitcode is not None:
    212                 # worker exited
    213                 util.debug('cleaning up worker %d' % i)
    214                 worker.join()
    215                 cleaned = True
    216                 del self._pool[i]
    217         return cleaned
    218 
    219     def _repopulate_pool(self):
    220         """Bring the number of pool processes up to the specified number,
    221         for use after reaping workers which have exited.
    222         """
    223         for i in range(self._processes - len(self._pool)):
    224             w = self.Process(target=worker,
    225                              args=(self._inqueue, self._outqueue,
    226                                    self._initializer,
    227                                    self._initargs, self._maxtasksperchild,
    228                                    self._wrap_exception)
    229                             )
    230             self._pool.append(w)
    231             w.name = w.name.replace('Process', 'PoolWorker')
    232             w.daemon = True
    233             w.start()
    234             util.debug('added worker')
    235 
    236     def _maintain_pool(self):
    237         """Clean up any exited workers and start replacements for them.
    238         """
    239         if self._join_exited_workers():
    240             self._repopulate_pool()
    241 
    242     def _setup_queues(self):
    243         self._inqueue = self._ctx.SimpleQueue()
    244         self._outqueue = self._ctx.SimpleQueue()
    245         self._quick_put = self._inqueue._writer.send
    246         self._quick_get = self._outqueue._reader.recv
    247 
    248     def apply(self, func, args=(), kwds={}):
    249         '''
    250         Equivalent of `func(*args, **kwds)`.
    251         '''
    252         assert self._state == RUN
    253         return self.apply_async(func, args, kwds).get()
    254 
    255     def map(self, func, iterable, chunksize=None):
    256         '''
    257         Apply `func` to each element in `iterable`, collecting the results
    258         in a list that is returned.
    259         '''
    260         return self._map_async(func, iterable, mapstar, chunksize).get()
    261 
    262     def starmap(self, func, iterable, chunksize=None):
    263         '''
    264         Like `map()` method but the elements of the `iterable` are expected to
    265         be iterables as well and will be unpacked as arguments. Hence
    266         `func` and (a, b) becomes func(a, b).
    267         '''
    268         return self._map_async(func, iterable, starmapstar, chunksize).get()
    269 
    270     def starmap_async(self, func, iterable, chunksize=None, callback=None,
    271             error_callback=None):
    272         '''
    273         Asynchronous version of `starmap()` method.
    274         '''
    275         return self._map_async(func, iterable, starmapstar, chunksize,
    276                                callback, error_callback)
    277 
    278     def imap(self, func, iterable, chunksize=1):
    279         '''
    280         Equivalent of `map()` -- can be MUCH slower than `Pool.map()`.
    281         '''
    282         if self._state != RUN:
    283             raise ValueError("Pool not running")
    284         if chunksize == 1:
    285             result = IMapIterator(self._cache)
    286             self._taskqueue.put((((result._job, i, func, (x,), {})
    287                          for i, x in enumerate(iterable)), result._set_length))
    288             return result
    289         else:
    290             assert chunksize > 1
    291             task_batches = Pool._get_tasks(func, iterable, chunksize)
    292             result = IMapIterator(self._cache)
    293             self._taskqueue.put((((result._job, i, mapstar, (x,), {})
    294                      for i, x in enumerate(task_batches)), result._set_length))
    295             return (item for chunk in result for item in chunk)
    296 
    297     def imap_unordered(self, func, iterable, chunksize=1):
    298         '''
    299         Like `imap()` method but ordering of results is arbitrary.
    300         '''
    301         if self._state != RUN:
    302             raise ValueError("Pool not running")
    303         if chunksize == 1:
    304             result = IMapUnorderedIterator(self._cache)
    305             self._taskqueue.put((((result._job, i, func, (x,), {})
    306                          for i, x in enumerate(iterable)), result._set_length))
    307             return result
    308         else:
    309             assert chunksize > 1
    310             task_batches = Pool._get_tasks(func, iterable, chunksize)
    311             result = IMapUnorderedIterator(self._cache)
    312             self._taskqueue.put((((result._job, i, mapstar, (x,), {})
    313                      for i, x in enumerate(task_batches)), result._set_length))
    314             return (item for chunk in result for item in chunk)
    315 
    316     def apply_async(self, func, args=(), kwds={}, callback=None,
    317             error_callback=None):
    318         '''
    319         Asynchronous version of `apply()` method.
    320         '''
    321         if self._state != RUN:
    322             raise ValueError("Pool not running")
    323         result = ApplyResult(self._cache, callback, error_callback)
    324         self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
    325         return result
    326 
    327     def map_async(self, func, iterable, chunksize=None, callback=None,
    328             error_callback=None):
    329         '''
    330         Asynchronous version of `map()` method.
    331         '''
    332         return self._map_async(func, iterable, mapstar, chunksize, callback,
    333             error_callback)
    334 
    335     def _map_async(self, func, iterable, mapper, chunksize=None, callback=None,
    336             error_callback=None):
    337         '''
    338         Helper function to implement map, starmap and their async counterparts.
    339         '''
    340         if self._state != RUN:
    341             raise ValueError("Pool not running")
    342         if not hasattr(iterable, '__len__'):
    343             iterable = list(iterable)
    344 
    345         if chunksize is None:
    346             chunksize, extra = divmod(len(iterable), len(self._pool) * 4)
    347             if extra:
    348                 chunksize += 1
    349         if len(iterable) == 0:
    350             chunksize = 0
    351 
    352         task_batches = Pool._get_tasks(func, iterable, chunksize)
    353         result = MapResult(self._cache, chunksize, len(iterable), callback,
    354                            error_callback=error_callback)
    355         self._taskqueue.put((((result._job, i, mapper, (x,), {})
    356                               for i, x in enumerate(task_batches)), None))
    357         return result
    358 
    359     @staticmethod
    360     def _handle_workers(pool):
    361         thread = threading.current_thread()
    362 
    363         # Keep maintaining workers until the cache gets drained, unless the pool
    364         # is terminated.
    365         while thread._state == RUN or (pool._cache and thread._state != TERMINATE):
    366             pool._maintain_pool()
    367             time.sleep(0.1)
    368         # send sentinel to stop workers
    369         pool._taskqueue.put(None)
    370         util.debug('worker handler exiting')
    371 
    372     @staticmethod
    373     def _handle_tasks(taskqueue, put, outqueue, pool, cache):
    374         thread = threading.current_thread()
    375 
    376         for taskseq, set_length in iter(taskqueue.get, None):
    377             task = None
    378             i = -1
    379             try:
    380                 for i, task in enumerate(taskseq):
    381                     if thread._state:
    382                         util.debug('task handler found thread._state != RUN')
    383                         break
    384                     try:
    385                         put(task)
    386                     except Exception as e:
    387                         job, ind = task[:2]
    388                         try:
    389                             cache[job]._set(ind, (False, e))
    390                         except KeyError:
    391                             pass
    392                 else:
    393                     if set_length:
    394                         util.debug('doing set_length()')
    395                         set_length(i+1)
    396                     continue
    397                 break
    398             except Exception as ex:
    399                 job, ind = task[:2] if task else (0, 0)
    400                 if job in cache:
    401                     cache[job]._set(ind + 1, (False, ex))
    402                 if set_length:
    403                     util.debug('doing set_length()')
    404                     set_length(i+1)
    405         else:
    406             util.debug('task handler got sentinel')
    407 
    408 
    409         try:
    410             # tell result handler to finish when cache is empty
    411             util.debug('task handler sending sentinel to result handler')
    412             outqueue.put(None)
    413 
    414             # tell workers there is no more work
    415             util.debug('task handler sending sentinel to workers')
    416             for p in pool:
    417                 put(None)
    418         except OSError:
    419             util.debug('task handler got OSError when sending sentinels')
    420 
    421         util.debug('task handler exiting')
    422 
    423     @staticmethod
    424     def _handle_results(outqueue, get, cache):
    425         thread = threading.current_thread()
    426 
    427         while 1:
    428             try:
    429                 task = get()
    430             except (OSError, EOFError):
    431                 util.debug('result handler got EOFError/OSError -- exiting')
    432                 return
    433 
    434             if thread._state:
    435                 assert thread._state == TERMINATE
    436                 util.debug('result handler found thread._state=TERMINATE')
    437                 break
    438 
    439             if task is None:
    440                 util.debug('result handler got sentinel')
    441                 break
    442 
    443             job, i, obj = task
    444             try:
    445                 cache[job]._set(i, obj)
    446             except KeyError:
    447                 pass
    448 
    449         while cache and thread._state != TERMINATE:
    450             try:
    451                 task = get()
    452             except (OSError, EOFError):
    453                 util.debug('result handler got EOFError/OSError -- exiting')
    454                 return
    455 
    456             if task is None:
    457                 util.debug('result handler ignoring extra sentinel')
    458                 continue
    459             job, i, obj = task
    460             try:
    461                 cache[job]._set(i, obj)
    462             except KeyError:
    463                 pass
    464 
    465         if hasattr(outqueue, '_reader'):
    466             util.debug('ensuring that outqueue is not full')
    467             # If we don't make room available in outqueue then
    468             # attempts to add the sentinel (None) to outqueue may
    469             # block.  There is guaranteed to be no more than 2 sentinels.
    470             try:
    471                 for i in range(10):
    472                     if not outqueue._reader.poll():
    473                         break
    474                     get()
    475             except (OSError, EOFError):
    476                 pass
    477 
    478         util.debug('result handler exiting: len(cache)=%s, thread._state=%s',
    479               len(cache), thread._state)
    480 
    481     @staticmethod
    482     def _get_tasks(func, it, size):
    483         it = iter(it)
    484         while 1:
    485             x = tuple(itertools.islice(it, size))
    486             if not x:
    487                 return
    488             yield (func, x)
    489 
    490     def __reduce__(self):
    491         raise NotImplementedError(
    492               'pool objects cannot be passed between processes or pickled'
    493               )
    494 
    495     def close(self):
    496         util.debug('closing pool')
    497         if self._state == RUN:
    498             self._state = CLOSE
    499             self._worker_handler._state = CLOSE
    500 
    501     def terminate(self):
    502         util.debug('terminating pool')
    503         self._state = TERMINATE
    504         self._worker_handler._state = TERMINATE
    505         self._terminate()
    506 
    507     def join(self):
    508         util.debug('joining pool')
    509         assert self._state in (CLOSE, TERMINATE)
    510         self._worker_handler.join()
    511         self._task_handler.join()
    512         self._result_handler.join()
    513         for p in self._pool:
    514             p.join()
    515 
    516     @staticmethod
    517     def _help_stuff_finish(inqueue, task_handler, size):
    518         # task_handler may be blocked trying to put items on inqueue
    519         util.debug('removing tasks from inqueue until task handler finished')
    520         inqueue._rlock.acquire()
    521         while task_handler.is_alive() and inqueue._reader.poll():
    522             inqueue._reader.recv()
    523             time.sleep(0)
    524 
    525     @classmethod
    526     def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool,
    527                         worker_handler, task_handler, result_handler, cache):
    528         # this is guaranteed to only be called once
    529         util.debug('finalizing pool')
    530 
    531         worker_handler._state = TERMINATE
    532         task_handler._state = TERMINATE
    533 
    534         util.debug('helping task handler/workers to finish')
    535         cls._help_stuff_finish(inqueue, task_handler, len(pool))
    536 
    537         assert result_handler.is_alive() or len(cache) == 0
    538 
    539         result_handler._state = TERMINATE
    540         outqueue.put(None)                  # sentinel
    541 
    542         # We must wait for the worker handler to exit before terminating
    543         # workers because we don't want workers to be restarted behind our back.
    544         util.debug('joining worker handler')
    545         if threading.current_thread() is not worker_handler:
    546             worker_handler.join()
    547 
    548         # Terminate workers which haven't already finished.
    549         if pool and hasattr(pool[0], 'terminate'):
    550             util.debug('terminating workers')
    551             for p in pool:
    552                 if p.exitcode is None:
    553                     p.terminate()
    554 
    555         util.debug('joining task handler')
    556         if threading.current_thread() is not task_handler:
    557             task_handler.join()
    558 
    559         util.debug('joining result handler')
    560         if threading.current_thread() is not result_handler:
    561             result_handler.join()
    562 
    563         if pool and hasattr(pool[0], 'terminate'):
    564             util.debug('joining pool workers')
    565             for p in pool:
    566                 if p.is_alive():
    567                     # worker has not yet exited
    568                     util.debug('cleaning up worker %d' % p.pid)
    569                     p.join()
    570 
    571     def __enter__(self):
    572         return self
    573 
    574     def __exit__(self, exc_type, exc_val, exc_tb):
    575         self.terminate()
    576 
    577 #
    578 # Class whose instances are returned by `Pool.apply_async()`
    579 #
    580 
    581 class ApplyResult(object):
    582 
    583     def __init__(self, cache, callback, error_callback):
    584         self._event = threading.Event()
    585         self._job = next(job_counter)
    586         self._cache = cache
    587         self._callback = callback
    588         self._error_callback = error_callback
    589         cache[self._job] = self
    590 
    591     def ready(self):
    592         return self._event.is_set()
    593 
    594     def successful(self):
    595         assert self.ready()
    596         return self._success
    597 
    598     def wait(self, timeout=None):
    599         self._event.wait(timeout)
    600 
    601     def get(self, timeout=None):
    602         self.wait(timeout)
    603         if not self.ready():
    604             raise TimeoutError
    605         if self._success:
    606             return self._value
    607         else:
    608             raise self._value
    609 
    610     def _set(self, i, obj):
    611         self._success, self._value = obj
    612         if self._callback and self._success:
    613             self._callback(self._value)
    614         if self._error_callback and not self._success:
    615             self._error_callback(self._value)
    616         self._event.set()
    617         del self._cache[self._job]
    618 
    619 AsyncResult = ApplyResult       # create alias -- see #17805
    620 
    621 #
    622 # Class whose instances are returned by `Pool.map_async()`
    623 #
    624 
    625 class MapResult(ApplyResult):
    626 
    627     def __init__(self, cache, chunksize, length, callback, error_callback):
    628         ApplyResult.__init__(self, cache, callback,
    629                              error_callback=error_callback)
    630         self._success = True
    631         self._value = [None] * length
    632         self._chunksize = chunksize
    633         if chunksize <= 0:
    634             self._number_left = 0
    635             self._event.set()
    636             del cache[self._job]
    637         else:
    638             self._number_left = length//chunksize + bool(length % chunksize)
    639 
    640     def _set(self, i, success_result):
    641         self._number_left -= 1
    642         success, result = success_result
    643         if success and self._success:
    644             self._value[i*self._chunksize:(i+1)*self._chunksize] = result
    645             if self._number_left == 0:
    646                 if self._callback:
    647                     self._callback(self._value)
    648                 del self._cache[self._job]
    649                 self._event.set()
    650         else:
    651             if not success and self._success:
    652                 # only store first exception
    653                 self._success = False
    654                 self._value = result
    655             if self._number_left == 0:
    656                 # only consider the result ready once all jobs are done
    657                 if self._error_callback:
    658                     self._error_callback(self._value)
    659                 del self._cache[self._job]
    660                 self._event.set()
    661 
    662 #
    663 # Class whose instances are returned by `Pool.imap()`
    664 #
    665 
    666 class IMapIterator(object):
    667 
    668     def __init__(self, cache):
    669         self._cond = threading.Condition(threading.Lock())
    670         self._job = next(job_counter)
    671         self._cache = cache
    672         self._items = collections.deque()
    673         self._index = 0
    674         self._length = None
    675         self._unsorted = {}
    676         cache[self._job] = self
    677 
    678     def __iter__(self):
    679         return self
    680 
    681     def next(self, timeout=None):
    682         with self._cond:
    683             try:
    684                 item = self._items.popleft()
    685             except IndexError:
    686                 if self._index == self._length:
    687                     raise StopIteration
    688                 self._cond.wait(timeout)
    689                 try:
    690                     item = self._items.popleft()
    691                 except IndexError:
    692                     if self._index == self._length:
    693                         raise StopIteration
    694                     raise TimeoutError
    695 
    696         success, value = item
    697         if success:
    698             return value
    699         raise value
    700 
    701     __next__ = next                    # XXX
    702 
    703     def _set(self, i, obj):
    704         with self._cond:
    705             if self._index == i:
    706                 self._items.append(obj)
    707                 self._index += 1
    708                 while self._index in self._unsorted:
    709                     obj = self._unsorted.pop(self._index)
    710                     self._items.append(obj)
    711                     self._index += 1
    712                 self._cond.notify()
    713             else:
    714                 self._unsorted[i] = obj
    715 
    716             if self._index == self._length:
    717                 del self._cache[self._job]
    718 
    719     def _set_length(self, length):
    720         with self._cond:
    721             self._length = length
    722             if self._index == self._length:
    723                 self._cond.notify()
    724                 del self._cache[self._job]
    725 
    726 #
    727 # Class whose instances are returned by `Pool.imap_unordered()`
    728 #
    729 
    730 class IMapUnorderedIterator(IMapIterator):
    731 
    732     def _set(self, i, obj):
    733         with self._cond:
    734             self._items.append(obj)
    735             self._index += 1
    736             self._cond.notify()
    737             if self._index == self._length:
    738                 del self._cache[self._job]
    739 
    740 #
    741 #
    742 #
    743 
    744 class ThreadPool(Pool):
    745     _wrap_exception = False
    746 
    747     @staticmethod
    748     def Process(*args, **kwds):
    749         from .dummy import Process
    750         return Process(*args, **kwds)
    751 
    752     def __init__(self, processes=None, initializer=None, initargs=()):
    753         Pool.__init__(self, processes, initializer, initargs)
    754 
    755     def _setup_queues(self):
    756         self._inqueue = queue.Queue()
    757         self._outqueue = queue.Queue()
    758         self._quick_put = self._inqueue.put
    759         self._quick_get = self._outqueue.get
    760 
    761     @staticmethod
    762     def _help_stuff_finish(inqueue, task_handler, size):
    763         # put sentinels at head of inqueue to make workers finish
    764         with inqueue.not_empty:
    765             inqueue.queue.clear()
    766             inqueue.queue.extend([None] * size)
    767             inqueue.not_empty.notify_all()
    768