diff --git a/aiida/calculations/arithmetic/add.py b/aiida/calculations/arithmetic/add.py index 1f983f0227..be3e5e8c61 100644 --- a/aiida/calculations/arithmetic/add.py +++ b/aiida/calculations/arithmetic/add.py @@ -53,10 +53,12 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo: handle.write(f'echo $(({self.inputs.x.value} + {self.inputs.y.value}))\n') codeinfo = CodeInfo() - codeinfo.code_uuid = self.inputs.code.uuid codeinfo.stdin_name = self.options.input_filename codeinfo.stdout_name = self.options.output_filename + if 'code' in self.inputs: + codeinfo.code_uuid = self.inputs.code.uuid + calcinfo = CalcInfo() calcinfo.codes_info = [codeinfo] calcinfo.retrieve_list = [self.options.output_filename] diff --git a/aiida/calculations/importers/__init__.py b/aiida/calculations/importers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/aiida/calculations/importers/arithmetic/__init__.py b/aiida/calculations/importers/arithmetic/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/aiida/calculations/importers/arithmetic/add.py b/aiida/calculations/importers/arithmetic/add.py new file mode 100644 index 0000000000..5890ae8109 --- /dev/null +++ b/aiida/calculations/importers/arithmetic/add.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +"""Importer for the :class:`aiida.calculations.arithmetic.add.ArithmeticAddCalculation` plugin.""" +from pathlib import Path +from re import match +from typing import Dict, Union +from tempfile import NamedTemporaryFile + +from aiida.engine import CalcJobImporter +from aiida.orm import Node, Int, RemoteData + + +class ArithmeticAddCalculationImporter(CalcJobImporter): + """Importer for the :class:`aiida.calculations.arithmetic.add.ArithmeticAddCalculation` plugin.""" + + @staticmethod + def parse_remote_data(remote_data: RemoteData, **kwargs) -> Dict[str, Union[Node, Dict]]: + """Parse the input nodes from the files in the provided ``RemoteData``. + + :param remote_data: the remote data node containing the raw input files. + :param kwargs: additional keyword arguments to control the parsing process. + :returns: a dictionary with the parsed inputs nodes that match the input spec of the associated ``CalcJob``. + """ + with NamedTemporaryFile('w+') as handle: + with remote_data.get_authinfo().get_transport() as transport: + filepath = Path(remote_data.get_remote_path()) / 'aiida.in' + transport.getfile(filepath, handle.name) + + handle.seek(0) + data = handle.read() + + matches = match(r'echo \$\(\(([0-9]+) \+ ([0-9]+)\)\).*', data.strip()) + + if matches is None: + raise ValueError(f'failed to parse the integers `x` and `y` from the input content: {data}') + + return { + 'x': Int(matches.group(1)), + 'y': Int(matches.group(2)), + } diff --git a/aiida/engine/__init__.py b/aiida/engine/__init__.py index 86f738c7c6..c3a2a2cdb6 100644 --- a/aiida/engine/__init__.py +++ b/aiida/engine/__init__.py @@ -29,6 +29,7 @@ 'AwaitableTarget', 'BaseRestartWorkChain', 'CalcJob', + 'CalcJobImporter', 'CalcJobOutputPort', 'CalcJobProcessSpec', 'DaemonClient', diff --git a/aiida/engine/processes/__init__.py b/aiida/engine/processes/__init__.py index a2f81d49b3..20668be208 100644 --- a/aiida/engine/processes/__init__.py +++ b/aiida/engine/processes/__init__.py @@ -30,6 +30,7 @@ 'AwaitableTarget', 'BaseRestartWorkChain', 'CalcJob', + 'CalcJobImporter', 'CalcJobOutputPort', 'CalcJobProcessSpec', 'ExitCode', diff --git a/aiida/engine/processes/calcjobs/__init__.py b/aiida/engine/processes/calcjobs/__init__.py index a91782d092..77686c9969 100644 --- a/aiida/engine/processes/calcjobs/__init__.py +++ b/aiida/engine/processes/calcjobs/__init__.py @@ -15,10 +15,12 @@ # pylint: disable=wildcard-import from .calcjob import * +from .importer import * from .manager import * __all__ = ( 'CalcJob', + 'CalcJobImporter', 'JobManager', 'JobsList', ) diff --git a/aiida/engine/processes/calcjobs/calcjob.py b/aiida/engine/processes/calcjobs/calcjob.py index 25beb9d817..77f6aa79da 100644 --- a/aiida/engine/processes/calcjobs/calcjob.py +++ b/aiida/engine/processes/calcjobs/calcjob.py @@ -28,6 +28,7 @@ from ..process import Process, ProcessState from ..process_spec import CalcJobProcessSpec from .tasks import Waiting, UPLOAD_COMMAND +from .importer import CalcJobImporter __all__ = ('CalcJob',) @@ -51,9 +52,6 @@ def validate_calc_job(inputs: Any, ctx: PortNamespace) -> Optional[str]: # pyli # If the namespace no longer contains the `code` or `metadata.computer` ports we skip validation return None - code = inputs.get('code', None) - computer_from_code = code.computer - computer_from_metadata = inputs.get('metadata', {}).get('computer', None) remote_folder = inputs.get('remote_folder', None) if remote_folder is not None: @@ -62,6 +60,10 @@ def validate_calc_job(inputs: Any, ctx: PortNamespace) -> Optional[str]: # pyli # checked for consistency. return None + code = inputs.get('code', None) + computer_from_code = code.computer + computer_from_metadata = inputs.get('metadata', {}).get('computer', None) + if not computer_from_code and not computer_from_metadata: return 'no computer has been specified in `metadata.computer` nor via `code`.' @@ -294,6 +296,27 @@ def spec_options(cls): # pylint: disable=no-self-argument """ return cls.spec_metadata['options'] # pylint: disable=unsubscriptable-object + @classmethod + def get_importer(cls, entry_point_name: str = None) -> CalcJobImporter: + """Load the `CalcJobImporter` associated with this `CalcJob` if it exists. + + By default an importer with the same entry point as the ``CalcJob`` will be loaded, however, this can be + overridden using the ``entry_point_name`` argument. + + :param entry_point_name: optional entry point name of a ``CalcJobImporter`` to override the default. + :return: the loaded ``CalcJobImporter``. + :raises: if no importer class could be loaded. + """ + from aiida.plugins import CalcJobImporterFactory + from aiida.plugins.entry_point import get_entry_point_from_class + + if entry_point_name is None: + _, entry_point = get_entry_point_from_class(cls.__module__, cls.__name__) + if entry_point is not None: + entry_point_name = entry_point.name # type: ignore[attr-defined] + + return CalcJobImporterFactory(entry_point_name)() + @property def options(self) -> AttributeDict: """Return the options of the metadata that were specified when this process instance was launched. @@ -412,6 +435,7 @@ def _perform_import(self): ) retrieve_calculation(self.node, transport, retrieved_temporary_folder.abspath) self.node.set_state(CalcJobState.PARSING) + self.node.set_attribute(orm.CalcJobNode.IMMIGRATED_KEY, True) return self.parse(retrieved_temporary_folder.abspath) def parse(self, retrieved_temporary_folder: Optional[str] = None) -> ExitCode: @@ -465,7 +489,16 @@ def parse(self, retrieved_temporary_folder: Optional[str] = None) -> ExitCode: def parse_scheduler_output(self, retrieved: orm.Node) -> Optional[ExitCode]: """Parse the output of the scheduler if that functionality has been implemented for the plugin.""" - scheduler = self.node.computer.get_scheduler() + computer = self.node.computer + + if computer is None: + self.logger.info( + 'no computer is defined for this calculation job which suggest that it is an imported job and so ' + 'scheduler output probably is not available or not in a format that can be reliably parsed, skipping..' + ) + return None + + scheduler = computer.get_scheduler() filename_stderr = self.node.get_option('scheduler_stderr') filename_stdout = self.node.get_option('scheduler_stdout') @@ -553,12 +586,12 @@ def presubmit(self, folder: Folder) -> CalcInfo: from aiida.orm import load_node, Code, Computer from aiida.schedulers.datastructures import JobTemplate - computer = self.node.computer inputs = self.node.get_incoming(link_type=LinkType.INPUT_CALC) if not self.inputs.metadata.dry_run and self.node.has_cached_links(): # type: ignore[union-attr] raise InvalidOperation('calculation node has unstored links in cache') + computer = self.node.computer codes = [_ for _ in inputs.all_nodes() if isinstance(_, Code)] for code in codes: @@ -576,17 +609,17 @@ def presubmit(self, folder: Folder) -> CalcInfo: calc_info = self.prepare_for_submission(folder) calc_info.uuid = str(self.node.uuid) - scheduler = computer.get_scheduler() # I create the job template to pass to the scheduler job_tmpl = JobTemplate() - job_tmpl.shebang = computer.get_shebang() job_tmpl.submit_as_hold = False job_tmpl.rerunnable = self.options.get('rerunnable', False) job_tmpl.job_environment = {} # 'email', 'email_on_started', 'email_on_terminated', job_tmpl.job_name = f'aiida-{self.node.pk}' job_tmpl.sched_output_path = self.options.scheduler_stdout + if computer is not None: + job_tmpl.shebang = computer.get_shebang() if self.options.scheduler_stderr == self.options.scheduler_stdout: job_tmpl.sched_join_files = True else: @@ -607,6 +640,13 @@ def presubmit(self, folder: Folder) -> CalcInfo: retrieve_temporary_list = calc_info.retrieve_temporary_list or [] self.node.set_retrieve_temporary_list(retrieve_temporary_list) + # If the inputs contain a ``remote_folder`` input node, we are in an import scenario and can skip the rest + if 'remote_folder' in inputs.all_link_labels(): + return + + # The remaining code is only necessary for actual runs, for example, creating the submission script + scheduler = computer.get_scheduler() + # the if is done so that if the method returns None, this is # not added. This has two advantages: # - it does not add too many \n\n if most of the prepend_text are empty diff --git a/aiida/engine/processes/calcjobs/importer.py b/aiida/engine/processes/calcjobs/importer.py new file mode 100644 index 0000000000..1e6b333f20 --- /dev/null +++ b/aiida/engine/processes/calcjobs/importer.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +"""Abstract utility class that helps to import calculation jobs completed outside of AiiDA.""" +from abc import abstractmethod +from typing import Dict, Union + +from aiida.orm import Node, RemoteData + +__all__ = ('CalcJobImporter',) + + +class CalcJobImporter: + + @staticmethod + @abstractmethod + def parse_remote_data(remote_data: RemoteData, **kwargs) -> Dict[str, Union[Node, Dict]]: + """Parse the input nodes from the files in the provided ``RemoteData``. + + :param remote_data: the remote data node containing the raw input files. + :param kwargs: additional keyword arguments to control the parsing process. + :returns: a dictionary with the parsed inputs nodes that match the input spec of the associated ``CalcJob``. + """ diff --git a/aiida/orm/nodes/process/calculation/calcjob.py b/aiida/orm/nodes/process/calculation/calcjob.py index 6ae1f91184..5014c70ef4 100644 --- a/aiida/orm/nodes/process/calculation/calcjob.py +++ b/aiida/orm/nodes/process/calculation/calcjob.py @@ -38,6 +38,7 @@ class CalcJobNode(CalculationNode): # pylint: disable=too-many-public-methods CALC_JOB_STATE_KEY = 'state' + IMMIGRATED_KEY = 'imported' REMOTE_WORKDIR_KEY = 'remote_workdir' RETRIEVE_LIST_KEY = 'retrieve_list' RETRIEVE_TEMPORARY_LIST_KEY = 'retrieve_temporary_list' @@ -89,6 +90,7 @@ def tools(self) -> 'CalculationTools': def _updatable_attributes(cls) -> Tuple[str, ...]: # pylint: disable=no-self-argument return super()._updatable_attributes + ( cls.CALC_JOB_STATE_KEY, + cls.IMMIGRATED_KEY, cls.REMOTE_WORKDIR_KEY, cls.RETRIEVE_LIST_KEY, cls.RETRIEVE_TEMPORARY_LIST_KEY, @@ -151,6 +153,11 @@ def get_builder_restart(self) -> 'ProcessBuilder': builder.metadata.options = self.get_options() # type: ignore[attr-defined] return builder + @property + def is_imported(self) -> bool: + """Return whether the calculation job was imported instead of being an actual run.""" + return self.get_attribute(self.IMMIGRATED_KEY, None) is True + def get_option(self, name: str) -> Optional[Any]: """ Retun the value of an option that was set for this CalcJobNode diff --git a/aiida/plugins/__init__.py b/aiida/plugins/__init__.py index 63c4419cd1..66bd8fb14d 100644 --- a/aiida/plugins/__init__.py +++ b/aiida/plugins/__init__.py @@ -20,6 +20,7 @@ __all__ = ( 'BaseFactory', + 'CalcJobImporterFactory', 'CalculationFactory', 'DataFactory', 'DbImporterFactory', diff --git a/aiida/plugins/factories.py b/aiida/plugins/factories.py index 2670f31b84..3b078069ec 100644 --- a/aiida/plugins/factories.py +++ b/aiida/plugins/factories.py @@ -17,12 +17,12 @@ from aiida.common.exceptions import InvalidEntryPointTypeError __all__ = ( - 'BaseFactory', 'CalculationFactory', 'DataFactory', 'DbImporterFactory', 'GroupFactory', 'OrbitalFactory', - 'ParserFactory', 'SchedulerFactory', 'TransportFactory', 'WorkflowFactory' + 'BaseFactory', 'CalculationFactory', 'CalcJobImporterFactory', 'DataFactory', 'DbImporterFactory', 'GroupFactory', + 'OrbitalFactory', 'ParserFactory', 'SchedulerFactory', 'TransportFactory', 'WorkflowFactory' ) if TYPE_CHECKING: - from aiida.engine import CalcJob, WorkChain + from aiida.engine import CalcJob, CalcJobImporter, WorkChain from aiida.orm import Data, Group from aiida.parsers import Parser from aiida.schedulers import Scheduler @@ -105,6 +105,25 @@ def CalculationFactory(entry_point_name: str, load: bool = True) -> Optional[Uni raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) +def CalcJobImporterFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, 'CalcJobImporter']]: + """Return the plugin registered under the given entry point. + + :param entry_point_name: the entry point name. + :return: the loaded :class:`~aiida.engine.processes.calcjobs.importer.CalcJobImporter` plugin. + :raises ``aiida.common.InvalidEntryPointTypeError``: if the type of the loaded entry point is invalid. + """ + from aiida.engine import CalcJobImporter + + entry_point_group = 'aiida.calculations.importers' + entry_point = BaseFactory(entry_point_group, entry_point_name, load=load) + valid_classes = (CalcJobImporter,) + + if isclass(entry_point) and issubclass(entry_point, CalcJobImporter): # type: ignore[arg-type] + return entry_point + + raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes) + + def DataFactory(entry_point_name: str, load: bool = True) -> Optional[Union[EntryPoint, 'Data']]: """Return the `Data` sub class registered under the given entry point. diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 6536484ef2..b70c05da1f 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -48,6 +48,7 @@ py:class BackendEntity py:class BackendNode py:class AuthInfo py:class CalcJob +py:class CalcJobImporter py:class CalcJobNode py:class Data py:class DbImporter diff --git a/setup.json b/setup.json index 821428bfc0..a8f2093014 100644 --- a/setup.json +++ b/setup.json @@ -132,6 +132,9 @@ "core.arithmetic.add = aiida.calculations.arithmetic.add:ArithmeticAddCalculation", "core.templatereplacer = aiida.calculations.templatereplacer:TemplatereplacerCalculation" ], + "aiida.calculations.importers": [ + "core.arithmetic.add = aiida.calculations.importers.arithmetic.add:ArithmeticAddCalculationImporter" + ], "aiida.cmdline.computer.configure": [ "core.local = aiida.transports.plugins.local:CONFIGURE_LOCAL_CMD", "core.ssh = aiida.transports.plugins.ssh:CONFIGURE_SSH_CMD" diff --git a/tests/calculations/importers/arithmetic/test_add.py b/tests/calculations/importers/arithmetic/test_add.py new file mode 100644 index 0000000000..dcbc47728c --- /dev/null +++ b/tests/calculations/importers/arithmetic/test_add.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +"""Tests for the :mod:`aiida.calculations.importers.arithmetic.add` module.""" +from aiida.calculations.importers.arithmetic.add import ArithmeticAddCalculationImporter +from aiida.orm import Int, RemoteData + + +def test_parse_remote_data(tmp_path, aiida_localhost): + """Test the ``ArithmeticAddCalculationImporter.parse_remote_data`` method.""" + with (tmp_path / 'aiida.in').open('w+') as handle: + handle.write('echo $((4 + 12))') + handle.flush() + + remote_data = RemoteData(tmp_path, computer=aiida_localhost) + inputs = ArithmeticAddCalculationImporter.parse_remote_data(remote_data) + + assert list(inputs.keys()) == ['x', 'y'] + assert isinstance(inputs['x'], Int) + assert isinstance(inputs['y'], Int) + assert inputs['x'].value == 4 + assert inputs['y'].value == 12 diff --git a/tests/engine/processes/calcjobs/test_calc_job.py b/tests/engine/processes/calcjobs/test_calc_job.py index cdf6872692..4a50b5fec7 100644 --- a/tests/engine/processes/calcjobs/test_calc_job.py +++ b/tests/engine/processes/calcjobs/test_calc_job.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-public-methods,redefined-outer-name +# pylint: disable=too-many-public-methods,redefined-outer-name,no-self-use """Test for the `CalcJob` process sub class.""" import json from copy import deepcopy @@ -22,7 +22,7 @@ from aiida import orm from aiida.backends.testbase import AiidaTestCase from aiida.common import exceptions, LinkType, CalcJobState, StashMode -from aiida.engine import launch, CalcJob, Process, ExitCode +from aiida.engine import launch, CalcJob, Process, ExitCode, CalcJobImporter from aiida.engine.processes.ports import PortNamespace from aiida.engine.processes.calcjobs.calcjob import validate_stash_options from aiida.plugins import CalculationFactory @@ -471,6 +471,16 @@ def test_parse_retrieved_folder(self): # because the retrieved folder does not contain the output file it expects assert exit_code == process.exit_codes.ERROR_READING_OUTPUT_FILE + def test_get_importer(self): + """Test the ``CalcJob.get_importer`` method.""" + assert isinstance(ArithmeticAddCalculation.get_importer(), CalcJobImporter) + assert isinstance( + ArithmeticAddCalculation.get_importer(entry_point_name='core.arithmetic.add'), CalcJobImporter + ) + + with pytest.raises(exceptions.MissingEntryPointError): + ArithmeticAddCalculation.get_importer(entry_point_name='non-existing') + @pytest.fixture def generate_process(aiida_local_code_factory): @@ -782,7 +792,6 @@ def setUpClass(cls, *args, **kwargs): super().setUpClass(*args, **kwargs) cls.computer.configure() # pylint: disable=no-member cls.inputs = { - 'code': orm.Code(remote_computer_exec=(cls.computer, '/bin/true')).store(), 'x': orm.Int(1), 'y': orm.Int(2), 'metadata': { @@ -814,6 +823,7 @@ def test_import_from_valid(self): assert isinstance(node, orm.CalcJobNode) assert node.is_finished_ok assert node.is_sealed + assert node.is_imported # Verify the expected outputs are there assert 'retrieved' in results @@ -843,6 +853,7 @@ def test_import_from_invalid(self): assert isinstance(node, orm.CalcJobNode) assert node.is_failed assert node.is_sealed + assert node.is_imported assert node.exit_status == ArithmeticAddCalculation.exit_codes.ERROR_INVALID_OUTPUT.status # Verify the expected outputs are there @@ -874,6 +885,7 @@ def test_import_non_default_input_file(self): assert isinstance(node, orm.CalcJobNode) assert node.is_finished_ok assert node.is_sealed + assert node.is_imported # Verify the expected outputs are there assert 'retrieved' in results diff --git a/tests/plugins/test_factories.py b/tests/plugins/test_factories.py index 8c131f2181..25214c0bd7 100644 --- a/tests/plugins/test_factories.py +++ b/tests/plugins/test_factories.py @@ -7,15 +7,15 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +# pylint: disable=no-self-use """Tests for the :py:mod:`~aiida.plugins.factories` module.""" -from unittest.mock import patch +import pytest -from aiida.backends.testbase import AiidaTestCase from aiida.common.exceptions import InvalidEntryPointTypeError -from aiida.engine import calcfunction, workfunction, CalcJob, WorkChain +from aiida.engine import calcfunction, workfunction, CalcJob, WorkChain, CalcJobImporter from aiida.orm import Data, Node, CalcFunctionNode, WorkFunctionNode from aiida.parsers import Parser -from aiida.plugins import factories +from aiida.plugins import entry_point, factories from aiida.schedulers import Scheduler from aiida.transports import Transport from aiida.tools.data.orbital import Orbital @@ -23,7 +23,7 @@ def custom_load_entry_point(group, name): - """Function that mocks `aiida.plugins.entry_point.load_entry_point` that is called by factories.""" + """Function that mocks :meth:`aiida.plugins.entry_point.load_entry_point` that is called by factories.""" @calcfunction def calc_function(): @@ -40,6 +40,10 @@ def work_function(): 'work_function': work_function, 'work_chain': WorkChain }, + 'aiida.calculations.importers': { + 'importer': CalcJobImporter, + 'invalid': CalcJob, + }, 'aiida.data': { 'valid': Data, 'invalid': Node, @@ -74,91 +78,107 @@ def work_function(): return entry_points[group][name] -class TestFactories(AiidaTestCase): +@pytest.fixture +def mock_load_entry_point(monkeypatch): + """Monkeypatch the :meth:`aiida.plugins.entry_point.load_entry_point` method.""" + monkeypatch.setattr(entry_point, 'load_entry_point', custom_load_entry_point) + yield + + +class TestFactories: """Tests for the :py:mod:`~aiida.plugins.factories` factory classes.""" - @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + @pytest.mark.usefixtures('mock_load_entry_point') def test_calculation_factory(self): - """Test the `CalculationFactory`.""" + """Test the ```CalculationFactory```.""" plugin = factories.CalculationFactory('calc_function') - self.assertEqual(plugin.is_process_function, True) - self.assertEqual(plugin.node_class, CalcFunctionNode) + assert plugin.is_process_function + assert plugin.node_class is CalcFunctionNode plugin = factories.CalculationFactory('calc_job') - self.assertEqual(plugin, CalcJob) + assert plugin is CalcJob - with self.assertRaises(InvalidEntryPointTypeError): + with pytest.raises(InvalidEntryPointTypeError): factories.CalculationFactory('work_function') - with self.assertRaises(InvalidEntryPointTypeError): + with pytest.raises(InvalidEntryPointTypeError): factories.CalculationFactory('work_chain') - @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + @pytest.mark.usefixtures('mock_load_entry_point') + def test_calc_job_importer_factory(self): + """Test the ``CalcJobImporterFactory``.""" + plugin = factories.CalcJobImporterFactory('importer') + assert plugin is CalcJobImporter + + with pytest.raises(InvalidEntryPointTypeError): + factories.CalcJobImporterFactory('invalid') + + @pytest.mark.usefixtures('mock_load_entry_point') def test_workflow_factory(self): - """Test the `WorkflowFactory`.""" + """Test the ``WorkflowFactory``.""" plugin = factories.WorkflowFactory('work_function') - self.assertEqual(plugin.is_process_function, True) - self.assertEqual(plugin.node_class, WorkFunctionNode) + assert plugin.is_process_function + assert plugin.node_class is WorkFunctionNode plugin = factories.WorkflowFactory('work_chain') - self.assertEqual(plugin, WorkChain) + assert plugin is WorkChain - with self.assertRaises(InvalidEntryPointTypeError): + with pytest.raises(InvalidEntryPointTypeError): factories.WorkflowFactory('calc_function') - with self.assertRaises(InvalidEntryPointTypeError): + with pytest.raises(InvalidEntryPointTypeError): factories.WorkflowFactory('calc_job') - @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + @pytest.mark.usefixtures('mock_load_entry_point') def test_data_factory(self): - """Test the `DataFactory`.""" + """Test the ``DataFactory``.""" plugin = factories.DataFactory('valid') - self.assertEqual(plugin, Data) + assert plugin is Data - with self.assertRaises(InvalidEntryPointTypeError): + with pytest.raises(InvalidEntryPointTypeError): factories.DataFactory('invalid') - @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + @pytest.mark.usefixtures('mock_load_entry_point') def test_db_importer_factory(self): - """Test the `DbImporterFactory`.""" + """Test the ``DbImporterFactory``.""" plugin = factories.DbImporterFactory('valid') - self.assertEqual(plugin, DbImporter) + assert plugin is DbImporter - with self.assertRaises(InvalidEntryPointTypeError): + with pytest.raises(InvalidEntryPointTypeError): factories.DbImporterFactory('invalid') - @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + @pytest.mark.usefixtures('mock_load_entry_point') def test_orbital_factory(self): - """Test the `OrbitalFactory`.""" + """Test the ``OrbitalFactory``.""" plugin = factories.OrbitalFactory('valid') - self.assertEqual(plugin, Orbital) + assert plugin is Orbital - with self.assertRaises(InvalidEntryPointTypeError): + with pytest.raises(InvalidEntryPointTypeError): factories.OrbitalFactory('invalid') - @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + @pytest.mark.usefixtures('mock_load_entry_point') def test_parser_factory(self): - """Test the `ParserFactory`.""" + """Test the ``ParserFactory``.""" plugin = factories.ParserFactory('valid') - self.assertEqual(plugin, Parser) + assert plugin is Parser - with self.assertRaises(InvalidEntryPointTypeError): + with pytest.raises(InvalidEntryPointTypeError): factories.ParserFactory('invalid') - @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + @pytest.mark.usefixtures('mock_load_entry_point') def test_scheduler_factory(self): - """Test the `SchedulerFactory`.""" + """Test the ``SchedulerFactory``.""" plugin = factories.SchedulerFactory('valid') - self.assertEqual(plugin, Scheduler) + assert plugin is Scheduler - with self.assertRaises(InvalidEntryPointTypeError): + with pytest.raises(InvalidEntryPointTypeError): factories.SchedulerFactory('invalid') - @patch('aiida.plugins.entry_point.load_entry_point', custom_load_entry_point) + @pytest.mark.usefixtures('mock_load_entry_point') def test_transport_factory(self): - """Test the `TransportFactory`.""" + """Test the ``TransportFactory``.""" plugin = factories.TransportFactory('valid') - self.assertEqual(plugin, Transport) + assert plugin is Transport - with self.assertRaises(InvalidEntryPointTypeError): + with pytest.raises(InvalidEntryPointTypeError): factories.TransportFactory('invalid')