◐ Shell
clean mode source ↗

bpo-22393: Fix multiprocessing.Pool hangs if a worker process dies unexpectedly by oesteban · Pull Request #10441 · python/cpython

Expand Up @@ -33,19 +33,29 @@ RUN = "RUN" CLOSE = "CLOSE" TERMINATE = "TERMINATE" BROKEN = "BROKEN"
# # Miscellaneous #
job_counter = itertools.count()

def mapstar(args): return list(map(*args))

def starmapstar(args): return list(itertools.starmap(args[0], args[1]))

class BrokenProcessPool(RuntimeError): """ Raised when a process in a ProcessPoolExecutor terminated abruptly while a future was in the running state. """
# # Hack to embed stringification of remote traceback in local traceback # Expand Down Expand Up @@ -104,6 +114,7 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None, if initializer is not None: initializer(*initargs)
util.debug('worker started') completed = 0 while maxtasks is None or (maxtasks and completed < maxtasks): try: Expand Down Expand Up @@ -189,6 +200,7 @@ def __init__(self, processes=None, initializer=None, initargs=(), ) self._worker_handler.daemon = True self._worker_handler._state = RUN self._worker_state_lock = self._ctx.Lock() self._worker_handler.start()

Expand Down Expand Up @@ -225,17 +237,31 @@ def __repr__(self):
def _join_exited_workers(self): """Cleanup after any worker processes which have exited due to reaching their specified lifetime. Returns True if any workers were cleaned up. their specified lifetime. Returns the number of workers that were cleaned up. Returns None if the process pool is broken. """ cleaned = False for i in reversed(range(len(self._pool))): worker = self._pool[i] if worker.exitcode is not None: cleaned = 0 broken = False for i, p in reversed(list(enumerate(self._pool))): broken = broken or (p.exitcode not in (None, 0)) if p.exitcode is not None: # worker exited util.debug('cleaning up worker %d' % i) worker.join() cleaned = True p.join() cleaned += 1 del self._pool[i]
if broken: # Stop all workers util.info('worker handler: process pool is broken, terminating workers...') for p in self._pool: if p.exitcode is None: p.terminate() for p in self._pool: p.join() del self._pool[:] return None return cleaned
def _repopulate_pool(self): Expand All @@ -256,11 +282,21 @@ def _repopulate_pool(self): util.debug('added worker')
def _maintain_pool(self): """Clean up any exited workers and start replacements for them. """ if self._join_exited_workers(): """Clean up any exited workers and start replacements for them.""" need_repopulate = self._join_exited_workers() if need_repopulate: self._repopulate_pool()
if need_repopulate is None: with self._worker_state_lock: self._worker_handler._state = BROKEN
err = BrokenProcessPool( 'A worker in the pool terminated abruptly.') # Exhaust MapResult with errors for i, cache_ent in list(self._cache.items()): cache_ent._set_all((False, err))
def _setup_queues(self): self._inqueue = self._ctx.SimpleQueue() self._outqueue = self._ctx.SimpleQueue() Expand Down Expand Up @@ -419,6 +455,7 @@ def _map_async(self, func, iterable, mapper, chunksize=None, callback=None, @staticmethod def _handle_workers(pool): thread = threading.current_thread() util.debug('worker handler entering')
# Keep maintaining workers until the cache gets drained, unless the pool # is terminated. Expand All @@ -432,6 +469,7 @@ def _handle_workers(pool): @staticmethod def _handle_tasks(taskqueue, put, outqueue, pool, cache): thread = threading.current_thread() util.debug('task handler entering')
for taskseq, set_length in iter(taskqueue.get, None): task = None Expand Down Expand Up @@ -477,6 +515,7 @@ def _handle_tasks(taskqueue, put, outqueue, pool, cache):
@staticmethod def _handle_results(outqueue, get, cache): util.debug('result handler entering') thread = threading.current_thread()
while 1: Expand Down Expand Up @@ -553,7 +592,10 @@ def close(self): util.debug('closing pool') if self._state == RUN: self._state = CLOSE self._worker_handler._state = CLOSE # Avert race condition in broken pools with self._worker_state_lock: if self._worker_handler._state != BROKEN: self._worker_handler._state = CLOSE
def terminate(self): util.debug('terminating pool') Expand Down Expand Up @@ -586,13 +628,21 @@ def _help_stuff_finish(inqueue, task_handler, size): def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool, worker_handler, task_handler, result_handler, cache): # this is guaranteed to only be called once util.debug('finalizing pool') util.debug('terminate pool entering') is_broken = BROKEN in (task_handler._state, worker_handler._state, result_handler._state)
worker_handler._state = TERMINATE task_handler._state = TERMINATE
util.debug('helping task handler/workers to finish') cls._help_stuff_finish(inqueue, task_handler, len(pool)) # Skip _help_finish_stuff if the pool is broken, because # the broken process may have been holding the inqueue lock. if not is_broken: util.debug('helping task handler/workers to finish') cls._help_stuff_finish(inqueue, task_handler, len(pool)) else: util.debug('finishing BROKEN process pool')
if (not result_handler.is_alive()) and (len(cache) != 0): raise AssertionError( Expand All @@ -603,8 +653,8 @@ def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool,
# We must wait for the worker handler to exit before terminating # workers because we don't want workers to be restarted behind our back. util.debug('joining worker handler') if threading.current_thread() is not worker_handler: util.debug('joining worker handler') worker_handler.join()
# Terminate workers which haven't already finished. Expand All @@ -614,12 +664,12 @@ def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool, if p.exitcode is None: p.terminate()
util.debug('joining task handler') if threading.current_thread() is not task_handler: util.debug('joining task handler') task_handler.join()
util.debug('joining result handler') if threading.current_thread() is not result_handler: util.debug('joining result handler') result_handler.join()
if pool and hasattr(pool[0], 'terminate'): Expand All @@ -629,6 +679,7 @@ def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool, # worker has not yet exited util.debug('cleaning up worker %d' % p.pid) p.join() util.debug('terminate pool finished')
def __enter__(self): self._check_running() Expand Down Expand Up @@ -680,6 +731,9 @@ def _set(self, i, obj): self._event.set() del self._cache[self._job]
def _set_all(self, obj): self._set(0, obj)
AsyncResult = ApplyResult # create alias -- see #17805
# Expand Down Expand Up @@ -723,6 +777,12 @@ def _set(self, i, success_result): del self._cache[self._job] self._event.set()
def _set_all(self, obj): item = 0 while self._number_left > 0: self._set(item, obj) item += 1
# # Class whose instances are returned by `Pool.imap()` # Expand Down Expand Up @@ -780,6 +840,10 @@ def _set(self, i, obj): if self._index == self._length: del self._cache[self._job]
def _set_all(self, obj): while self._index != self._length: self._set(self._index, obj)
def _set_length(self, length): with self._cond: self._length = length Expand Down