diff --git a/aiida/manage/tests/pytest_fixtures.py b/aiida/manage/tests/pytest_fixtures.py index fef23ab0d5..cdbf487582 100644 --- a/aiida/manage/tests/pytest_fixtures.py +++ b/aiida/manage/tests/pytest_fixtures.py @@ -13,6 +13,7 @@ import asyncio import copy +import inspect import pathlib import shutil import tempfile @@ -288,7 +289,8 @@ def submit_and_await(started_daemon_client): def _factory( submittable: Process | ProcessBuilder | ProcessNode, state: plumpy.ProcessState = plumpy.ProcessState.FINISHED, - timeout: int = 20 + timeout: int = 20, + **kwargs ): """Submit a process and wait for it to achieve the given state. @@ -296,12 +298,17 @@ def _factory( submitted first before awaiting the desired state. :param state: The process state to wait for, by default it waits for the submittable to be ``FINISHED``. :param timeout: The time to wait for the process to achieve the state. + :param kwargs: If the ``submittable`` is a process class, it is instantiated with the ``kwargs`` as inputs. :raises RuntimeError: If the process fails to achieve the specified state before the timeout expires. """ - if not isinstance(submittable, ProcessNode): + if inspect.isclass(submittable) and issubclass(submittable, Process): + node = submit(submittable, **kwargs) + elif isinstance(submittable, ProcessBuilder): node = submit(submittable) - else: + elif isinstance(submittable, ProcessNode): node = submittable + else: + raise ValueError(f'type of submittable `{type(submittable)}` is not supported.') start_time = time.time()