From 89100740d98d8d9e0a9d1a0db9b5055b15e4e4ab Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Wed, 18 Aug 2021 20:27:45 +0200 Subject: [PATCH] Engine: add `CalcJobImporter` class and associated entry point group The `CalcJobImporter` class is added, which defines a single abstract staticmethod `parse_remote_data`. The idea is that plugins can define an importer for a `CalcJob` implementation and implement this method. The method takes a `RemoteData` node that points to a path on the associated computer that contains the input and output files of a calculation that has been run outside of AiiDA, but by an executable that is normally run with this particular `CalcJob`. The `parse_remote_data` implementation should read the input files found in the remote data and parse their content into the input nodes that when used to launch the calculation job, would result in similar input files. These inputs, including the `RemoteData` as the `remote_folder` input, can then be used to run an instance of this particular `CalcJob`. The engine will recognize the `remote_folder` input, signalling an import job, and instead of running a normal job that creates the input files on the remote before submitting it to the scheduler, it passes straight to the retrieve step. This will retrieve the files from the `RemoteData` as if it would have been created by the job itself. If a parsers was defined in the inputs, the contents are parsed and the returned output nodes are attached. The `CalcJobImporter` can be loaded through its entry point name using the `CalcJobImporterFactory`, just like the entry points of all other entry point groups have their associated factory. As a shortcut, the `CalcJob` class, provides the `get_importer` class method which will attempt to load a `CalcJobImporter` class with the exact same entry point. Alternatively, the caller can specify the desired entry point name should it not correspond to that of the `CalcJob` class. To test the functionality, a `CalcJobImporter` is implemented for the `ArithmeticAddCalculation` class. --- aiida/calculations/arithmetic/add.py | 4 +- aiida/calculations/importers/__init__.py | 0 .../importers/arithmetic/__init__.py | 0 .../calculations/importers/arithmetic/add.py | 39 +++++++ aiida/engine/__init__.py | 1 + aiida/engine/processes/__init__.py | 1 + aiida/engine/processes/calcjobs/__init__.py | 2 + aiida/engine/processes/calcjobs/calcjob.py | 54 +++++++-- aiida/engine/processes/calcjobs/importer.py | 21 ++++ .../orm/nodes/process/calculation/calcjob.py | 7 ++ aiida/plugins/__init__.py | 1 + aiida/plugins/factories.py | 25 +++- docs/source/nitpick-exceptions | 1 + setup.json | 3 + .../importers/arithmetic/test_add.py | 20 ++++ .../processes/calcjobs/test_calc_job.py | 18 ++- tests/plugins/test_factories.py | 108 +++++++++++------- 17 files changed, 247 insertions(+), 58 deletions(-) create mode 100644 aiida/calculations/importers/__init__.py create mode 100644 aiida/calculations/importers/arithmetic/__init__.py create mode 100644 aiida/calculations/importers/arithmetic/add.py create mode 100644 aiida/engine/processes/calcjobs/importer.py create mode 100644 tests/calculations/importers/arithmetic/test_add.py diff --git a/aiida/calculations/arithmetic/add.py b/aiida/calculations/arithmetic/add.py index 1f983f0227..be3e5e8c61 100644 --- a/aiida/calculations/arithmetic/add.py +++ b/aiida/calculations/arithmetic/add.py @@ -53,10 +53,12 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo: handle.write(f'echo $(({self.inputs.x.value} + {self.inputs.y.value}))\n') codeinfo = CodeInfo() - codeinfo.code_uuid = self.inputs.code.uuid codeinfo.stdin_name = self.options.input_filename codeinfo.stdout_name = self.options.output_filename + if 'code' in self.inputs: + codeinfo.code_uuid = self.inputs.code.uuid + calcinfo = CalcInfo() calcinfo.codes_info = [codeinfo] calcinfo.retrieve_list = [self.options.output_filename] diff --git a/aiida/calculations/importers/__init__.py b/aiida/calculations/importers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/aiida/calculations/importers/arithmetic/__init__.py b/aiida/calculations/importers/arithmetic/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/aiida/calculations/importers/arithmetic/add.py b/aiida/calculations/importers/arithmetic/add.py new file mode 100644 index 0000000000..5890ae8109 --- /dev/null +++ b/aiida/calculations/importers/arithmetic/add.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +"""Importer for the :class:`aiida.calculations.arithmetic.add.ArithmeticAddCalculation` plugin.""" +from pathlib import Path +from re import match +from typing import Dict, Union +from tempfile import NamedTemporaryFile + +from aiida.engine import CalcJobImporter +from aiida.orm import Node, Int, RemoteData + + +class ArithmeticAddCalculationImporter(CalcJobImporter): + """Importer for the :class:`aiida.calculations.arithmetic.add.ArithmeticAddCalculation` plugin.""" + + @staticmethod + def parse_remote_data(remote_data: RemoteData, **kwargs) -> Dict[str, Union[Node, Dict]]: + """Parse the input nodes from the files in the provided ``RemoteData``. + + :param remote_data: the remote data node containing the raw input files. + :param kwargs: additional keyword arguments to control the parsing process. + :returns: a dictionary with the parsed inputs nodes that match the input spec of the associated ``CalcJob``. + """ + with NamedTemporaryFile('w+') as handle: + with remote_data.get_authinfo().get_transport() as transport: + filepath = Path(remote_data.get_remote_path()) / 'aiida.in' + transport.getfile(filepath, handle.name) + + handle.seek(0) + data = handle.read() + + matches = match(r'echo \$\(\(([0-9]+) \+ ([0-9]+)\)\).*', data.strip()) + + if matches is None: + raise ValueError(f'failed to parse the integers `x` and `y` from the input content: {data}') + + return { + 'x': Int(matches.group(1)), + 'y': Int(matches.group(2)), + } diff --git a/aiida/engine/__init__.py b/aiida/engine/__init__.py index 86f738c7c6..c3a2a2cdb6 100644 --- a/aiida/engine/__init__.py +++ b/aiida/engine/__init__.py @@ -29,6 +29,7 @@ 'AwaitableTarget', 'BaseRestartWorkChain', 'CalcJob', + 'CalcJobImporter', 'CalcJobOutputPort', 'CalcJobProcessSpec', 'DaemonClient', diff --git a/aiida/engine/processes/__init__.py b/aiida/engine/processes/__init__.py index a2f81d49b3..20668be208 100644 --- a/aiida/engine/processes/__init__.py +++ b/aiida/engine/processes/__init__.py @@ -30,6 +30,7 @@ 'AwaitableTarget', 'BaseRestartWorkChain', 'CalcJob', + 'CalcJobImporter', 'CalcJobOutputPort', 'CalcJobProcessSpec', 'ExitCode', diff --git a/aiida/engine/processes/calcjobs/__init__.py b/aiida/engine/processes/calcjobs/__init__.py index a91782d092..77686c9969 100644 --- a/aiida/engine/processes/calcjobs/__init__.py +++ b/aiida/engine/processes/calcjobs/__init__.py @@ -15,10 +15,12 @@ # pylint: disable=wildcard-import from .calcjob import * +from .importer import * from .manager import * __all__ = ( 'CalcJob', + 'CalcJobImporter', 'JobManager', 'JobsList', ) diff --git a/aiida/engine/processes/calcjobs/calcjob.py b/aiida/engine/processes/calcjobs/calcjob.py index 25beb9d817..77f6aa79da 100644 --- a/aiida/engine/processes/calcjobs/calcjob.py +++ b/aiida/engine/processes/calcjobs/calcjob.py @@ -28,6 +28,7 @@ from ..process import Process, ProcessState from ..process_spec import CalcJobProcessSpec from .tasks import Waiting, UPLOAD_COMMAND +from .importer import CalcJobImporter __all__ = ('CalcJob',) @@ -51,9 +52,6 @@ def validate_calc_job(inputs: Any, ctx: PortNamespace) -> Optional[str]: # pyli # If the namespace no longer contains the `code` or `metadata.computer` ports we skip validation return None - code = inputs.get('code', None) - computer_from_code = code.computer - computer_from_metadata = inputs.get('metadata', {}).get('computer', None) remote_folder = inputs.get('remote_folder', None) if remote_folder is not None: @@ -62,6 +60,10 @@ def validate_calc_job(inputs: Any, ctx: PortNamespace) -> Optional[str]: # pyli # checked for consistency. return None + code = inputs.get('code', None) + computer_from_code = code.computer + computer_from_metadata = inputs.get('metadata', {}).get('computer', None) + if not computer_from_code and not computer_from_metadata: return 'no computer has been specified in `metadata.computer` nor via `code`.' @@ -294,6 +296,27 @@ def spec_options(cls): # pylint: disable=no-self-argument """ return cls.spec_metadata['options'] # pylint: disable=unsubscriptable-object + @classmethod + def get_importer(cls, entry_point_name: str = None) -> CalcJobImporter: + """Load the `CalcJobImporter` associated with this `CalcJob` if it exists. + + By default an importer with the same entry point as the ``CalcJob`` will be loaded, however, this can be + overridden using the ``entry_point_name`` argument. + + :param entry_point_name: optional entry point name of a ``CalcJobImporter`` to override the default. + :return: the loaded ``CalcJobImporter``. + :raises: if no importer class could be loaded. + """ + from aiida.plugins import CalcJobImporterFactory + from aiida.plugins.entry_point import get_entry_point_from_class + + if entry_point_name is None: + _, entry_point = get_entry_point_from_class(cls.__module__, cls.__name__) + if entry_point is not None: + entry_point_name = entry_point.name # type: ignore[attr-defined] + + return CalcJobImporterFactory(entry_point_name)() + @property def options(self) -> AttributeDict: """Return the options of the metadata that were specified when this process instance was launched. @@ -412,6 +435,7 @@ def _perform_import(self): ) retrieve_calculation(self.node, transport, retrieved_temporary_folder.abspath) self.node.set_state(CalcJobState.PARSING) + self.node.set_attribute(orm.CalcJobNode.IMMIGRATED_KEY, True) return self.parse(retrieved_temporary_folder.abspath) def parse(self, retrieved_temporary_folder: Optional[str] = None) -> ExitCode: @@ -465,7 +489,16 @@ def parse(self, retrieved_temporary_folder: Optional[str] = None) -> ExitCode: def parse_scheduler_output(self, retrieved: orm.Node) -> Optional[ExitCode]: """Parse the output of the scheduler if that functionality has been implemented for the plugin.""" - scheduler = self.node.computer.get_scheduler() + computer = self.node.computer + + if computer is None: + self.logger.info( + 'no computer is defined for this calculation job which suggest that it is an imported job and so ' + 'scheduler output probably is not available or not in a format that can be reliably parsed, skipping..' + ) + return None + + scheduler = computer.get_scheduler() filename_stderr = self.node.get_option('scheduler_stderr') filename_stdout = self.node.get_option('scheduler_stdout') @@ -553,12 +586,12 @@ def presubmit(self, folder: Folder) -> CalcInfo: from aiida.orm import load_node, Code, Computer from aiida.schedulers.datastructures import JobTemplate - computer = self.node.computer inputs = self.node.get_incoming(link_type=LinkType.INPUT_CALC) if not self.inputs.metadata.dry_run and self.node.has_cached_links(): # type: ignore[union-attr] raise InvalidOperation('calculation node has unstored links in cache') + computer = self.node.computer codes = [_ for _ in inputs.all_nodes() if isinstance(_, Code)] for code in codes: @@ -576,17 +609,17 @@ def presubmit(self, folder: Folder) -> CalcInfo: calc_info = self.prepare_for_submission(folder) calc_info.uuid = str(self.node.uuid) - scheduler = computer.get_scheduler() # I create the job template to pass to the scheduler job_tmpl = JobTemplate() - job_tmpl.shebang = computer.get_shebang() job_tmpl.submit_as_hold = False job_tmpl.rerunnable = self.options.get('rerunnable', False) job_tmpl.job_environment = {} # 'email', 'email_on_started', 'email_on_terminated', job_tmpl.job_name = f'aiida-{self.node.pk}' job_tmpl.sched_output_path = self.options.scheduler_stdout + if computer is not None: + job_tmpl.shebang = computer.get_shebang() if self.options.scheduler_stderr == self.options.scheduler_stdout: job_tmpl.sched_join_files = True else: @@ -607,6 +640,13 @@ def presubmit(self, folder: Folder) -> CalcInfo: retrieve_temporary_list = calc_info.retrieve_temporary_list or [] self.node.set_retrieve_temporary_list(retrieve_temporary_list) + # If the inputs contain a ``remote_folder`` input node, we are in an import scenario and can skip the rest + if 'remote_folder' in inputs.all_link_labels(): + return + + # The remaining code is only necessary for actual runs, for example, creating the submission script + scheduler = computer.get_scheduler() + # the if is done so that if the method returns None, this is # not added. This has two advantages: # - it does not add too many \n\n if most of the prepend_text are empty diff --git a/aiida/engine/processes/calcjobs/importer.py b/aiida/engine/processes/calcjobs/importer.py new file mode 100644 index 0000000000..1e6b333f20 --- /dev/null +++ b/aiida/engine/processes/calcjobs/importer.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +"""Abstract utility class that helps to import calculation jobs completed outside of AiiDA.""" +from abc import abstractmethod +from typing import Dict, Union + +from aiida.orm import Node, RemoteData + +__all__ = ('CalcJobImporter',) + + +class CalcJobImporter: + + @staticmethod + @abstractmethod + def parse_remote_data(remote_data: RemoteData, **kwargs) -> Dict[str, Union[Node, Dict]]: + """Parse the input nodes from the files in the provided ``RemoteData``. + + :param remote_data: the remote data node containing the raw input files. + :param kwargs: additional keyword arguments to control the parsing process. + :returns: a dictionary with the parsed inputs nodes that match the input spec of the associated ``CalcJob``. + """ diff --git a/aiida/orm/nodes/process/calculation/calcjob.py b/aiida/orm/nodes/process/calculation/calcjob.py index 6ae1f91184..5014c70ef4 100644 --- a/aiida/orm/nodes/process/calculation/calcjob.py +++ b/aiida/orm/nodes/process/calculation/calcjob.py @@ -38,6 +38,7 @@ class CalcJobNode(CalculationNode): # pylint: disable=too-many-public-methods CALC_JOB_STATE_KEY = 'state' + IMMIGRATED_KEY = 'imported' REMOTE_WORKDIR_KEY = 'remote_workdir' RETRIEVE_LIST_KEY = 'retrieve_list' RETRIEVE_TEMPORARY_LIST_KEY = 'retrieve_temporary_list' @@ -89,6 +90,7 @@ def tools(self) -> 'CalculationTools': def _updatable_attributes(cls) -> Tuple[str, ...]: # pylint: disable=no-self-argument return super()._updatable_attributes + ( cls.CALC_JOB_STATE_KEY, + cls.IMMIGRATED_KEY, cls.REMOTE_WORKDIR_KEY, cls.RETRIEVE_LIST_KEY, cls.RETRIEVE_TEMPORARY_LIST_KEY, @@ -151,6 +153,11 @@ def get_builder_restart(self) -> 'ProcessBuilder': builder.metadata.options = self.get_options() # type: ignore[attr-defined] return builder + @property + def is_imported(self) -> bool: + """Return whether the calculation job was imported instead of being an actual run.""" + return self.get_attribute(self.IMMIGRATED_KEY, None) is True + def get_option(self, name: str) -> Optional[Any]: """ Retun the value of an option that was set for this CalcJobNode diff --git a/aiida/plugins/__init__.py b/aiida/plugins/__init__.py index 63c4419cd1..66bd8fb14d 100644 --- a/aiida/plugins/__init__.py +++ b/aiida/plugins/__init__.py @@ -20,6 +20,7 @@ __all__ = ( 'BaseFactory', + 'CalcJobImporterFactory', 'CalculationFactory', 'DataFactory', 'DbImporterFactory', diff --git a/aiida/plugins/factories.py b/aiida/plugins/factories.py index 2670f31b84..3b078069ec 100644 --- a/aiida/plugins/factories.py +++ b/aiida/plugins/factories.py @@ -17,12 +17,12 @@ from aiida.common.exceptions import InvalidEntryPointTypeError __all__ = ( - 'BaseFactory', 'CalculationFactory', 'DataFactory', 'DbImporterFactory', 'GroupFactory', 'OrbitalFactory', - 'ParserFactory', 'SchedulerFactory', 'TransportFactory', 'WorkflowFactory' + 'BaseFactory', 'CalculationFactory', 'CalcJobImporterFactory', 'DataFactory', 'DbImporterFactory', 'GroupFactory', + 'OrbitalFactory', 'ParserFactory', 'SchedulerFactory', 'TransportFactory', 'WorkflowFactory' ) if TYPE_CHECKING: - from aiida.engine import CalcJob, WorkChain + from aiida.engine import CalcJob, CalcJobImporter, WorkChain from aiida.orm import Data, Group from aiida.parsers import Parser from aiida.schedulers import Scheduler @@ -105,6 +105,25 @@ def CalculationFactory(entry_point_name: str, load: bool = True) -> Optional[Uni raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) +def CalcJobImporterFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, 'CalcJobImporter']]: + """Return the plugin registered under the given entry point. + + :param entry_point_name: the entry point name. + :return: the loaded :class:`~aiida.engine.processes.calcjobs.importer.CalcJobImporter` plugin. + :raises ``aiida.common.InvalidEntryPointTypeError``: if the type of the loaded entry point is invalid. + """ + from aiida.engine import CalcJobImporter + + entry_point_group = 'aiida.calculations.importers' + entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) + valid_classes = (CalcJobImporter,) + + if isclass(entry_point) and issubclass(entry_point, CalcJobImporter): # type: ignore[arg-type] + return entry_point + + raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) + + def DataFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, 'Data']]: """Return the `Data` sub class registered under the given entry point. diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 6536484ef2..b70c05da1f 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -48,6 +48,7 @@ py:class BackendEntity py:class BackendNode py:class AuthInfo py:class CalcJob +py:class CalcJobImporter py:class CalcJobNode py:class Data py:class DbImporter diff --git a/setup.json b/setup.json index 821428bfc0..a8f2093014 100644 --- a/setup.json +++ b/setup.json @@ -132,6 +132,9 @@ "core.arithmetic.add = aiida.calculations.arithmetic.add:ArithmeticAddCalculation", "core.templatereplacer = aiida.calculations.templatereplacer:TemplatereplacerCalculation" ], + "aiida.calculations.importers": [ + "core.arithmetic.add = aiida.calculations.importers.arithmetic.add:ArithmeticAddCalculationImporter" + ], "aiida.cmdline.computer.configure": [ "core.local = aiida.transports.plugins.local:CONFIGURE_LOCAL_CMD", "core.ssh = aiida.transports.plugins.ssh:CONFIGURE_SSH_CMD" diff --git a/tests/calculations/importers/arithmetic/test_add.py b/tests/calculations/importers/arithmetic/test_add.py new file mode 100644 index 0000000000..dcbc47728c --- /dev/null +++ b/tests/calculations/importers/arithmetic/test_add.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +"""Tests for the :mod:`aiida.calculations.importers.arithmetic.add` module.""" +from aiida.calculations.importers.arithmetic.add import ArithmeticAddCalculationImporter +from aiida.orm import Int, RemoteData + + +def test_parse_remote_data(tmp_path, aiida_localhost): + """Test the ``ArithmeticAddCalculationImporter.parse_remote_data`` method.""" + with (tmp_path / 'aiida.in').open('w+') as handle: + handle.write('echo $((4 + 12))') + handle.flush() + + remote_data = RemoteData(tmp_path, computer=aiida_localhost) + inputs = ArithmeticAddCalculationImporter.parse_remote_data(remote_data) + + assert list(inputs.keys()) == ['x', 'y'] + assert isinstance(inputs['x'], Int) + assert isinstance(inputs['y'], Int) + assert inputs['x'].value == 4 + assert inputs['y'].value == 12 diff --git a/tests/engine/processes/calcjobs/test_calc_job.py b/tests/engine/processes/calcjobs/test_calc_job.py index cdf6872692..4a50b5fec7 100644 --- a/tests/engine/processes/calcjobs/test_calc_job.py +++ b/tests/engine/processes/calcjobs/test_calc_job.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-public-methods,redefined-outer-name +# pylint: disable=too-many-public-methods,redefined-outer-name,no-self-use """Test for the `CalcJob` process sub class.""" import json from copy import deepcopy @@ -22,7 +22,7 @@ from aiida import orm from aiida.backends.testbase import AiidaTestCase from aiida.common import exceptions, LinkType, CalcJobState, StashMode -from aiida.engine import launch, CalcJob, Process, ExitCode +from aiida.engine import launch, CalcJob, Process, ExitCode, CalcJobImporter from aiida.engine.processes.ports import PortNamespace from aiida.engine.processes.calcjobs.calcjob import validate_stash_options from aiida.plugins import CalculationFactory @@ -471,6 +471,16 @@ def test_parse_retrieved_folder(self): # because the retrieved folder does not contain the output file it expects assert exit_code == process.exit_codes.ERROR_READING_OUTPUT_FILE + def test_get_importer(self): + """Test the ``CalcJob.get_importer`` method.""" + assert isinstance(ArithmeticAddCalculation.get_importer(), CalcJobImporter) + assert isinstance( + ArithmeticAddCalculation.get_importer(entry_point_name='core.arithmetic.add'), CalcJobImporter + ) + + with pytest.raises(exceptions.MissingEntryPointError): + ArithmeticAddCalculation.get_importer(entry_point_name='non-existing') + @pytest.fixture def generate_process(aiida_local_code_factory): @@ -782,7 +792,6 @@ def setUpClass(cls, *args, **kwargs): super().setUpClass(*args, **kwargs) cls.computer.configure() # pylint: disable=no-member cls.inputs = { - 'code': orm.Code(remote_computer_exec=(cls.computer, '/bin/true')).store(), 'x': orm.Int(1), 'y': orm.Int(2), 'metadata': { @@ -814,6 +823,7 @@ def test_import_from_valid(self): assert isinstance(node, orm.CalcJobNode) assert node.is_finished_ok assert node.is_sealed + assert node.is_imported # Verify the expected outputs are there assert 'retrieved' in results @@ -843,6 +853,7 @@ def test_import_from_invalid(self): assert isinstance(node, orm.CalcJobNode) assert node.is_failed assert node.is_sealed + assert node.is_imported assert node.exit_status == ArithmeticAddCalculation.exit_codes.ERROR_INVALID_OUTPUT.status # Verify the expected outputs are there @@ -874,6 +885,7 @@ def test_import_non_default_input_file(self): assert isinstance(node, orm.CalcJobNode) assert node.is_finished_ok assert node.is_sealed + assert node.is_imported # Verify the expected outputs are there assert 'retrieved' in results diff --git a/tests/plugins/test_factories.py b/tests/plugins/test_factories.py index 8c131f2181..25214c0bd7 100644 --- a/tests/plugins/test_factories.py +++ b/tests/plugins/test_factories.py @@ -7,15 +7,15 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for the :py:mod:`~aiida.plugins.factories` module.""" -from unittest.mock import patch +import pytest -from aiida.backends.testbase import AiidaTestCase from aiida.common.exceptions import InvalidEntryPointTypeError -from aiida.engine import calcfunction, workfunction, CalcJob, WorkChain +from aiida.engine import calcfunction, workfunction, CalcJob, WorkChain, CalcJobImporter from aiida.orm import Data, Node, CalcFunctionNode, WorkFunctionNode from aiida.parsers import Parser -from aiida.plugins import factories +from aiida.plugins import entry_point, factories from aiida.schedulers import Scheduler from aiida.transports import Transport from aiida.tools.data.orbital import Orbital @@ -23,7 +23,7 @@ def custom_load_entry_point(group, name): - """Function that mocks `aiida.plugins.entry_point.load_entry_point` that is called by factories.""" + """Function that mocks :meth:`aiida.plugins.entry_point.load_entry_point` that is called by factories.""" @calcfunction def calc_function(): @@ -40,6 +40,10 @@ def work_function(): 'work_function': work_function, 'work_chain': WorkChain }, + 'aiida.calculations.importers': { + 'importer': CalcJobImporter, + 'invalid': CalcJob, + }, 'aiida.data': { 'valid': Data, 'invalid': Node, @@ -74,91 +78,107 @@ def work_function(): return entry_points[group][name] -class TestFactories(AiidaTestCase): +@pytest.fixture +def mock_load_entry_point(monkeypatch): + """Monkeypatch the :meth:`aiida.plugins.entry_point.load_entry_point` method.""" + monkeypatch.setattr(entry_point, 'load_entry_point', custom_load_entry_point) + yield + + +class TestFactories: """Tests for the :py:mod:`~aiida.plugins.factories` factory classes.""" - @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + @pytest.mark.usefixtures('mock_load_entry_point') def test_calculation_factory(self): - """Test the `CalculationFactory`.""" + """Test the ```CalculationFactory```.""" plugin = factories.CalculationFactory('calc_function') - self.assertEqual(plugin.is_process_function, True) - self.assertEqual(plugin.node_class, CalcFunctionNode) + assert plugin.is_process_function + assert plugin.node_class is CalcFunctionNode plugin = factories.CalculationFactory('calc_job') - self.assertEqual(plugin, CalcJob) + assert plugin is CalcJob - with self.assertRaises(InvalidEntryPointTypeError): + with pytest.raises(InvalidEntryPointTypeError): factories.CalculationFactory('work_function') - with self.assertRaises(InvalidEntryPointTypeError): + with pytest.raises(InvalidEntryPointTypeError): factories.CalculationFactory('work_chain') - @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + @pytest.mark.usefixtures('mock_load_entry_point') + def test_calc_job_importer_factory(self): + """Test the ``CalcJobImporterFactory``.""" + plugin = factories.CalcJobImporterFactory('importer') + assert plugin is CalcJobImporter + + with pytest.raises(InvalidEntryPointTypeError): + factories.CalcJobImporterFactory('invalid') + + @pytest.mark.usefixtures('mock_load_entry_point') def test_workflow_factory(self): - """Test the `WorkflowFactory`.""" + """Test the ``WorkflowFactory``.""" plugin = factories.WorkflowFactory('work_function') - self.assertEqual(plugin.is_process_function, True) - self.assertEqual(plugin.node_class, WorkFunctionNode) + assert plugin.is_process_function + assert plugin.node_class is WorkFunctionNode plugin = factories.WorkflowFactory('work_chain') - self.assertEqual(plugin, WorkChain) + assert plugin is WorkChain - with self.assertRaises(InvalidEntryPointTypeError): + with pytest.raises(InvalidEntryPointTypeError): factories.WorkflowFactory('calc_function') - with self.assertRaises(InvalidEntryPointTypeError): + with pytest.raises(InvalidEntryPointTypeError): factories.WorkflowFactory('calc_job') - @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + @pytest.mark.usefixtures('mock_load_entry_point') def test_data_factory(self): - """Test the `DataFactory`.""" + """Test the ``DataFactory``.""" plugin = factories.DataFactory('valid') - self.assertEqual(plugin, Data) + assert plugin is Data - with self.assertRaises(InvalidEntryPointTypeError): + with pytest.raises(InvalidEntryPointTypeError): factories.DataFactory('invalid') - @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + @pytest.mark.usefixtures('mock_load_entry_point') def test_db_importer_factory(self): - """Test the `DbImporterFactory`.""" + """Test the ``DbImporterFactory``.""" plugin = factories.DbImporterFactory('valid') - self.assertEqual(plugin, DbImporter) + assert plugin is DbImporter - with self.assertRaises(InvalidEntryPointTypeError): + with pytest.raises(InvalidEntryPointTypeError): factories.DbImporterFactory('invalid') - @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + @pytest.mark.usefixtures('mock_load_entry_point') def test_orbital_factory(self): - """Test the `OrbitalFactory`.""" + """Test the ``OrbitalFactory``.""" plugin = factories.OrbitalFactory('valid') - self.assertEqual(plugin, Orbital) + assert plugin is Orbital - with self.assertRaises(InvalidEntryPointTypeError): + with pytest.raises(InvalidEntryPointTypeError): factories.OrbitalFactory('invalid') - @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + @pytest.mark.usefixtures('mock_load_entry_point') def test_parser_factory(self): - """Test the `ParserFactory`.""" + """Test the ``ParserFactory``.""" plugin = factories.ParserFactory('valid') - self.assertEqual(plugin, Parser) + assert plugin is Parser - with self.assertRaises(InvalidEntryPointTypeError): + with pytest.raises(InvalidEntryPointTypeError): factories.ParserFactory('invalid') - @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + @pytest.mark.usefixtures('mock_load_entry_point') def test_scheduler_factory(self): - """Test the `SchedulerFactory`.""" + """Test the ``SchedulerFactory``.""" plugin = factories.SchedulerFactory('valid') - self.assertEqual(plugin, Scheduler) + assert plugin is Scheduler - with self.assertRaises(InvalidEntryPointTypeError): + with pytest.raises(InvalidEntryPointTypeError): factories.SchedulerFactory('invalid') - @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + @pytest.mark.usefixtures('mock_load_entry_point') def test_transport_factory(self): - """Test the `TransportFactory`.""" + """Test the ``TransportFactory``.""" plugin = factories.TransportFactory('valid') - self.assertEqual(plugin, Transport) + assert plugin is Transport - with self.assertRaises(InvalidEntryPointTypeError): + with pytest.raises(InvalidEntryPointTypeError): factories.TransportFactory('invalid')