1 # 2 # Module providing the `Pool` class for managing a process pool 3 # 4 # multiprocessing/pool.py 5 # 6 # Copyright (c) 2006-2008, R Oudkerk 7 # All rights reserved. 8 # 9 # Redistribution and use in source and binary forms, with or without 10 # modification, are permitted provided that the following conditions 11 # are met: 12 # 13 # 1. Redistributions of source code must retain the above copyright 14 # notice, this list of conditions and the following disclaimer. 15 # 2. Redistributions in binary form must reproduce the above copyright 16 # notice, this list of conditions and the following disclaimer in the 17 # documentation and/or other materials provided with the distribution. 18 # 3. Neither the name of author nor the names of any contributors may be 19 # used to endorse or promote products derived from this software 20 # without specific prior written permission. 21 # 22 # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND 23 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 24 # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 25 # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE 26 # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 27 # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS 28 # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 29 # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 30 # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 31 # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF 32 # SUCH DAMAGE. 33 # 34 35 __all__ = ['Pool'] 36 37 # 38 # Imports 39 # 40 41 import threading 42 import Queue 43 import itertools 44 import collections 45 import time 46 47 from multiprocessing import Process, cpu_count, TimeoutError 48 from multiprocessing.util import Finalize, debug 49 50 # 51 # Constants representing the state of a pool 52 # 53 54 RUN = 0 55 CLOSE = 1 56 TERMINATE = 2 57 58 # 59 # Miscellaneous 60 # 61 62 job_counter = itertools.count() 63 64 def mapstar(args): 65 return map(*args) 66 67 # 68 # Code run by worker processes 69 # 70 71 class MaybeEncodingError(Exception): 72 """Wraps possible unpickleable errors, so they can be 73 safely sent through the socket.""" 74 75 def __init__(self, exc, value): 76 self.exc = repr(exc) 77 self.value = repr(value) 78 super(MaybeEncodingError, self).__init__(self.exc, self.value) 79 80 def __str__(self): 81 return "Error sending result: '%s'. Reason: '%s'" % (self.value, 82 self.exc) 83 84 def __repr__(self): 85 return "<MaybeEncodingError: %s>" % str(self) 86 87 88 def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None): 89 assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0) 90 put = outqueue.put 91 get = inqueue.get 92 if hasattr(inqueue, '_writer'): 93 inqueue._writer.close() 94 outqueue._reader.close() 95 96 if initializer is not None: 97 initializer(*initargs) 98 99 completed = 0 100 while maxtasks is None or (maxtasks and completed < maxtasks): 101 try: 102 task = get() 103 except (EOFError, IOError): 104 debug('worker got EOFError or IOError -- exiting') 105 break 106 107 if task is None: 108 debug('worker got sentinel -- exiting') 109 break 110 111 job, i, func, args, kwds = task 112 try: 113 result = (True, func(*args, **kwds)) 114 except Exception, e: 115 result = (False, e) 116 try: 117 put((job, i, result)) 118 except Exception as e: 119 wrapped = MaybeEncodingError(e, result[1]) 120 debug("Possible encoding error while sending result: %s" % ( 121 wrapped)) 122 put((job, i, (False, wrapped))) 123 completed += 1 124 debug('worker exiting after %d tasks' % completed) 125 126 # 127 # Class representing a process pool 128 # 129 130 class Pool(object): 131 ''' 132 Class which supports an async version of the `apply()` builtin 133 ''' 134 Process = Process 135 136 def __init__(self, processes=None, initializer=None, initargs=(), 137 maxtasksperchild=None): 138 self._setup_queues() 139 self._taskqueue = Queue.Queue() 140 self._cache = {} 141 self._state = RUN 142 self._maxtasksperchild = maxtasksperchild 143 self._initializer = initializer 144 self._initargs = initargs 145 146 if processes is None: 147 try: 148 processes = cpu_count() 149 except NotImplementedError: 150 processes = 1 151 if processes < 1: 152 raise ValueError("Number of processes must be at least 1") 153 154 if initializer is not None and not hasattr(initializer, '__call__'): 155 raise TypeError('initializer must be a callable') 156 157 self._processes = processes 158 self._pool = [] 159 self._repopulate_pool() 160 161 self._worker_handler = threading.Thread( 162 target=Pool._handle_workers, 163 args=(self, ) 164 ) 165 self._worker_handler.daemon = True 166 self._worker_handler._state = RUN 167 self._worker_handler.start() 168 169 170 self._task_handler = threading.Thread( 171 target=Pool._handle_tasks, 172 args=(self._taskqueue, self._quick_put, self._outqueue, self._pool) 173 ) 174 self._task_handler.daemon = True 175 self._task_handler._state = RUN 176 self._task_handler.start() 177 178 self._result_handler = threading.Thread( 179 target=Pool._handle_results, 180 args=(self._outqueue, self._quick_get, self._cache) 181 ) 182 self._result_handler.daemon = True 183 self._result_handler._state = RUN 184 self._result_handler.start() 185 186 self._terminate = Finalize( 187 self, self._terminate_pool, 188 args=(self._taskqueue, self._inqueue, self._outqueue, self._pool, 189 self._worker_handler, self._task_handler, 190 self._result_handler, self._cache), 191 exitpriority=15 192 ) 193 194 def _join_exited_workers(self): 195 """Cleanup after any worker processes which have exited due to reaching 196 their specified lifetime. Returns True if any workers were cleaned up. 197 """ 198 cleaned = False 199 for i in reversed(range(len(self._pool))): 200 worker = self._pool[i] 201 if worker.exitcode is not None: 202 # worker exited 203 debug('cleaning up worker %d' % i) 204 worker.join() 205 cleaned = True 206 del self._pool[i] 207 return cleaned 208 209 def _repopulate_pool(self): 210 """Bring the number of pool processes up to the specified number, 211 for use after reaping workers which have exited. 212 """ 213 for i in range(self._processes - len(self._pool)): 214 w = self.Process(target=worker, 215 args=(self._inqueue, self._outqueue, 216 self._initializer, 217 self._initargs, self._maxtasksperchild) 218 ) 219 self._pool.append(w) 220 w.name = w.name.replace('Process', 'PoolWorker') 221 w.daemon = True 222 w.start() 223 debug('added worker') 224 225 def _maintain_pool(self): 226 """Clean up any exited workers and start replacements for them. 227 """ 228 if self._join_exited_workers(): 229 self._repopulate_pool() 230 231 def _setup_queues(self): 232 from .queues import SimpleQueue 233 self._inqueue = SimpleQueue() 234 self._outqueue = SimpleQueue() 235 self._quick_put = self._inqueue._writer.send 236 self._quick_get = self._outqueue._reader.recv 237 238 def apply(self, func, args=(), kwds={}): 239 ''' 240 Equivalent of `apply()` builtin 241 ''' 242 assert self._state == RUN 243 return self.apply_async(func, args, kwds).get() 244 245 def map(self, func, iterable, chunksize=None): 246 ''' 247 Equivalent of `map()` builtin 248 ''' 249 assert self._state == RUN 250 return self.map_async(func, iterable, chunksize).get() 251 252 def imap(self, func, iterable, chunksize=1): 253 ''' 254 Equivalent of `itertools.imap()` -- can be MUCH slower than `Pool.map()` 255 ''' 256 assert self._state == RUN 257 if chunksize == 1: 258 result = IMapIterator(self._cache) 259 self._taskqueue.put((((result._job, i, func, (x,), {}) 260 for i, x in enumerate(iterable)), result._set_length)) 261 return result 262 else: 263 assert chunksize > 1 264 task_batches = Pool._get_tasks(func, iterable, chunksize) 265 result = IMapIterator(self._cache) 266 self._taskqueue.put((((result._job, i, mapstar, (x,), {}) 267 for i, x in enumerate(task_batches)), result._set_length)) 268 return (item for chunk in result for item in chunk) 269 270 def imap_unordered(self, func, iterable, chunksize=1): 271 ''' 272 Like `imap()` method but ordering of results is arbitrary 273 ''' 274 assert self._state == RUN 275 if chunksize == 1: 276 result = IMapUnorderedIterator(self._cache) 277 self._taskqueue.put((((result._job, i, func, (x,), {}) 278 for i, x in enumerate(iterable)), result._set_length)) 279 return result 280 else: 281 assert chunksize > 1 282 task_batches = Pool._get_tasks(func, iterable, chunksize) 283 result = IMapUnorderedIterator(self._cache) 284 self._taskqueue.put((((result._job, i, mapstar, (x,), {}) 285 for i, x in enumerate(task_batches)), result._set_length)) 286 return (item for chunk in result for item in chunk) 287 288 def apply_async(self, func, args=(), kwds={}, callback=None): 289 ''' 290 Asynchronous equivalent of `apply()` builtin 291 ''' 292 assert self._state == RUN 293 result = ApplyResult(self._cache, callback) 294 self._taskqueue.put(([(result._job, None, func, args, kwds)], None)) 295 return result 296 297 def map_async(self, func, iterable, chunksize=None, callback=None): 298 ''' 299 Asynchronous equivalent of `map()` builtin 300 ''' 301 assert self._state == RUN 302 if not hasattr(iterable, '__len__'): 303 iterable = list(iterable) 304 305 if chunksize is None: 306 chunksize, extra = divmod(len(iterable), len(self._pool) * 4) 307 if extra: 308 chunksize += 1 309 if len(iterable) == 0: 310 chunksize = 0 311 312 task_batches = Pool._get_tasks(func, iterable, chunksize) 313 result = MapResult(self._cache, chunksize, len(iterable), callback) 314 self._taskqueue.put((((result._job, i, mapstar, (x,), {}) 315 for i, x in enumerate(task_batches)), None)) 316 return result 317 318 @staticmethod 319 def _handle_workers(pool): 320 thread = threading.current_thread() 321 322 # Keep maintaining workers until the cache gets drained, unless the pool 323 # is terminated. 324 while thread._state == RUN or (pool._cache and thread._state != TERMINATE): 325 pool._maintain_pool() 326 time.sleep(0.1) 327 # send sentinel to stop workers 328 pool._taskqueue.put(None) 329 debug('worker handler exiting') 330 331 @staticmethod 332 def _handle_tasks(taskqueue, put, outqueue, pool): 333 thread = threading.current_thread() 334 335 for taskseq, set_length in iter(taskqueue.get, None): 336 i = -1 337 for i, task in enumerate(taskseq): 338 if thread._state: 339 debug('task handler found thread._state != RUN') 340 break 341 try: 342 put(task) 343 except IOError: 344 debug('could not put task on queue') 345 break 346 else: 347 if set_length: 348 debug('doing set_length()') 349 set_length(i+1) 350 continue 351 break 352 else: 353 debug('task handler got sentinel') 354 355 356 try: 357 # tell result handler to finish when cache is empty 358 debug('task handler sending sentinel to result handler') 359 outqueue.put(None) 360 361 # tell workers there is no more work 362 debug('task handler sending sentinel to workers') 363 for p in pool: 364 put(None) 365 except IOError: 366 debug('task handler got IOError when sending sentinels') 367 368 debug('task handler exiting') 369 370 @staticmethod 371 def _handle_results(outqueue, get, cache): 372 thread = threading.current_thread() 373 374 while 1: 375 try: 376 task = get() 377 except (IOError, EOFError): 378 debug('result handler got EOFError/IOError -- exiting') 379 return 380 381 if thread._state: 382 assert thread._state == TERMINATE 383 debug('result handler found thread._state=TERMINATE') 384 break 385 386 if task is None: 387 debug('result handler got sentinel') 388 break 389 390 job, i, obj = task 391 try: 392 cache[job]._set(i, obj) 393 except KeyError: 394 pass 395 396 while cache and thread._state != TERMINATE: 397 try: 398 task = get() 399 except (IOError, EOFError): 400 debug('result handler got EOFError/IOError -- exiting') 401 return 402 403 if task is None: 404 debug('result handler ignoring extra sentinel') 405 continue 406 job, i, obj = task 407 try: 408 cache[job]._set(i, obj) 409 except KeyError: 410 pass 411 412 if hasattr(outqueue, '_reader'): 413 debug('ensuring that outqueue is not full') 414 # If we don't make room available in outqueue then 415 # attempts to add the sentinel (None) to outqueue may 416 # block. There is guaranteed to be no more than 2 sentinels. 417 try: 418 for i in range(10): 419 if not outqueue._reader.poll(): 420 break 421 get() 422 except (IOError, EOFError): 423 pass 424 425 debug('result handler exiting: len(cache)=%s, thread._state=%s', 426 len(cache), thread._state) 427 428 @staticmethod 429 def _get_tasks(func, it, size): 430 it = iter(it) 431 while 1: 432 x = tuple(itertools.islice(it, size)) 433 if not x: 434 return 435 yield (func, x) 436 437 def __reduce__(self): 438 raise NotImplementedError( 439 'pool objects cannot be passed between processes or pickled' 440 ) 441 442 def close(self): 443 debug('closing pool') 444 if self._state == RUN: 445 self._state = CLOSE 446 self._worker_handler._state = CLOSE 447 448 def terminate(self): 449 debug('terminating pool') 450 self._state = TERMINATE 451 self._worker_handler._state = TERMINATE 452 self._terminate() 453 454 def join(self): 455 debug('joining pool') 456 assert self._state in (CLOSE, TERMINATE) 457 self._worker_handler.join() 458 self._task_handler.join() 459 self._result_handler.join() 460 for p in self._pool: 461 p.join() 462 463 @staticmethod 464 def _help_stuff_finish(inqueue, task_handler, size): 465 # task_handler may be blocked trying to put items on inqueue 466 debug('removing tasks from inqueue until task handler finished') 467 inqueue._rlock.acquire() 468 while task_handler.is_alive() and inqueue._reader.poll(): 469 inqueue._reader.recv() 470 time.sleep(0) 471 472 @classmethod 473 def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool, 474 worker_handler, task_handler, result_handler, cache): 475 # this is guaranteed to only be called once 476 debug('finalizing pool') 477 478 worker_handler._state = TERMINATE 479 task_handler._state = TERMINATE 480 481 debug('helping task handler/workers to finish') 482 cls._help_stuff_finish(inqueue, task_handler, len(pool)) 483 484 assert result_handler.is_alive() or len(cache) == 0 485 486 result_handler._state = TERMINATE 487 outqueue.put(None) # sentinel 488 489 # We must wait for the worker handler to exit before terminating 490 # workers because we don't want workers to be restarted behind our back. 491 debug('joining worker handler') 492 if threading.current_thread() is not worker_handler: 493 worker_handler.join(1e100) 494 495 # Terminate workers which haven't already finished. 496 if pool and hasattr(pool[0], 'terminate'): 497 debug('terminating workers') 498 for p in pool: 499 if p.exitcode is None: 500 p.terminate() 501 502 debug('joining task handler') 503 if threading.current_thread() is not task_handler: 504 task_handler.join(1e100) 505 506 debug('joining result handler') 507 if threading.current_thread() is not result_handler: 508 result_handler.join(1e100) 509 510 if pool and hasattr(pool[0], 'terminate'): 511 debug('joining pool workers') 512 for p in pool: 513 if p.is_alive(): 514 # worker has not yet exited 515 debug('cleaning up worker %d' % p.pid) 516 p.join() 517 518 # 519 # Class whose instances are returned by `Pool.apply_async()` 520 # 521 522 class ApplyResult(object): 523 524 def __init__(self, cache, callback): 525 self._cond = threading.Condition(threading.Lock()) 526 self._job = job_counter.next() 527 self._cache = cache 528 self._ready = False 529 self._callback = callback 530 cache[self._job] = self 531 532 def ready(self): 533 return self._ready 534 535 def successful(self): 536 assert self._ready 537 return self._success 538 539 def wait(self, timeout=None): 540 self._cond.acquire() 541 try: 542 if not self._ready: 543 self._cond.wait(timeout) 544 finally: 545 self._cond.release() 546 547 def get(self, timeout=None): 548 self.wait(timeout) 549 if not self._ready: 550 raise TimeoutError 551 if self._success: 552 return self._value 553 else: 554 raise self._value 555 556 def _set(self, i, obj): 557 self._success, self._value = obj 558 if self._callback and self._success: 559 self._callback(self._value) 560 self._cond.acquire() 561 try: 562 self._ready = True 563 self._cond.notify() 564 finally: 565 self._cond.release() 566 del self._cache[self._job] 567 568 AsyncResult = ApplyResult # create alias -- see #17805 569 570 # 571 # Class whose instances are returned by `Pool.map_async()` 572 # 573 574 class MapResult(ApplyResult): 575 576 def __init__(self, cache, chunksize, length, callback): 577 ApplyResult.__init__(self, cache, callback) 578 self._success = True 579 self._value = [None] * length 580 self._chunksize = chunksize 581 if chunksize <= 0: 582 self._number_left = 0 583 self._ready = True 584 del cache[self._job] 585 else: 586 self._number_left = length//chunksize + bool(length % chunksize) 587 588 def _set(self, i, success_result): 589 success, result = success_result 590 if success: 591 self._value[i*self._chunksize:(i+1)*self._chunksize] = result 592 self._number_left -= 1 593 if self._number_left == 0: 594 if self._callback: 595 self._callback(self._value) 596 del self._cache[self._job] 597 self._cond.acquire() 598 try: 599 self._ready = True 600 self._cond.notify() 601 finally: 602 self._cond.release() 603 604 else: 605 self._success = False 606 self._value = result 607 del self._cache[self._job] 608 self._cond.acquire() 609 try: 610 self._ready = True 611 self._cond.notify() 612 finally: 613 self._cond.release() 614 615 # 616 # Class whose instances are returned by `Pool.imap()` 617 # 618 619 class IMapIterator(object): 620 621 def __init__(self, cache): 622 self._cond = threading.Condition(threading.Lock()) 623 self._job = job_counter.next() 624 self._cache = cache 625 self._items = collections.deque() 626 self._index = 0 627 self._length = None 628 self._unsorted = {} 629 cache[self._job] = self 630 631 def __iter__(self): 632 return self 633 634 def next(self, timeout=None): 635 self._cond.acquire() 636 try: 637 try: 638 item = self._items.popleft() 639 except IndexError: 640 if self._index == self._length: 641 raise StopIteration 642 self._cond.wait(timeout) 643 try: 644 item = self._items.popleft() 645 except IndexError: 646 if self._index == self._length: 647 raise StopIteration 648 raise TimeoutError 649 finally: 650 self._cond.release() 651 652 success, value = item 653 if success: 654 return value 655 raise value 656 657 __next__ = next # XXX 658 659 def _set(self, i, obj): 660 self._cond.acquire() 661 try: 662 if self._index == i: 663 self._items.append(obj) 664 self._index += 1 665 while self._index in self._unsorted: 666 obj = self._unsorted.pop(self._index) 667 self._items.append(obj) 668 self._index += 1 669 self._cond.notify() 670 else: 671 self._unsorted[i] = obj 672 673 if self._index == self._length: 674 del self._cache[self._job] 675 finally: 676 self._cond.release() 677 678 def _set_length(self, length): 679 self._cond.acquire() 680 try: 681 self._length = length 682 if self._index == self._length: 683 self._cond.notify() 684 del self._cache[self._job] 685 finally: 686 self._cond.release() 687 688 # 689 # Class whose instances are returned by `Pool.imap_unordered()` 690 # 691 692 class IMapUnorderedIterator(IMapIterator): 693 694 def _set(self, i, obj): 695 self._cond.acquire() 696 try: 697 self._items.append(obj) 698 self._index += 1 699 self._cond.notify() 700 if self._index == self._length: 701 del self._cache[self._job] 702 finally: 703 self._cond.release() 704 705 # 706 # 707 # 708 709 class ThreadPool(Pool): 710 711 from .dummy import Process 712 713 def __init__(self, processes=None, initializer=None, initargs=()): 714 Pool.__init__(self, processes, initializer, initargs) 715 716 def _setup_queues(self): 717 self._inqueue = Queue.Queue() 718 self._outqueue = Queue.Queue() 719 self._quick_put = self._inqueue.put 720 self._quick_get = self._outqueue.get 721 722 @staticmethod 723 def _help_stuff_finish(inqueue, task_handler, size): 724 # put sentinels at head of inqueue to make workers finish 725 inqueue.not_empty.acquire() 726 try: 727 inqueue.queue.clear() 728 inqueue.queue.extend([None] * size) 729 inqueue.not_empty.notify_all() 730 finally: 731 inqueue.not_empty.release() 732