bpo-22393: Fix multiprocessing.Pool hangs if a worker process dies unexpectedly by oesteban · Pull Request #10441 · python/cpython
# # Miscellaneous #
job_counter = itertools.count()
def mapstar(args): return list(map(*args))
def starmapstar(args): return list(itertools.starmap(args[0], args[1]))
class BrokenProcessPool(RuntimeError): """ Raised when a process in a ProcessPoolExecutor terminated abruptly while a future was in the running state. """
# # Hack to embed stringification of remote traceback in local traceback #
util.debug('worker started') completed = 0 while maxtasks is None or (maxtasks and completed < maxtasks): try:
def _join_exited_workers(self): """Cleanup after any worker processes which have exited due to reaching their specified lifetime. Returns True if any workers were cleaned up. their specified lifetime. Returns the number of workers that were cleaned up. Returns None if the process pool is broken. """ cleaned = False for i in reversed(range(len(self._pool))): worker = self._pool[i] if worker.exitcode is not None: cleaned = 0 broken = False for i, p in reversed(list(enumerate(self._pool))): broken = broken or (p.exitcode not in (None, 0)) if p.exitcode is not None: # worker exited util.debug('cleaning up worker %d' % i) worker.join() cleaned = True p.join() cleaned += 1 del self._pool[i]
if broken: # Stop all workers util.info('worker handler: process pool is broken, terminating workers...') for p in self._pool: if p.exitcode is None: p.terminate() for p in self._pool: p.join() del self._pool[:] return None return cleaned
def _repopulate_pool(self):
def _maintain_pool(self): """Clean up any exited workers and start replacements for them. """ if self._join_exited_workers(): """Clean up any exited workers and start replacements for them.""" need_repopulate = self._join_exited_workers() if need_repopulate: self._repopulate_pool()
if need_repopulate is None: with self._worker_state_lock: self._worker_handler._state = BROKEN
err = BrokenProcessPool( 'A worker in the pool terminated abruptly.') # Exhaust MapResult with errors for i, cache_ent in list(self._cache.items()): cache_ent._set_all((False, err))
def _setup_queues(self): self._inqueue = self._ctx.SimpleQueue() self._outqueue = self._ctx.SimpleQueue()
# Keep maintaining workers until the cache gets drained, unless the pool # is terminated.
for taskseq, set_length in iter(taskqueue.get, None): task = None
@staticmethod def _handle_results(outqueue, get, cache): util.debug('result handler entering') thread = threading.current_thread()
while 1:
def terminate(self): util.debug('terminating pool')
worker_handler._state = TERMINATE task_handler._state = TERMINATE
util.debug('helping task handler/workers to finish') cls._help_stuff_finish(inqueue, task_handler, len(pool)) # Skip _help_finish_stuff if the pool is broken, because # the broken process may have been holding the inqueue lock. if not is_broken: util.debug('helping task handler/workers to finish') cls._help_stuff_finish(inqueue, task_handler, len(pool)) else: util.debug('finishing BROKEN process pool')
if (not result_handler.is_alive()) and (len(cache) != 0): raise AssertionError(
# We must wait for the worker handler to exit before terminating # workers because we don't want workers to be restarted behind our back. util.debug('joining worker handler') if threading.current_thread() is not worker_handler: util.debug('joining worker handler') worker_handler.join()
# Terminate workers which haven't already finished.
util.debug('joining task handler') if threading.current_thread() is not task_handler: util.debug('joining task handler') task_handler.join()
util.debug('joining result handler') if threading.current_thread() is not result_handler: util.debug('joining result handler') result_handler.join()
if pool and hasattr(pool[0], 'terminate'):
def __enter__(self): self._check_running()
def _set_all(self, obj): self._set(0, obj)
AsyncResult = ApplyResult # create alias -- see #17805
#
def _set_all(self, obj): item = 0 while self._number_left > 0: self._set(item, obj) item += 1
# # Class whose instances are returned by `Pool.imap()` #
def _set_all(self, obj): while self._index != self._length: self._set(self._index, obj)
def _set_length(self, length): with self._cond: self._length = length