diff --git a/plumpy/exceptions.py b/plumpy/exceptions.py index 40d3e12d..d7976b3a 100644 --- a/plumpy/exceptions.py +++ b/plumpy/exceptions.py @@ -33,3 +33,7 @@ class PersistenceError(Exception): class ClosedError(Exception): """Raised when an mutable operation is attempted on a closed process""" + + +class DuplicateProcess(Exception): + """Raised when an ProcessLauncher is asked to launch a process it is already running.""" diff --git a/plumpy/process_comms.py b/plumpy/process_comms.py index 3d9b6308..ecc75b95 100644 --- a/plumpy/process_comms.py +++ b/plumpy/process_comms.py @@ -4,12 +4,14 @@ import copy import logging from typing import Any, cast, Dict, Optional, Sequence, TYPE_CHECKING, Union +from weakref import WeakValueDictionary import kiwipy -from . import loaders +from . import exceptions from . import communications from . import futures +from . import loaders from . import persistence from .utils import PID_TYPE @@ -27,6 +29,7 @@ if TYPE_CHECKING: from .processes import Process # pylint: disable=cyclic-import + ProcessCacheType = WeakValueDictionary[PID_TYPE, Process] # pylint: disable=unsubscriptable-object ProcessResult = Any ProcessStatus = Any @@ -527,6 +530,20 @@ def __init__( else: self._loader = loaders.get_object_loader() + # using a weak reference ensures the processes can be garbage cleaned on completion + self._process_cache: 'ProcessCacheType' = WeakValueDictionary() + + @property + def process_cache(self) -> 'ProcessCacheType': + """Return a dictionary mapping PIDs to launched processes that are still in memory. + + The mapping uses a `WeakValueDictionary`, meaning that processes can be removed, + once they are no longer referenced anywhere else. + This means the dictionary will always contain all processes still running, + but potentially also processes that have terminated but have not yet been garbage collected. + """ + return copy.copy(self._process_cache) + async def __call__(self, communicator: kiwipy.Communicator, task: Dict[str, Any]) -> Union[PID_TYPE, ProcessResult]: """ Receive a task. @@ -571,15 +588,24 @@ async def _launch( init_kwargs = {} proc_class = self._loader.load_object(process_class) - proc = proc_class(*init_args, **init_kwargs) + proc: Process = proc_class(*init_args, **init_kwargs) + + if proc.pid in self._process_cache and not self._process_cache[proc.pid].has_terminated(): + raise exceptions.DuplicateProcess(f'Process<{proc.pid}> is already running') + if persist and self._persister is not None: self._persister.save_checkpoint(proc) + self._process_cache[proc.pid] = proc + if nowait: asyncio.ensure_future(proc.step_until_terminated()) return proc.pid - await proc.step_until_terminated() + try: + await proc.step_until_terminated() + finally: + self._process_cache.pop(proc.pid, None) return proc.future().result() @@ -602,15 +628,23 @@ async def _continue( LOGGER.warning('rejecting task: cannot continue process<%d> because no persister is available', pid) raise communications.TaskRejected('Cannot continue process, no persister') + if pid in self._process_cache and not self._process_cache[pid].has_terminated(): + raise exceptions.DuplicateProcess(f'Process<{pid}> is already running') + # Do not catch exceptions here, because if these operations fail, the continue task should except and bubble up saved_state = self._persister.load_checkpoint(pid, tag) proc = cast('Process', saved_state.unbundle(self._load_context)) + self._process_cache[proc.pid] = proc + if nowait: asyncio.ensure_future(proc.step_until_terminated()) return proc.pid - await proc.step_until_terminated() + try: + await proc.step_until_terminated() + finally: + self._process_cache.pop(proc.pid, None) return proc.future().result() diff --git a/test/rmq/test_communicator.py b/test/rmq/test_communicator.py index a4e8f27b..2ca9b3a3 100644 --- a/test/rmq/test_communicator.py +++ b/test/rmq/test_communicator.py @@ -6,8 +6,8 @@ import asyncio import shortuuid +from kiwipy import BroadcastFilter, RemoteException, rmq import pytest -from kiwipy import BroadcastFilter, rmq import plumpy from plumpy import communications, process_comms @@ -204,3 +204,15 @@ async def test_continue(self, loop_communicator, async_controller, persister): # Let the process run to the end result = await async_controller.continue_process(pid) assert result, utils.DummyProcessWithOutput.EXPECTED_OUTPUTS + + @pytest.mark.asyncio + async def test_duplicate_process(self, loop_communicator, async_controller, persister): + loop = asyncio.get_event_loop() + launcher = plumpy.ProcessLauncher(loop, persister=persister) + loop_communicator.add_task_subscriber(launcher) + process = utils.DummyProcessWithOutput() + persister.save_checkpoint(process) + launcher._process_cache[process.pid] = process + assert process.pid in launcher.process_cache + with pytest.raises(RemoteException, match='already running'): + await async_controller.continue_process(process.pid)