1 # 2 # Module providing the `Pool` class for managing a process pool 3 # 4 # multiprocessing/pool.py 5 # 6 # Copyright (c) 2006-2008, R Oudkerk 7 # Licensed to PSF under a Contributor Agreement. 8 # 9 10 __all__ = ['Pool', 'ThreadPool'] 11 12 # 13 # Imports 14 # 15 16 import threading 17 import queue 18 import itertools 19 import collections 20 import os 21 import time 22 import traceback 23 24 # If threading is available then ThreadPool should be provided. Therefore 25 # we avoid top-level imports which are liable to fail on some systems. 26 from . import util 27 from . import get_context, TimeoutError 28 29 # 30 # Constants representing the state of a pool 31 # 32 33 RUN = 0 34 CLOSE = 1 35 TERMINATE = 2 36 37 # 38 # Miscellaneous 39 # 40 41 job_counter = itertools.count() 42 43 def mapstar(args): 44 return list(map(*args)) 45 46 def starmapstar(args): 47 return list(itertools.starmap(args[0], args[1])) 48 49 # 50 # Hack to embed stringification of remote traceback in local traceback 51 # 52 53 class RemoteTraceback(Exception): 54 def __init__(self, tb): 55 self.tb = tb 56 def __str__(self): 57 return self.tb 58 59 class ExceptionWithTraceback: 60 def __init__(self, exc, tb): 61 tb = traceback.format_exception(type(exc), exc, tb) 62 tb = ''.join(tb) 63 self.exc = exc 64 self.tb = '\n"""\n%s"""' % tb 65 def __reduce__(self): 66 return rebuild_exc, (self.exc, self.tb) 67 68 def rebuild_exc(exc, tb): 69 exc.__cause__ = RemoteTraceback(tb) 70 return exc 71 72 # 73 # Code run by worker processes 74 # 75 76 class MaybeEncodingError(Exception): 77 """Wraps possible unpickleable errors, so they can be 78 safely sent through the socket.""" 79 80 def __init__(self, exc, value): 81 self.exc = repr(exc) 82 self.value = repr(value) 83 super(MaybeEncodingError, self).__init__(self.exc, self.value) 84 85 def __str__(self): 86 return "Error sending result: '%s'. Reason: '%s'" % (self.value, 87 self.exc) 88 89 def __repr__(self): 90 return "<%s: %s>" % (self.__class__.__name__, self) 91 92 93 def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None, 94 wrap_exception=False): 95 assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0) 96 put = outqueue.put 97 get = inqueue.get 98 if hasattr(inqueue, '_writer'): 99 inqueue._writer.close() 100 outqueue._reader.close() 101 102 if initializer is not None: 103 initializer(*initargs) 104 105 completed = 0 106 while maxtasks is None or (maxtasks and completed < maxtasks): 107 try: 108 task = get() 109 except (EOFError, OSError): 110 util.debug('worker got EOFError or OSError -- exiting') 111 break 112 113 if task is None: 114 util.debug('worker got sentinel -- exiting') 115 break 116 117 job, i, func, args, kwds = task 118 try: 119 result = (True, func(*args, **kwds)) 120 except Exception as e: 121 if wrap_exception: 122 e = ExceptionWithTraceback(e, e.__traceback__) 123 result = (False, e) 124 try: 125 put((job, i, result)) 126 except Exception as e: 127 wrapped = MaybeEncodingError(e, result[1]) 128 util.debug("Possible encoding error while sending result: %s" % ( 129 wrapped)) 130 put((job, i, (False, wrapped))) 131 completed += 1 132 util.debug('worker exiting after %d tasks' % completed) 133 134 # 135 # Class representing a process pool 136 # 137 138 class Pool(object): 139 ''' 140 Class which supports an async version of applying functions to arguments. 141 ''' 142 _wrap_exception = True 143 144 def Process(self, *args, **kwds): 145 return self._ctx.Process(*args, **kwds) 146 147 def __init__(self, processes=None, initializer=None, initargs=(), 148 maxtasksperchild=None, context=None): 149 self._ctx = context or get_context() 150 self._setup_queues() 151 self._taskqueue = queue.Queue() 152 self._cache = {} 153 self._state = RUN 154 self._maxtasksperchild = maxtasksperchild 155 self._initializer = initializer 156 self._initargs = initargs 157 158 if processes is None: 159 processes = os.cpu_count() or 1 160 if processes < 1: 161 raise ValueError("Number of processes must be at least 1") 162 163 if initializer is not None and not callable(initializer): 164 raise TypeError('initializer must be a callable') 165 166 self._processes = processes 167 self._pool = [] 168 self._repopulate_pool() 169 170 self._worker_handler = threading.Thread( 171 target=Pool._handle_workers, 172 args=(self, ) 173 ) 174 self._worker_handler.daemon = True 175 self._worker_handler._state = RUN 176 self._worker_handler.start() 177 178 179 self._task_handler = threading.Thread( 180 target=Pool._handle_tasks, 181 args=(self._taskqueue, self._quick_put, self._outqueue, 182 self._pool, self._cache) 183 ) 184 self._task_handler.daemon = True 185 self._task_handler._state = RUN 186 self._task_handler.start() 187 188 self._result_handler = threading.Thread( 189 target=Pool._handle_results, 190 args=(self._outqueue, self._quick_get, self._cache) 191 ) 192 self._result_handler.daemon = True 193 self._result_handler._state = RUN 194 self._result_handler.start() 195 196 self._terminate = util.Finalize( 197 self, self._terminate_pool, 198 args=(self._taskqueue, self._inqueue, self._outqueue, self._pool, 199 self._worker_handler, self._task_handler, 200 self._result_handler, self._cache), 201 exitpriority=15 202 ) 203 204 def _join_exited_workers(self): 205 """Cleanup after any worker processes which have exited due to reaching 206 their specified lifetime. Returns True if any workers were cleaned up. 207 """ 208 cleaned = False 209 for i in reversed(range(len(self._pool))): 210 worker = self._pool[i] 211 if worker.exitcode is not None: 212 # worker exited 213 util.debug('cleaning up worker %d' % i) 214 worker.join() 215 cleaned = True 216 del self._pool[i] 217 return cleaned 218 219 def _repopulate_pool(self): 220 """Bring the number of pool processes up to the specified number, 221 for use after reaping workers which have exited. 222 """ 223 for i in range(self._processes - len(self._pool)): 224 w = self.Process(target=worker, 225 args=(self._inqueue, self._outqueue, 226 self._initializer, 227 self._initargs, self._maxtasksperchild, 228 self._wrap_exception) 229 ) 230 self._pool.append(w) 231 w.name = w.name.replace('Process', 'PoolWorker') 232 w.daemon = True 233 w.start() 234 util.debug('added worker') 235 236 def _maintain_pool(self): 237 """Clean up any exited workers and start replacements for them. 238 """ 239 if self._join_exited_workers(): 240 self._repopulate_pool() 241 242 def _setup_queues(self): 243 self._inqueue = self._ctx.SimpleQueue() 244 self._outqueue = self._ctx.SimpleQueue() 245 self._quick_put = self._inqueue._writer.send 246 self._quick_get = self._outqueue._reader.recv 247 248 def apply(self, func, args=(), kwds={}): 249 ''' 250 Equivalent of `func(*args, **kwds)`. 251 ''' 252 assert self._state == RUN 253 return self.apply_async(func, args, kwds).get() 254 255 def map(self, func, iterable, chunksize=None): 256 ''' 257 Apply `func` to each element in `iterable`, collecting the results 258 in a list that is returned. 259 ''' 260 return self._map_async(func, iterable, mapstar, chunksize).get() 261 262 def starmap(self, func, iterable, chunksize=None): 263 ''' 264 Like `map()` method but the elements of the `iterable` are expected to 265 be iterables as well and will be unpacked as arguments. Hence 266 `func` and (a, b) becomes func(a, b). 267 ''' 268 return self._map_async(func, iterable, starmapstar, chunksize).get() 269 270 def starmap_async(self, func, iterable, chunksize=None, callback=None, 271 error_callback=None): 272 ''' 273 Asynchronous version of `starmap()` method. 274 ''' 275 return self._map_async(func, iterable, starmapstar, chunksize, 276 callback, error_callback) 277 278 def imap(self, func, iterable, chunksize=1): 279 ''' 280 Equivalent of `map()` -- can be MUCH slower than `Pool.map()`. 281 ''' 282 if self._state != RUN: 283 raise ValueError("Pool not running") 284 if chunksize == 1: 285 result = IMapIterator(self._cache) 286 self._taskqueue.put((((result._job, i, func, (x,), {}) 287 for i, x in enumerate(iterable)), result._set_length)) 288 return result 289 else: 290 assert chunksize > 1 291 task_batches = Pool._get_tasks(func, iterable, chunksize) 292 result = IMapIterator(self._cache) 293 self._taskqueue.put((((result._job, i, mapstar, (x,), {}) 294 for i, x in enumerate(task_batches)), result._set_length)) 295 return (item for chunk in result for item in chunk) 296 297 def imap_unordered(self, func, iterable, chunksize=1): 298 ''' 299 Like `imap()` method but ordering of results is arbitrary. 300 ''' 301 if self._state != RUN: 302 raise ValueError("Pool not running") 303 if chunksize == 1: 304 result = IMapUnorderedIterator(self._cache) 305 self._taskqueue.put((((result._job, i, func, (x,), {}) 306 for i, x in enumerate(iterable)), result._set_length)) 307 return result 308 else: 309 assert chunksize > 1 310 task_batches = Pool._get_tasks(func, iterable, chunksize) 311 result = IMapUnorderedIterator(self._cache) 312 self._taskqueue.put((((result._job, i, mapstar, (x,), {}) 313 for i, x in enumerate(task_batches)), result._set_length)) 314 return (item for chunk in result for item in chunk) 315 316 def apply_async(self, func, args=(), kwds={}, callback=None, 317 error_callback=None): 318 ''' 319 Asynchronous version of `apply()` method. 320 ''' 321 if self._state != RUN: 322 raise ValueError("Pool not running") 323 result = ApplyResult(self._cache, callback, error_callback) 324 self._taskqueue.put(([(result._job, None, func, args, kwds)], None)) 325 return result 326 327 def map_async(self, func, iterable, chunksize=None, callback=None, 328 error_callback=None): 329 ''' 330 Asynchronous version of `map()` method. 331 ''' 332 return self._map_async(func, iterable, mapstar, chunksize, callback, 333 error_callback) 334 335 def _map_async(self, func, iterable, mapper, chunksize=None, callback=None, 336 error_callback=None): 337 ''' 338 Helper function to implement map, starmap and their async counterparts. 339 ''' 340 if self._state != RUN: 341 raise ValueError("Pool not running") 342 if not hasattr(iterable, '__len__'): 343 iterable = list(iterable) 344 345 if chunksize is None: 346 chunksize, extra = divmod(len(iterable), len(self._pool) * 4) 347 if extra: 348 chunksize += 1 349 if len(iterable) == 0: 350 chunksize = 0 351 352 task_batches = Pool._get_tasks(func, iterable, chunksize) 353 result = MapResult(self._cache, chunksize, len(iterable), callback, 354 error_callback=error_callback) 355 self._taskqueue.put((((result._job, i, mapper, (x,), {}) 356 for i, x in enumerate(task_batches)), None)) 357 return result 358 359 @staticmethod 360 def _handle_workers(pool): 361 thread = threading.current_thread() 362 363 # Keep maintaining workers until the cache gets drained, unless the pool 364 # is terminated. 365 while thread._state == RUN or (pool._cache and thread._state != TERMINATE): 366 pool._maintain_pool() 367 time.sleep(0.1) 368 # send sentinel to stop workers 369 pool._taskqueue.put(None) 370 util.debug('worker handler exiting') 371 372 @staticmethod 373 def _handle_tasks(taskqueue, put, outqueue, pool, cache): 374 thread = threading.current_thread() 375 376 for taskseq, set_length in iter(taskqueue.get, None): 377 task = None 378 i = -1 379 try: 380 for i, task in enumerate(taskseq): 381 if thread._state: 382 util.debug('task handler found thread._state != RUN') 383 break 384 try: 385 put(task) 386 except Exception as e: 387 job, ind = task[:2] 388 try: 389 cache[job]._set(ind, (False, e)) 390 except KeyError: 391 pass 392 else: 393 if set_length: 394 util.debug('doing set_length()') 395 set_length(i+1) 396 continue 397 break 398 except Exception as ex: 399 job, ind = task[:2] if task else (0, 0) 400 if job in cache: 401 cache[job]._set(ind + 1, (False, ex)) 402 if set_length: 403 util.debug('doing set_length()') 404 set_length(i+1) 405 else: 406 util.debug('task handler got sentinel') 407 408 409 try: 410 # tell result handler to finish when cache is empty 411 util.debug('task handler sending sentinel to result handler') 412 outqueue.put(None) 413 414 # tell workers there is no more work 415 util.debug('task handler sending sentinel to workers') 416 for p in pool: 417 put(None) 418 except OSError: 419 util.debug('task handler got OSError when sending sentinels') 420 421 util.debug('task handler exiting') 422 423 @staticmethod 424 def _handle_results(outqueue, get, cache): 425 thread = threading.current_thread() 426 427 while 1: 428 try: 429 task = get() 430 except (OSError, EOFError): 431 util.debug('result handler got EOFError/OSError -- exiting') 432 return 433 434 if thread._state: 435 assert thread._state == TERMINATE 436 util.debug('result handler found thread._state=TERMINATE') 437 break 438 439 if task is None: 440 util.debug('result handler got sentinel') 441 break 442 443 job, i, obj = task 444 try: 445 cache[job]._set(i, obj) 446 except KeyError: 447 pass 448 449 while cache and thread._state != TERMINATE: 450 try: 451 task = get() 452 except (OSError, EOFError): 453 util.debug('result handler got EOFError/OSError -- exiting') 454 return 455 456 if task is None: 457 util.debug('result handler ignoring extra sentinel') 458 continue 459 job, i, obj = task 460 try: 461 cache[job]._set(i, obj) 462 except KeyError: 463 pass 464 465 if hasattr(outqueue, '_reader'): 466 util.debug('ensuring that outqueue is not full') 467 # If we don't make room available in outqueue then 468 # attempts to add the sentinel (None) to outqueue may 469 # block. There is guaranteed to be no more than 2 sentinels. 470 try: 471 for i in range(10): 472 if not outqueue._reader.poll(): 473 break 474 get() 475 except (OSError, EOFError): 476 pass 477 478 util.debug('result handler exiting: len(cache)=%s, thread._state=%s', 479 len(cache), thread._state) 480 481 @staticmethod 482 def _get_tasks(func, it, size): 483 it = iter(it) 484 while 1: 485 x = tuple(itertools.islice(it, size)) 486 if not x: 487 return 488 yield (func, x) 489 490 def __reduce__(self): 491 raise NotImplementedError( 492 'pool objects cannot be passed between processes or pickled' 493 ) 494 495 def close(self): 496 util.debug('closing pool') 497 if self._state == RUN: 498 self._state = CLOSE 499 self._worker_handler._state = CLOSE 500 501 def terminate(self): 502 util.debug('terminating pool') 503 self._state = TERMINATE 504 self._worker_handler._state = TERMINATE 505 self._terminate() 506 507 def join(self): 508 util.debug('joining pool') 509 assert self._state in (CLOSE, TERMINATE) 510 self._worker_handler.join() 511 self._task_handler.join() 512 self._result_handler.join() 513 for p in self._pool: 514 p.join() 515 516 @staticmethod 517 def _help_stuff_finish(inqueue, task_handler, size): 518 # task_handler may be blocked trying to put items on inqueue 519 util.debug('removing tasks from inqueue until task handler finished') 520 inqueue._rlock.acquire() 521 while task_handler.is_alive() and inqueue._reader.poll(): 522 inqueue._reader.recv() 523 time.sleep(0) 524 525 @classmethod 526 def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool, 527 worker_handler, task_handler, result_handler, cache): 528 # this is guaranteed to only be called once 529 util.debug('finalizing pool') 530 531 worker_handler._state = TERMINATE 532 task_handler._state = TERMINATE 533 534 util.debug('helping task handler/workers to finish') 535 cls._help_stuff_finish(inqueue, task_handler, len(pool)) 536 537 assert result_handler.is_alive() or len(cache) == 0 538 539 result_handler._state = TERMINATE 540 outqueue.put(None) # sentinel 541 542 # We must wait for the worker handler to exit before terminating 543 # workers because we don't want workers to be restarted behind our back. 544 util.debug('joining worker handler') 545 if threading.current_thread() is not worker_handler: 546 worker_handler.join() 547 548 # Terminate workers which haven't already finished. 549 if pool and hasattr(pool[0], 'terminate'): 550 util.debug('terminating workers') 551 for p in pool: 552 if p.exitcode is None: 553 p.terminate() 554 555 util.debug('joining task handler') 556 if threading.current_thread() is not task_handler: 557 task_handler.join() 558 559 util.debug('joining result handler') 560 if threading.current_thread() is not result_handler: 561 result_handler.join() 562 563 if pool and hasattr(pool[0], 'terminate'): 564 util.debug('joining pool workers') 565 for p in pool: 566 if p.is_alive(): 567 # worker has not yet exited 568 util.debug('cleaning up worker %d' % p.pid) 569 p.join() 570 571 def __enter__(self): 572 return self 573 574 def __exit__(self, exc_type, exc_val, exc_tb): 575 self.terminate() 576 577 # 578 # Class whose instances are returned by `Pool.apply_async()` 579 # 580 581 class ApplyResult(object): 582 583 def __init__(self, cache, callback, error_callback): 584 self._event = threading.Event() 585 self._job = next(job_counter) 586 self._cache = cache 587 self._callback = callback 588 self._error_callback = error_callback 589 cache[self._job] = self 590 591 def ready(self): 592 return self._event.is_set() 593 594 def successful(self): 595 assert self.ready() 596 return self._success 597 598 def wait(self, timeout=None): 599 self._event.wait(timeout) 600 601 def get(self, timeout=None): 602 self.wait(timeout) 603 if not self.ready(): 604 raise TimeoutError 605 if self._success: 606 return self._value 607 else: 608 raise self._value 609 610 def _set(self, i, obj): 611 self._success, self._value = obj 612 if self._callback and self._success: 613 self._callback(self._value) 614 if self._error_callback and not self._success: 615 self._error_callback(self._value) 616 self._event.set() 617 del self._cache[self._job] 618 619 AsyncResult = ApplyResult # create alias -- see #17805 620 621 # 622 # Class whose instances are returned by `Pool.map_async()` 623 # 624 625 class MapResult(ApplyResult): 626 627 def __init__(self, cache, chunksize, length, callback, error_callback): 628 ApplyResult.__init__(self, cache, callback, 629 error_callback=error_callback) 630 self._success = True 631 self._value = [None] * length 632 self._chunksize = chunksize 633 if chunksize <= 0: 634 self._number_left = 0 635 self._event.set() 636 del cache[self._job] 637 else: 638 self._number_left = length//chunksize + bool(length % chunksize) 639 640 def _set(self, i, success_result): 641 self._number_left -= 1 642 success, result = success_result 643 if success and self._success: 644 self._value[i*self._chunksize:(i+1)*self._chunksize] = result 645 if self._number_left == 0: 646 if self._callback: 647 self._callback(self._value) 648 del self._cache[self._job] 649 self._event.set() 650 else: 651 if not success and self._success: 652 # only store first exception 653 self._success = False 654 self._value = result 655 if self._number_left == 0: 656 # only consider the result ready once all jobs are done 657 if self._error_callback: 658 self._error_callback(self._value) 659 del self._cache[self._job] 660 self._event.set() 661 662 # 663 # Class whose instances are returned by `Pool.imap()` 664 # 665 666 class IMapIterator(object): 667 668 def __init__(self, cache): 669 self._cond = threading.Condition(threading.Lock()) 670 self._job = next(job_counter) 671 self._cache = cache 672 self._items = collections.deque() 673 self._index = 0 674 self._length = None 675 self._unsorted = {} 676 cache[self._job] = self 677 678 def __iter__(self): 679 return self 680 681 def next(self, timeout=None): 682 with self._cond: 683 try: 684 item = self._items.popleft() 685 except IndexError: 686 if self._index == self._length: 687 raise StopIteration 688 self._cond.wait(timeout) 689 try: 690 item = self._items.popleft() 691 except IndexError: 692 if self._index == self._length: 693 raise StopIteration 694 raise TimeoutError 695 696 success, value = item 697 if success: 698 return value 699 raise value 700 701 __next__ = next # XXX 702 703 def _set(self, i, obj): 704 with self._cond: 705 if self._index == i: 706 self._items.append(obj) 707 self._index += 1 708 while self._index in self._unsorted: 709 obj = self._unsorted.pop(self._index) 710 self._items.append(obj) 711 self._index += 1 712 self._cond.notify() 713 else: 714 self._unsorted[i] = obj 715 716 if self._index == self._length: 717 del self._cache[self._job] 718 719 def _set_length(self, length): 720 with self._cond: 721 self._length = length 722 if self._index == self._length: 723 self._cond.notify() 724 del self._cache[self._job] 725 726 # 727 # Class whose instances are returned by `Pool.imap_unordered()` 728 # 729 730 class IMapUnorderedIterator(IMapIterator): 731 732 def _set(self, i, obj): 733 with self._cond: 734 self._items.append(obj) 735 self._index += 1 736 self._cond.notify() 737 if self._index == self._length: 738 del self._cache[self._job] 739 740 # 741 # 742 # 743 744 class ThreadPool(Pool): 745 _wrap_exception = False 746 747 @staticmethod 748 def Process(*args, **kwds): 749 from .dummy import Process 750 return Process(*args, **kwds) 751 752 def __init__(self, processes=None, initializer=None, initargs=()): 753 Pool.__init__(self, processes, initializer, initargs) 754 755 def _setup_queues(self): 756 self._inqueue = queue.Queue() 757 self._outqueue = queue.Queue() 758 self._quick_put = self._inqueue.put 759 self._quick_get = self._outqueue.get 760 761 @staticmethod 762 def _help_stuff_finish(inqueue, task_handler, size): 763 # put sentinels at head of inqueue to make workers finish 764 with inqueue.not_empty: 765 inqueue.queue.clear() 766 inqueue.queue.extend([None] * size) 767 inqueue.not_empty.notify_all() 768