From 72505e968a4aa38e7d0e33d5be29b7aa5a1b3121 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 2 Mar 2021 21:23:52 +0100 Subject: [PATCH 1/2] =?UTF-8?q?=E2=9C=A8=20NEW:=20Add=20`ProcessLauncher.p?= =?UTF-8?q?rocess=5Fcache`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- plumpy/exceptions.py | 4 ++++ plumpy/process_comms.py | 32 ++++++++++++++++++++++++++++++-- test/rmq/test_communicator.py | 14 +++++++++++++- 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/plumpy/exceptions.py b/plumpy/exceptions.py index 40d3e12d..d7976b3a 100644 --- a/plumpy/exceptions.py +++ b/plumpy/exceptions.py @@ -33,3 +33,7 @@ class PersistenceError(Exception): class ClosedError(Exception): """Raised when an mutable operation is attempted on a closed process""" + + +class DuplicateProcess(Exception): + """Raised when an ProcessLauncher is asked to launch a process it is already running.""" diff --git a/plumpy/process_comms.py b/plumpy/process_comms.py index 3d9b6308..57e02755 100644 --- a/plumpy/process_comms.py +++ b/plumpy/process_comms.py @@ -4,12 +4,14 @@ import copy import logging from typing import Any, cast, Dict, Optional, Sequence, TYPE_CHECKING, Union +from weakref import WeakValueDictionary import kiwipy -from . import loaders +from . import exceptions from . import communications from . import futures +from . import loaders from . import persistence from .utils import PID_TYPE @@ -27,6 +29,7 @@ if TYPE_CHECKING: from .processes import Process # pylint: disable=cyclic-import + ProcessCacheType = WeakValueDictionary[PID_TYPE, Process] # pylint: disable=unsubscriptable-object ProcessResult = Any ProcessStatus = Any @@ -527,6 +530,20 @@ def __init__( else: self._loader = loaders.get_object_loader() + # using a weak reference ensures the processes can be garbage cleaned on completion + self._process_cache: 'ProcessCacheType' = WeakValueDictionary() + + @property + def process_cache(self) -> 'ProcessCacheType': + """Return a dictionary mapping PIDs to launched processes that are still in memory. + + The mapping uses a `WeakValueDictionary`, meaning that processes can be removed, + once they are no longer referenced anywhere else. + This means the dictionary will always contain all processes still running, + but potentially also processes that have terminated but have not yet been garbage collected. + """ + return copy.copy(self._process_cache) + async def __call__(self, communicator: kiwipy.Communicator, task: Dict[str, Any]) -> Union[PID_TYPE, ProcessResult]: """ Receive a task. @@ -571,10 +588,16 @@ async def _launch( init_kwargs = {} proc_class = self._loader.load_object(process_class) - proc = proc_class(*init_args, **init_kwargs) + proc: Process = proc_class(*init_args, **init_kwargs) + + if proc.pid in self._process_cache and not self._process_cache[proc.pid].has_terminated(): + raise exceptions.DuplicateProcess(f'Process<{proc.pid}> is already running') + if persist and self._persister is not None: self._persister.save_checkpoint(proc) + self._process_cache[proc.pid] = proc + if nowait: asyncio.ensure_future(proc.step_until_terminated()) return proc.pid @@ -602,10 +625,15 @@ async def _continue( LOGGER.warning('rejecting task: cannot continue process<%d> because no persister is available', pid) raise communications.TaskRejected('Cannot continue process, no persister') + if pid in self._process_cache and not self._process_cache[pid].has_terminated(): + raise exceptions.DuplicateProcess(f'Process<{pid}> is already running') + # Do not catch exceptions here, because if these operations fail, the continue task should except and bubble up saved_state = self._persister.load_checkpoint(pid, tag) proc = cast('Process', saved_state.unbundle(self._load_context)) + self._process_cache[proc.pid] = proc + if nowait: asyncio.ensure_future(proc.step_until_terminated()) return proc.pid diff --git a/test/rmq/test_communicator.py b/test/rmq/test_communicator.py index 5e50380e..29df3177 100644 --- a/test/rmq/test_communicator.py +++ b/test/rmq/test_communicator.py @@ -7,7 +7,7 @@ import shortuuid import pytest -from kiwipy import rmq +from kiwipy import RemoteException, rmq import plumpy from plumpy import communications, process_comms @@ -177,3 +177,15 @@ async def test_continue(self, loop_communicator, async_controller, persister): # Let the process run to the end result = await async_controller.continue_process(pid) assert result, utils.DummyProcessWithOutput.EXPECTED_OUTPUTS + + @pytest.mark.asyncio + async def test_duplicate_process(self, loop_communicator, async_controller, persister): + loop = asyncio.get_event_loop() + launcher = plumpy.ProcessLauncher(loop, persister=persister) + loop_communicator.add_task_subscriber(launcher) + process = utils.DummyProcessWithOutput() + persister.save_checkpoint(process) + launcher._process_cache[process.pid] = process + assert process.pid in launcher.process_cache + with pytest.raises(RemoteException, match='already running'): + await async_controller.continue_process(process.pid) From 2b6ef68e60f4209a10efd7dd92a371781cbf36f0 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Thu, 4 Mar 2021 21:39:40 +0100 Subject: [PATCH 2/2] remove process from cache after step_until_terminated --- plumpy/process_comms.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/plumpy/process_comms.py b/plumpy/process_comms.py index 57e02755..ecc75b95 100644 --- a/plumpy/process_comms.py +++ b/plumpy/process_comms.py @@ -602,7 +602,10 @@ async def _launch( asyncio.ensure_future(proc.step_until_terminated()) return proc.pid - await proc.step_until_terminated() + try: + await proc.step_until_terminated() + finally: + self._process_cache.pop(proc.pid, None) return proc.future().result() @@ -638,7 +641,10 @@ async def _continue( asyncio.ensure_future(proc.step_until_terminated()) return proc.pid - await proc.step_until_terminated() + try: + await proc.step_until_terminated() + finally: + self._process_cache.pop(proc.pid, None) return proc.future().result()