Home | History | Annotate | Download | only in multiprocessing
      1 #
      2 # Module providing the `Pool` class for managing a process pool
      3 #
      4 # multiprocessing/pool.py
      5 #
      6 # Copyright (c) 2006-2008, R Oudkerk
      7 # All rights reserved.
      8 #
      9 # Redistribution and use in source and binary forms, with or without
     10 # modification, are permitted provided that the following conditions
     11 # are met:
     12 #
     13 # 1. Redistributions of source code must retain the above copyright
     14 #    notice, this list of conditions and the following disclaimer.
     15 # 2. Redistributions in binary form must reproduce the above copyright
     16 #    notice, this list of conditions and the following disclaimer in the
     17 #    documentation and/or other materials provided with the distribution.
     18 # 3. Neither the name of author nor the names of any contributors may be
     19 #    used to endorse or promote products derived from this software
     20 #    without specific prior written permission.
     21 #
     22 # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND
     23 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
     24 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
     25 # ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
     26 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
     27 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
     28 # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
     29 # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
     30 # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
     31 # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
     32 # SUCH DAMAGE.
     33 #
     34 
     35 __all__ = ['Pool']
     36 
     37 #
     38 # Imports
     39 #
     40 
     41 import threading
     42 import Queue
     43 import itertools
     44 import collections
     45 import time
     46 
     47 from multiprocessing import Process, cpu_count, TimeoutError
     48 from multiprocessing.util import Finalize, debug
     49 
     50 #
     51 # Constants representing the state of a pool
     52 #
     53 
     54 RUN = 0
     55 CLOSE = 1
     56 TERMINATE = 2
     57 
     58 #
     59 # Miscellaneous
     60 #
     61 
     62 job_counter = itertools.count()
     63 
     64 def mapstar(args):
     65     return map(*args)
     66 
     67 #
     68 # Code run by worker processes
     69 #
     70 
     71 class MaybeEncodingError(Exception):
     72     """Wraps possible unpickleable errors, so they can be
     73     safely sent through the socket."""
     74 
     75     def __init__(self, exc, value):
     76         self.exc = repr(exc)
     77         self.value = repr(value)
     78         super(MaybeEncodingError, self).__init__(self.exc, self.value)
     79 
     80     def __str__(self):
     81         return "Error sending result: '%s'. Reason: '%s'" % (self.value,
     82                                                              self.exc)
     83 
     84     def __repr__(self):
     85         return "<MaybeEncodingError: %s>" % str(self)
     86 
     87 
     88 def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
     89     assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
     90     put = outqueue.put
     91     get = inqueue.get
     92     if hasattr(inqueue, '_writer'):
     93         inqueue._writer.close()
     94         outqueue._reader.close()
     95 
     96     if initializer is not None:
     97         initializer(*initargs)
     98 
     99     completed = 0
    100     while maxtasks is None or (maxtasks and completed < maxtasks):
    101         try:
    102             task = get()
    103         except (EOFError, IOError):
    104             debug('worker got EOFError or IOError -- exiting')
    105             break
    106 
    107         if task is None:
    108             debug('worker got sentinel -- exiting')
    109             break
    110 
    111         job, i, func, args, kwds = task
    112         try:
    113             result = (True, func(*args, **kwds))
    114         except Exception, e:
    115             result = (False, e)
    116         try:
    117             put((job, i, result))
    118         except Exception as e:
    119             wrapped = MaybeEncodingError(e, result[1])
    120             debug("Possible encoding error while sending result: %s" % (
    121                 wrapped))
    122             put((job, i, (False, wrapped)))
    123         completed += 1
    124     debug('worker exiting after %d tasks' % completed)
    125 
    126 #
    127 # Class representing a process pool
    128 #
    129 
    130 class Pool(object):
    131     '''
    132     Class which supports an async version of the `apply()` builtin
    133     '''
    134     Process = Process
    135 
    136     def __init__(self, processes=None, initializer=None, initargs=(),
    137                  maxtasksperchild=None):
    138         self._setup_queues()
    139         self._taskqueue = Queue.Queue()
    140         self._cache = {}
    141         self._state = RUN
    142         self._maxtasksperchild = maxtasksperchild
    143         self._initializer = initializer
    144         self._initargs = initargs
    145 
    146         if processes is None:
    147             try:
    148                 processes = cpu_count()
    149             except NotImplementedError:
    150                 processes = 1
    151         if processes < 1:
    152             raise ValueError("Number of processes must be at least 1")
    153 
    154         if initializer is not None and not hasattr(initializer, '__call__'):
    155             raise TypeError('initializer must be a callable')
    156 
    157         self._processes = processes
    158         self._pool = []
    159         self._repopulate_pool()
    160 
    161         self._worker_handler = threading.Thread(
    162             target=Pool._handle_workers,
    163             args=(self, )
    164             )
    165         self._worker_handler.daemon = True
    166         self._worker_handler._state = RUN
    167         self._worker_handler.start()
    168 
    169 
    170         self._task_handler = threading.Thread(
    171             target=Pool._handle_tasks,
    172             args=(self._taskqueue, self._quick_put, self._outqueue, self._pool)
    173             )
    174         self._task_handler.daemon = True
    175         self._task_handler._state = RUN
    176         self._task_handler.start()
    177 
    178         self._result_handler = threading.Thread(
    179             target=Pool._handle_results,
    180             args=(self._outqueue, self._quick_get, self._cache)
    181             )
    182         self._result_handler.daemon = True
    183         self._result_handler._state = RUN
    184         self._result_handler.start()
    185 
    186         self._terminate = Finalize(
    187             self, self._terminate_pool,
    188             args=(self._taskqueue, self._inqueue, self._outqueue, self._pool,
    189                   self._worker_handler, self._task_handler,
    190                   self._result_handler, self._cache),
    191             exitpriority=15
    192             )
    193 
    194     def _join_exited_workers(self):
    195         """Cleanup after any worker processes which have exited due to reaching
    196         their specified lifetime.  Returns True if any workers were cleaned up.
    197         """
    198         cleaned = False
    199         for i in reversed(range(len(self._pool))):
    200             worker = self._pool[i]
    201             if worker.exitcode is not None:
    202                 # worker exited
    203                 debug('cleaning up worker %d' % i)
    204                 worker.join()
    205                 cleaned = True
    206                 del self._pool[i]
    207         return cleaned
    208 
    209     def _repopulate_pool(self):
    210         """Bring the number of pool processes up to the specified number,
    211         for use after reaping workers which have exited.
    212         """
    213         for i in range(self._processes - len(self._pool)):
    214             w = self.Process(target=worker,
    215                              args=(self._inqueue, self._outqueue,
    216                                    self._initializer,
    217                                    self._initargs, self._maxtasksperchild)
    218                             )
    219             self._pool.append(w)
    220             w.name = w.name.replace('Process', 'PoolWorker')
    221             w.daemon = True
    222             w.start()
    223             debug('added worker')
    224 
    225     def _maintain_pool(self):
    226         """Clean up any exited workers and start replacements for them.
    227         """
    228         if self._join_exited_workers():
    229             self._repopulate_pool()
    230 
    231     def _setup_queues(self):
    232         from .queues import SimpleQueue
    233         self._inqueue = SimpleQueue()
    234         self._outqueue = SimpleQueue()
    235         self._quick_put = self._inqueue._writer.send
    236         self._quick_get = self._outqueue._reader.recv
    237 
    238     def apply(self, func, args=(), kwds={}):
    239         '''
    240         Equivalent of `apply()` builtin
    241         '''
    242         assert self._state == RUN
    243         return self.apply_async(func, args, kwds).get()
    244 
    245     def map(self, func, iterable, chunksize=None):
    246         '''
    247         Equivalent of `map()` builtin
    248         '''
    249         assert self._state == RUN
    250         return self.map_async(func, iterable, chunksize).get()
    251 
    252     def imap(self, func, iterable, chunksize=1):
    253         '''
    254         Equivalent of `itertools.imap()` -- can be MUCH slower than `Pool.map()`
    255         '''
    256         assert self._state == RUN
    257         if chunksize == 1:
    258             result = IMapIterator(self._cache)
    259             self._taskqueue.put((((result._job, i, func, (x,), {})
    260                          for i, x in enumerate(iterable)), result._set_length))
    261             return result
    262         else:
    263             assert chunksize > 1
    264             task_batches = Pool._get_tasks(func, iterable, chunksize)
    265             result = IMapIterator(self._cache)
    266             self._taskqueue.put((((result._job, i, mapstar, (x,), {})
    267                      for i, x in enumerate(task_batches)), result._set_length))
    268             return (item for chunk in result for item in chunk)
    269 
    270     def imap_unordered(self, func, iterable, chunksize=1):
    271         '''
    272         Like `imap()` method but ordering of results is arbitrary
    273         '''
    274         assert self._state == RUN
    275         if chunksize == 1:
    276             result = IMapUnorderedIterator(self._cache)
    277             self._taskqueue.put((((result._job, i, func, (x,), {})
    278                          for i, x in enumerate(iterable)), result._set_length))
    279             return result
    280         else:
    281             assert chunksize > 1
    282             task_batches = Pool._get_tasks(func, iterable, chunksize)
    283             result = IMapUnorderedIterator(self._cache)
    284             self._taskqueue.put((((result._job, i, mapstar, (x,), {})
    285                      for i, x in enumerate(task_batches)), result._set_length))
    286             return (item for chunk in result for item in chunk)
    287 
    288     def apply_async(self, func, args=(), kwds={}, callback=None):
    289         '''
    290         Asynchronous equivalent of `apply()` builtin
    291         '''
    292         assert self._state == RUN
    293         result = ApplyResult(self._cache, callback)
    294         self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
    295         return result
    296 
    297     def map_async(self, func, iterable, chunksize=None, callback=None):
    298         '''
    299         Asynchronous equivalent of `map()` builtin
    300         '''
    301         assert self._state == RUN
    302         if not hasattr(iterable, '__len__'):
    303             iterable = list(iterable)
    304 
    305         if chunksize is None:
    306             chunksize, extra = divmod(len(iterable), len(self._pool) * 4)
    307             if extra:
    308                 chunksize += 1
    309         if len(iterable) == 0:
    310             chunksize = 0
    311 
    312         task_batches = Pool._get_tasks(func, iterable, chunksize)
    313         result = MapResult(self._cache, chunksize, len(iterable), callback)
    314         self._taskqueue.put((((result._job, i, mapstar, (x,), {})
    315                               for i, x in enumerate(task_batches)), None))
    316         return result
    317 
    318     @staticmethod
    319     def _handle_workers(pool):
    320         thread = threading.current_thread()
    321 
    322         # Keep maintaining workers until the cache gets drained, unless the pool
    323         # is terminated.
    324         while thread._state == RUN or (pool._cache and thread._state != TERMINATE):
    325             pool._maintain_pool()
    326             time.sleep(0.1)
    327         # send sentinel to stop workers
    328         pool._taskqueue.put(None)
    329         debug('worker handler exiting')
    330 
    331     @staticmethod
    332     def _handle_tasks(taskqueue, put, outqueue, pool):
    333         thread = threading.current_thread()
    334 
    335         for taskseq, set_length in iter(taskqueue.get, None):
    336             i = -1
    337             for i, task in enumerate(taskseq):
    338                 if thread._state:
    339                     debug('task handler found thread._state != RUN')
    340                     break
    341                 try:
    342                     put(task)
    343                 except IOError:
    344                     debug('could not put task on queue')
    345                     break
    346             else:
    347                 if set_length:
    348                     debug('doing set_length()')
    349                     set_length(i+1)
    350                 continue
    351             break
    352         else:
    353             debug('task handler got sentinel')
    354 
    355 
    356         try:
    357             # tell result handler to finish when cache is empty
    358             debug('task handler sending sentinel to result handler')
    359             outqueue.put(None)
    360 
    361             # tell workers there is no more work
    362             debug('task handler sending sentinel to workers')
    363             for p in pool:
    364                 put(None)
    365         except IOError:
    366             debug('task handler got IOError when sending sentinels')
    367 
    368         debug('task handler exiting')
    369 
    370     @staticmethod
    371     def _handle_results(outqueue, get, cache):
    372         thread = threading.current_thread()
    373 
    374         while 1:
    375             try:
    376                 task = get()
    377             except (IOError, EOFError):
    378                 debug('result handler got EOFError/IOError -- exiting')
    379                 return
    380 
    381             if thread._state:
    382                 assert thread._state == TERMINATE
    383                 debug('result handler found thread._state=TERMINATE')
    384                 break
    385 
    386             if task is None:
    387                 debug('result handler got sentinel')
    388                 break
    389 
    390             job, i, obj = task
    391             try:
    392                 cache[job]._set(i, obj)
    393             except KeyError:
    394                 pass
    395 
    396         while cache and thread._state != TERMINATE:
    397             try:
    398                 task = get()
    399             except (IOError, EOFError):
    400                 debug('result handler got EOFError/IOError -- exiting')
    401                 return
    402 
    403             if task is None:
    404                 debug('result handler ignoring extra sentinel')
    405                 continue
    406             job, i, obj = task
    407             try:
    408                 cache[job]._set(i, obj)
    409             except KeyError:
    410                 pass
    411 
    412         if hasattr(outqueue, '_reader'):
    413             debug('ensuring that outqueue is not full')
    414             # If we don't make room available in outqueue then
    415             # attempts to add the sentinel (None) to outqueue may
    416             # block.  There is guaranteed to be no more than 2 sentinels.
    417             try:
    418                 for i in range(10):
    419                     if not outqueue._reader.poll():
    420                         break
    421                     get()
    422             except (IOError, EOFError):
    423                 pass
    424 
    425         debug('result handler exiting: len(cache)=%s, thread._state=%s',
    426               len(cache), thread._state)
    427 
    428     @staticmethod
    429     def _get_tasks(func, it, size):
    430         it = iter(it)
    431         while 1:
    432             x = tuple(itertools.islice(it, size))
    433             if not x:
    434                 return
    435             yield (func, x)
    436 
    437     def __reduce__(self):
    438         raise NotImplementedError(
    439               'pool objects cannot be passed between processes or pickled'
    440               )
    441 
    442     def close(self):
    443         debug('closing pool')
    444         if self._state == RUN:
    445             self._state = CLOSE
    446             self._worker_handler._state = CLOSE
    447 
    448     def terminate(self):
    449         debug('terminating pool')
    450         self._state = TERMINATE
    451         self._worker_handler._state = TERMINATE
    452         self._terminate()
    453 
    454     def join(self):
    455         debug('joining pool')
    456         assert self._state in (CLOSE, TERMINATE)
    457         self._worker_handler.join()
    458         self._task_handler.join()
    459         self._result_handler.join()
    460         for p in self._pool:
    461             p.join()
    462 
    463     @staticmethod
    464     def _help_stuff_finish(inqueue, task_handler, size):
    465         # task_handler may be blocked trying to put items on inqueue
    466         debug('removing tasks from inqueue until task handler finished')
    467         inqueue._rlock.acquire()
    468         while task_handler.is_alive() and inqueue._reader.poll():
    469             inqueue._reader.recv()
    470             time.sleep(0)
    471 
    472     @classmethod
    473     def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool,
    474                         worker_handler, task_handler, result_handler, cache):
    475         # this is guaranteed to only be called once
    476         debug('finalizing pool')
    477 
    478         worker_handler._state = TERMINATE
    479         task_handler._state = TERMINATE
    480 
    481         debug('helping task handler/workers to finish')
    482         cls._help_stuff_finish(inqueue, task_handler, len(pool))
    483 
    484         assert result_handler.is_alive() or len(cache) == 0
    485 
    486         result_handler._state = TERMINATE
    487         outqueue.put(None)                  # sentinel
    488 
    489         # We must wait for the worker handler to exit before terminating
    490         # workers because we don't want workers to be restarted behind our back.
    491         debug('joining worker handler')
    492         if threading.current_thread() is not worker_handler:
    493             worker_handler.join(1e100)
    494 
    495         # Terminate workers which haven't already finished.
    496         if pool and hasattr(pool[0], 'terminate'):
    497             debug('terminating workers')
    498             for p in pool:
    499                 if p.exitcode is None:
    500                     p.terminate()
    501 
    502         debug('joining task handler')
    503         if threading.current_thread() is not task_handler:
    504             task_handler.join(1e100)
    505 
    506         debug('joining result handler')
    507         if threading.current_thread() is not result_handler:
    508             result_handler.join(1e100)
    509 
    510         if pool and hasattr(pool[0], 'terminate'):
    511             debug('joining pool workers')
    512             for p in pool:
    513                 if p.is_alive():
    514                     # worker has not yet exited
    515                     debug('cleaning up worker %d' % p.pid)
    516                     p.join()
    517 
    518 #
    519 # Class whose instances are returned by `Pool.apply_async()`
    520 #
    521 
    522 class ApplyResult(object):
    523 
    524     def __init__(self, cache, callback):
    525         self._cond = threading.Condition(threading.Lock())
    526         self._job = job_counter.next()
    527         self._cache = cache
    528         self._ready = False
    529         self._callback = callback
    530         cache[self._job] = self
    531 
    532     def ready(self):
    533         return self._ready
    534 
    535     def successful(self):
    536         assert self._ready
    537         return self._success
    538 
    539     def wait(self, timeout=None):
    540         self._cond.acquire()
    541         try:
    542             if not self._ready:
    543                 self._cond.wait(timeout)
    544         finally:
    545             self._cond.release()
    546 
    547     def get(self, timeout=None):
    548         self.wait(timeout)
    549         if not self._ready:
    550             raise TimeoutError
    551         if self._success:
    552             return self._value
    553         else:
    554             raise self._value
    555 
    556     def _set(self, i, obj):
    557         self._success, self._value = obj
    558         if self._callback and self._success:
    559             self._callback(self._value)
    560         self._cond.acquire()
    561         try:
    562             self._ready = True
    563             self._cond.notify()
    564         finally:
    565             self._cond.release()
    566         del self._cache[self._job]
    567 
    568 AsyncResult = ApplyResult       # create alias -- see #17805
    569 
    570 #
    571 # Class whose instances are returned by `Pool.map_async()`
    572 #
    573 
    574 class MapResult(ApplyResult):
    575 
    576     def __init__(self, cache, chunksize, length, callback):
    577         ApplyResult.__init__(self, cache, callback)
    578         self._success = True
    579         self._value = [None] * length
    580         self._chunksize = chunksize
    581         if chunksize <= 0:
    582             self._number_left = 0
    583             self._ready = True
    584             del cache[self._job]
    585         else:
    586             self._number_left = length//chunksize + bool(length % chunksize)
    587 
    588     def _set(self, i, success_result):
    589         success, result = success_result
    590         if success:
    591             self._value[i*self._chunksize:(i+1)*self._chunksize] = result
    592             self._number_left -= 1
    593             if self._number_left == 0:
    594                 if self._callback:
    595                     self._callback(self._value)
    596                 del self._cache[self._job]
    597                 self._cond.acquire()
    598                 try:
    599                     self._ready = True
    600                     self._cond.notify()
    601                 finally:
    602                     self._cond.release()
    603 
    604         else:
    605             self._success = False
    606             self._value = result
    607             del self._cache[self._job]
    608             self._cond.acquire()
    609             try:
    610                 self._ready = True
    611                 self._cond.notify()
    612             finally:
    613                 self._cond.release()
    614 
    615 #
    616 # Class whose instances are returned by `Pool.imap()`
    617 #
    618 
    619 class IMapIterator(object):
    620 
    621     def __init__(self, cache):
    622         self._cond = threading.Condition(threading.Lock())
    623         self._job = job_counter.next()
    624         self._cache = cache
    625         self._items = collections.deque()
    626         self._index = 0
    627         self._length = None
    628         self._unsorted = {}
    629         cache[self._job] = self
    630 
    631     def __iter__(self):
    632         return self
    633 
    634     def next(self, timeout=None):
    635         self._cond.acquire()
    636         try:
    637             try:
    638                 item = self._items.popleft()
    639             except IndexError:
    640                 if self._index == self._length:
    641                     raise StopIteration
    642                 self._cond.wait(timeout)
    643                 try:
    644                     item = self._items.popleft()
    645                 except IndexError:
    646                     if self._index == self._length:
    647                         raise StopIteration
    648                     raise TimeoutError
    649         finally:
    650             self._cond.release()
    651 
    652         success, value = item
    653         if success:
    654             return value
    655         raise value
    656 
    657     __next__ = next                    # XXX
    658 
    659     def _set(self, i, obj):
    660         self._cond.acquire()
    661         try:
    662             if self._index == i:
    663                 self._items.append(obj)
    664                 self._index += 1
    665                 while self._index in self._unsorted:
    666                     obj = self._unsorted.pop(self._index)
    667                     self._items.append(obj)
    668                     self._index += 1
    669                 self._cond.notify()
    670             else:
    671                 self._unsorted[i] = obj
    672 
    673             if self._index == self._length:
    674                 del self._cache[self._job]
    675         finally:
    676             self._cond.release()
    677 
    678     def _set_length(self, length):
    679         self._cond.acquire()
    680         try:
    681             self._length = length
    682             if self._index == self._length:
    683                 self._cond.notify()
    684                 del self._cache[self._job]
    685         finally:
    686             self._cond.release()
    687 
    688 #
    689 # Class whose instances are returned by `Pool.imap_unordered()`
    690 #
    691 
    692 class IMapUnorderedIterator(IMapIterator):
    693 
    694     def _set(self, i, obj):
    695         self._cond.acquire()
    696         try:
    697             self._items.append(obj)
    698             self._index += 1
    699             self._cond.notify()
    700             if self._index == self._length:
    701                 del self._cache[self._job]
    702         finally:
    703             self._cond.release()
    704 
    705 #
    706 #
    707 #
    708 
    709 class ThreadPool(Pool):
    710 
    711     from .dummy import Process
    712 
    713     def __init__(self, processes=None, initializer=None, initargs=()):
    714         Pool.__init__(self, processes, initializer, initargs)
    715 
    716     def _setup_queues(self):
    717         self._inqueue = Queue.Queue()
    718         self._outqueue = Queue.Queue()
    719         self._quick_put = self._inqueue.put
    720         self._quick_get = self._outqueue.get
    721 
    722     @staticmethod
    723     def _help_stuff_finish(inqueue, task_handler, size):
    724         # put sentinels at head of inqueue to make workers finish
    725         inqueue.not_empty.acquire()
    726         try:
    727             inqueue.queue.clear()
    728             inqueue.queue.extend([None] * size)
    729             inqueue.not_empty.notify_all()
    730         finally:
    731             inqueue.not_empty.release()
    732