diff --git a/.pylintrc b/.pylintrc index 089a60ca07..5816e54bd0 100644 --- a/.pylintrc +++ b/.pylintrc @@ -59,7 +59,8 @@ disable=bad-continuation, no-else-raise, import-outside-toplevel, cyclic-import, - duplicate-code + duplicate-code, + too-few-public-methods # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/aiida/engine/processes/calcjobs/calcjob.py b/aiida/engine/processes/calcjobs/calcjob.py index bfc95ff053..6f686355fd 100644 --- a/aiida/engine/processes/calcjobs/calcjob.py +++ b/aiida/engine/processes/calcjobs/calcjob.py @@ -26,7 +26,7 @@ __all__ = ('CalcJob',) -def validate_calc_job(inputs, ctx): +def validate_calc_job(inputs, ctx): # pylint: disable=inconsistent-return-statements,too-many-return-statements """Validate the entire set of inputs passed to the `CalcJob` constructor. Reasons that will cause this validation to raise an `InputValidationError`: @@ -35,7 +35,7 @@ def validate_calc_job(inputs, ctx): * The specified computer is not stored * The `Computer` specified in `metadata.computer` is not the same as that of the specified `Code` - :raises `~aiida.common.exceptions.InputValidationError`: if inputs are invalid + :return: string with error message in case the inputs are invalid """ try: ctx.get_port('code') @@ -49,27 +49,49 @@ def validate_calc_job(inputs, ctx): computer_from_metadata = inputs.get('metadata', {}).get('computer', None) if not computer_from_code and not computer_from_metadata: - raise exceptions.InputValidationError('no computer has been specified in `metadata.computer` nor via `code`.') + return 'no computer has been specified in `metadata.computer` nor via `code`.' if computer_from_code and not computer_from_code.is_stored: - raise exceptions.InputValidationError('the Computer<{}> is not stored'.format(computer_from_code)) + return 'the Computer<{}> is not stored'.format(computer_from_code) if computer_from_metadata and not computer_from_metadata.is_stored: - raise exceptions.InputValidationError('the Computer<{}> is not stored'.format(computer_from_metadata)) + return 'the Computer<{}> is not stored'.format(computer_from_metadata) if computer_from_code and computer_from_metadata and computer_from_code.uuid != computer_from_metadata.uuid: - raise exceptions.InputValidationError( - 'Computer<{}> explicitly defined in `metadata.computer is different from ' - 'Computer<{}> which is the computer of Code<{}> defined as the `code` input.'.format( + return ( + 'Computer<{}> explicitly defined in `metadata.computer` is different from Computer<{}> which is the ' + 'computer of Code<{}> defined as the `code` input.'.format( computer_from_metadata, computer_from_code, code ) ) + try: + resources_port = ctx.get_port('metadata.options.resources') + except ValueError: + return + + # If the resources port exists but is not required, we don't need to validate it against the computer's scheduler + if not resources_port.required: + return + + computer = computer_from_code or computer_from_metadata + scheduler = computer.get_scheduler() + try: + resources = inputs['metadata']['options']['resources'] + except KeyError: + return 'input `metadata.options.resources` is required but is not specified' -def validate_parser(parser_name, ctx): # pylint: disable=unused-argument + try: + scheduler.preprocess_resources(resources, computer.get_default_mpiprocs_per_machine()) + scheduler.validate_resources(**resources) + except (ValueError, TypeError) as exception: + return 'input `metadata.options.resources` is not valid for the {} scheduler: {}'.format(scheduler, exception) + + +def validate_parser(parser_name, _): # pylint: disable=inconsistent-return-statements """Validate the parser. - :raises InputValidationError: if the parser name does not correspond to a loadable `Parser` class. + :return: string with error message in case the inputs are invalid """ from aiida.plugins import ParserFactory @@ -77,20 +99,7 @@ def validate_parser(parser_name, ctx): # pylint: disable=unused-argument try: ParserFactory(parser_name) except exceptions.EntryPointError as exception: - raise exceptions.InputValidationError('invalid parser specified: {}'.format(exception)) - - -def validate_resources(resources, ctx): # pylint: disable=unused-argument - """Validate the resources. - - :raises InputValidationError: if `num_machines` is not specified or is not an integer. - """ - if resources is not plumpy.UNSPECIFIED: - if 'num_machines' not in resources: - raise exceptions.InputValidationError('the `resources` input has to at least include `num_machines`.') - - if not isinstance(resources['num_machines'], int): - raise exceptions.InputValidationError('the input `resources.num_machines` shoud be an integer.') + return 'invalid parser specified: {}'.format(exception) class CalcJob(Process): @@ -137,7 +146,7 @@ def define(cls, spec: CalcJobProcessSpec): help='Filename to which the content of stdout of the scheduler is written.') spec.input('metadata.options.scheduler_stderr', valid_type=str, default='_scheduler-stderr.txt', help='Filename to which the content of stderr of the scheduler is written.') - spec.input('metadata.options.resources', valid_type=dict, required=True, validator=validate_resources, + spec.input('metadata.options.resources', valid_type=dict, required=True, help='Set the dictionary of resources to be used by the scheduler plugin, like the number of nodes, ' 'cpus etc. This dictionary is scheduler-plugin dependent. Look at the documentation of the ' 'scheduler for more details.') @@ -389,9 +398,7 @@ def presubmit(self, folder): # Set resources, also with get_default_mpiprocs_per_machine resources = self.node.get_option('resources') - def_cpus_machine = computer.get_default_mpiprocs_per_machine() - if def_cpus_machine is not None: - resources['default_mpiprocs_per_machine'] = def_cpus_machine + scheduler.preprocess_resources(resources, computer.get_default_mpiprocs_per_machine()) job_tmpl.job_resource = scheduler.create_job_resource(**resources) subst_dict = {'tot_num_mpiprocs': job_tmpl.job_resource.get_tot_num_mpiprocs()} diff --git a/aiida/schedulers/datastructures.py b/aiida/schedulers/datastructures.py index c5cce9ec52..631db31164 100644 --- a/aiida/schedulers/datastructures.py +++ b/aiida/schedulers/datastructures.py @@ -7,19 +7,18 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -This module defines the main data structures used by the Scheduler. +"""Data structures used by `Scheduler` instances. In particular, there is the definition of possible job states (job_states), the data structure to be filled for job submission (JobTemplate), and the data structure that is returned when querying for jobs in the scheduler (JobInfo). """ - -from enum import Enum +import abc +import enum from aiida.common import AIIDA_LOGGER -from aiida.common.extendeddicts import DefaultFieldsAttributeDict +from aiida.common.extendeddicts import AttributeDict, DefaultFieldsAttributeDict SCHEDULER_LOGGER = AIIDA_LOGGER.getChild('scheduler') @@ -28,7 +27,7 @@ ) -class JobState(Enum): +class JobState(enum.Enum): """Enumeration of possible scheduler states of a CalcJob. There is no FAILED state as every completed job is put in DONE, regardless of success. @@ -42,14 +41,11 @@ class JobState(Enum): DONE = 'done' -class JobResource(DefaultFieldsAttributeDict): - """ - A class to store the job resources. It must be inherited and redefined by the specific - plugin, that should contain a ``_job_resource_class`` attribute pointing to the correct - JobResource subclass. +class JobResource(DefaultFieldsAttributeDict, metaclass=abc.ABCMeta): + """Data structure to store job resources. - It should at least define the get_tot_num_mpiprocs() method, plus an __init__ to accept - its set of variables. + Each `Scheduler` implementation must define the `_job_resource_class` attribute to be a subclass of this class. + It should at least define the `get_tot_num_mpiprocs` method, plus a constructor to accept its set of variables. Typical attributes are: @@ -61,40 +57,37 @@ class JobResource(DefaultFieldsAttributeDict): * ``tot_num_mpiprocs`` * ``parallel_env`` - The __init__ should take care of checking the values. + The constructor should take care of checking the values. The init should raise only ValueError or TypeError on invalid parameters. """ _default_fields = tuple() - @classmethod - def accepts_default_mpiprocs_per_machine(cls): - """ - Return True if this JobResource accepts a 'default_mpiprocs_per_machine' - key, False otherwise. + @abc.abstractclassmethod + def validate_resources(cls, **kwargs): + """Validate the resources against the job resource class of this scheduler. - Should be implemented in each subclass. + :param kwargs: dictionary of values to define the job resources + :raises ValueError: if the resources are invalid or incomplete + :return: optional tuple of parsed resource settings """ - raise NotImplementedError @classmethod def get_valid_keys(cls): - """ - Return a list of valid keys to be passed to the __init__ - """ + """Return a list of valid keys to be passed to the constructor.""" return list(cls._default_fields) + @abc.abstractclassmethod + def accepts_default_mpiprocs_per_machine(cls): + """Return True if this subclass accepts a `default_mpiprocs_per_machine` key, False otherwise.""" + + @abc.abstractmethod def get_tot_num_mpiprocs(self): - """ - Return the total number of cpus of this job resource. - """ - raise NotImplementedError + """Return the total number of cpus of this job resource.""" class NodeNumberJobResource(JobResource): - """ - An implementation of JobResource for schedulers that support - the specification of a number of nodes and a number of cpus per node - """ + """`JobResource` for schedulers that support the specification of a number of nodes and cpus per node.""" + _default_fields = ( 'num_machines', 'num_mpiprocs_per_machine', @@ -103,205 +96,145 @@ class NodeNumberJobResource(JobResource): ) @classmethod - def get_valid_keys(cls): - """ - Return a list of valid keys to be passed to the __init__ - """ - return super().get_valid_keys() + ['tot_num_mpiprocs', 'default_mpiprocs_per_machine'] + def validate_resources(cls, **kwargs): + """Validate the resources against the job resource class of this scheduler. - @classmethod - def accepts_default_mpiprocs_per_machine(cls): - """ - Return True if this JobResource accepts a 'default_mpiprocs_per_machine' - key, False otherwise. + :param kwargs: dictionary of values to define the job resources + :return: attribute dictionary with the parsed parameters populated + :raises ValueError: if the resources are invalid or incomplete """ - return True + resources = AttributeDict() - def __init__(self, **kwargs): # pylint: disable=too-many-branches,too-many-statements - """ - Initialize the job resources from the passed arguments (the valid keys can be - obtained with the function self.get_valid_keys()). + def is_greater_equal_one(parameter): + value = getattr(resources, parameter, None) + if value is not None and value < 1: + raise ValueError('`{}` must be greater than or equal to one.'.format(parameter)) - Should raise only ValueError or TypeError on invalid parameters. - """ - super().__init__() + # Validate that all fields are valid integers if they are specified, otherwise initialize them to `None` + for parameter in list(cls._default_fields) + ['tot_num_mpiprocs']: + try: + setattr(resources, parameter, int(kwargs.pop(parameter))) + except KeyError: + setattr(resources, parameter, None) + except ValueError: + raise ValueError('`{}` must be an integer when specified'.format(parameter)) - try: - num_machines = int(kwargs.pop('num_machines')) - except KeyError: - num_machines = None - except ValueError: - raise ValueError('num_machines must an integer') + if kwargs: + raise ValueError('these parameters were not recognized: {}'.format(', '.join(list(kwargs.keys())))) - try: - default_mpiprocs_per_machine = kwargs.pop('default_mpiprocs_per_machine') - if default_mpiprocs_per_machine is not None: - default_mpiprocs_per_machine = int(default_mpiprocs_per_machine) - except KeyError: - default_mpiprocs_per_machine = None - except ValueError: - raise ValueError('default_mpiprocs_per_machine must an integer') + # At least two of the following parameters need to be defined as non-zero + if [resources.num_machines, resources.num_mpiprocs_per_machine, resources.tot_num_mpiprocs].count(None) > 1: + raise ValueError( + 'At least two among `num_machines`, `num_mpiprocs_per_machine` or `tot_num_mpiprocs` must be specified.' + ) - try: - num_mpiprocs_per_machine = int(kwargs.pop('num_mpiprocs_per_machine')) - except KeyError: - num_mpiprocs_per_machine = None - except ValueError: - raise ValueError('num_mpiprocs_per_machine must an integer') + for parameter in ['num_machines', 'num_mpiprocs_per_machine']: + is_greater_equal_one(parameter) - try: - tot_num_mpiprocs = int(kwargs.pop('tot_num_mpiprocs')) - except KeyError: - tot_num_mpiprocs = None - except ValueError: - raise ValueError('tot_num_mpiprocs must an integer') + # Here we now that at least two of the three required variables are defined and greater equal than one. + if resources.num_machines is None: + resources.num_machines = resources.tot_num_mpiprocs // resources.num_mpiprocs_per_machine + elif resources.num_mpiprocs_per_machine is None: + resources.num_mpiprocs_per_machine = resources.tot_num_mpiprocs // resources.num_machines + elif resources.tot_num_mpiprocs is None: + resources.tot_num_mpiprocs = resources.num_mpiprocs_per_machine * resources.num_machines - try: - self.num_cores_per_machine = int(kwargs.pop('num_cores_per_machine')) - except KeyError: - self.num_cores_per_machine = None - except ValueError: - raise ValueError('num_cores_per_machine must an integer') + if resources.tot_num_mpiprocs != resources.num_mpiprocs_per_machine * resources.num_machines: + raise ValueError('`tot_num_mpiprocs` is not equal to `num_mpiprocs_per_machine * num_machines`.') - try: - self.num_cores_per_mpiproc = int(kwargs.pop('num_cores_per_mpiproc')) - except KeyError: - self.num_cores_per_mpiproc = None - except ValueError: - raise ValueError('num_cores_per_mpiproc must an integer') + is_greater_equal_one('num_mpiprocs_per_machine') + is_greater_equal_one('num_machines') - if kwargs: - raise TypeError( - 'The following parameters were not recognized for ' - 'the JobResource: {}'.format(kwargs.keys()) - ) + return resources - if num_machines is None: - # Use default value, if not provided - if num_mpiprocs_per_machine is None: - num_mpiprocs_per_machine = default_mpiprocs_per_machine - - if num_mpiprocs_per_machine is None or tot_num_mpiprocs is None: - raise TypeError( - 'At least two among num_machines, ' - 'num_mpiprocs_per_machine or tot_num_mpiprocs must be specified' - ) - else: - # To avoid divisions by zero - if num_mpiprocs_per_machine <= 0: - raise ValueError('num_mpiprocs_per_machine must be >= 1') - num_machines = tot_num_mpiprocs // num_mpiprocs_per_machine - else: - if tot_num_mpiprocs is None: - # Only set the default value if tot_num_mpiprocs is not provided. - # Otherwise, it means that the user provided both - # num_machines and tot_num_mpiprocs, and we have to ignore - # the default value of tot_num_mpiprocs - if num_mpiprocs_per_machine is None: - num_mpiprocs_per_machine = default_mpiprocs_per_machine - - if num_mpiprocs_per_machine is None: - if tot_num_mpiprocs is None: - raise TypeError( - 'At least two among num_machines, ' - 'num_mpiprocs_per_machine or tot_num_mpiprocs must be specified' - ) - else: - # To avoid divisions by zero - if num_machines <= 0: - raise ValueError('num_machines must be >= 1') - num_mpiprocs_per_machine = tot_num_mpiprocs // num_machines - - self.num_machines = num_machines - self.num_mpiprocs_per_machine = num_mpiprocs_per_machine - - if tot_num_mpiprocs is not None: - if tot_num_mpiprocs != self.num_mpiprocs_per_machine * self.num_machines: - raise ValueError( - 'tot_num_mpiprocs must be equal to ' - 'num_mpiprocs_per_machine * num_machines, and in particular it ' - 'should be a multiple of num_mpiprocs_per_machine and/or ' - 'num_machines' - ) - - if self.num_mpiprocs_per_machine <= 0: - raise ValueError('num_mpiprocs_per_machine must be >= 1') - if self.num_machines <= 0: - raise ValueError('num_machine must be >= 1') + def __init__(self, **kwargs): + """Initialize the job resources from the passed arguments. - def get_tot_num_mpiprocs(self): - """ - Return the total number of cpus of this job resource. + :raises ValueError: if the resources are invalid or incomplete """ + resources = self.validate_resources(**kwargs) + super().__init__(resources) + + @classmethod + def get_valid_keys(cls): + """Return a list of valid keys to be passed to the constructor.""" + return super().get_valid_keys() + ['tot_num_mpiprocs'] + + @classmethod + def accepts_default_mpiprocs_per_machine(cls): + """Return True if this subclass accepts a `default_mpiprocs_per_machine` key, False otherwise.""" + return True + + def get_tot_num_mpiprocs(self): + """Return the total number of cpus of this job resource.""" return self.num_machines * self.num_mpiprocs_per_machine class ParEnvJobResource(JobResource): - """ - An implementation of JobResource for schedulers that support - the specification of a parallel environment (a string) + the total number of nodes - """ + """`JobResource` for schedulers that support the specification of a parallel environment and number of MPI procs.""" + _default_fields = ( 'parallel_env', 'tot_num_mpiprocs', - 'default_mpiprocs_per_machine', ) - def __init__(self, **kwargs): - """ - Initialize the job resources from the passed arguments (the valid keys can be - obtained with the function self.get_valid_keys()). + @classmethod + def validate_resources(cls, **kwargs): + """Validate the resources against the job resource class of this scheduler. - :raise ValueError: on invalid parameters. - :raise TypeError: on invalid parameters. - :raise aiida.common.ConfigurationError: if default_mpiprocs_per_machine was set for this - computer, since ParEnvJobResource cannot accept this parameter. + :param kwargs: dictionary of values to define the job resources + :return: attribute dictionary with the parsed parameters populated + :raises ValueError: if the resources are invalid or incomplete """ - from aiida.common.exceptions import ConfigurationError - super().__init__() + resources = AttributeDict() try: - self.parallel_env = str(kwargs.pop('parallel_env')) - except (KeyError, TypeError, ValueError): - raise TypeError("'parallel_env' must be specified and must be a string") + resources.parallel_env = kwargs.pop('parallel_env') + except KeyError: + raise ValueError('`parallel_env` must be specified and must be a string') + else: + if not isinstance(resources.parallel_env, str): + raise ValueError('`parallel_env` must be specified and must be a string') try: - self.tot_num_mpiprocs = int(kwargs.pop('tot_num_mpiprocs')) + resources.tot_num_mpiprocs = int(kwargs.pop('tot_num_mpiprocs')) except (KeyError, ValueError): - raise TypeError('tot_num_mpiprocs must be specified and must be an integer') + raise ValueError('`tot_num_mpiprocs` must be specified and must be an integer') - default_mpiprocs_per_machine = kwargs.pop('default_mpiprocs_per_machine', None) - if default_mpiprocs_per_machine is not None: - raise ConfigurationError( - 'default_mpiprocs_per_machine cannot be set ' - 'for schedulers that use ParEnvJobResource' - ) + if resources.tot_num_mpiprocs < 1: + raise ValueError('`tot_num_mpiprocs` must be greater than or equal to one.') - if self.tot_num_mpiprocs <= 0: - raise ValueError('tot_num_mpiprocs must be >= 1') + if kwargs: + raise ValueError('these parameters were not recognized: {}'.format(', '.join(list(kwargs.keys())))) - def get_tot_num_mpiprocs(self): + return resources + + def __init__(self, **kwargs): """ - Return the total number of cpus of this job resource. + Initialize the job resources from the passed arguments (the valid keys can be + obtained with the function self.get_valid_keys()). + + :raises ValueError: if the resources are invalid or incomplete """ - return self.tot_num_mpiprocs + resources = self.validate_resources(**kwargs) + super().__init__(resources) @classmethod def accepts_default_mpiprocs_per_machine(cls): - """ - Return True if this JobResource accepts a 'default_mpiprocs_per_machine' - key, False otherwise. - """ + """Return True if this subclass accepts a `default_mpiprocs_per_machine` key, False otherwise.""" return False + def get_tot_num_mpiprocs(self): + """Return the total number of cpus of this job resource.""" + return self.tot_num_mpiprocs + class JobTemplate(DefaultFieldsAttributeDict): # pylint: disable=too-many-instance-attributes - """ - A template for submitting jobs. This contains all required information - to create the job header. + """A template for submitting jobs to a scheduler. + + This contains all required information to create the job header. - The required fields are: working_directory, job_name, num_machines, - num_mpiprocs_per_machine, argv. + The required fields are: working_directory, job_name, num_machines, num_mpiprocs_per_machine, argv. Fields: diff --git a/aiida/schedulers/plugins/pbsbaseclasses.py b/aiida/schedulers/plugins/pbsbaseclasses.py index 38c135afc2..9fe8f10926 100644 --- a/aiida/schedulers/plugins/pbsbaseclasses.py +++ b/aiida/schedulers/plugins/pbsbaseclasses.py @@ -64,47 +64,37 @@ class PbsJobResource(NodeNumberJobResource): - """ - Base class for PBS job resources - """ - - def __init__(self, **kwargs): - """ - It extends the base class init method and calculates the - num_cores_per_machine fields to pass to PBSlike schedulers. + """Class for PBS job resources.""" - Checks that num_cores_per_machine is a multiple of - num_cores_per_mpiproc and/or num_mpiprocs_per_machine + @classmethod + def validate_resources(cls, **kwargs): + """Validate the resources against the job resource class of this scheduler. - Check sequence + This extends the base class validator and calculates the `num_cores_per_machine` fields to pass to PBSlike + schedulers. Checks that `num_cores_per_machine` is a multiple of `num_cores_per_mpiproc` and/or + `num_mpiprocs_per_machine`. - 1. If num_cores_per_mpiproc and num_cores_per_machine both are - specified check whether it satisfies the check - 2. If only num_cores_per_mpiproc is passed, calculate - num_cores_per_machine - 3. If only num_cores_per_machine is passed, use it + :param kwargs: dictionary of values to define the job resources + :return: attribute dictionary with the parsed parameters populated + :raises ValueError: if the resources are invalid or incomplete """ - super().__init__(**kwargs) + resources = super().validate_resources(**kwargs) - value_error = ( - 'num_cores_per_machine must be equal to ' - 'num_cores_per_mpiproc * num_mpiprocs_per_machine, ' - 'and in perticular it should be a multiple of ' - 'num_cores_per_mpiproc and/or num_mpiprocs_per_machine' - ) + if resources.num_cores_per_machine is not None and resources.num_cores_per_mpiproc is not None: + if resources.num_cores_per_machine != resources.num_cores_per_mpiproc * resources.num_mpiprocs_per_machine: + raise ValueError( + '`num_cores_per_machine` must be equal to `num_cores_per_mpiproc * num_mpiprocs_per_machine` and in' + ' particular it should be a multiple of `num_cores_per_mpiproc` and/or `num_mpiprocs_per_machine`' + ) + + elif resources.num_cores_per_mpiproc is not None: + if resources.num_cores_per_mpiproc < 1: + raise ValueError('num_cores_per_mpiproc must be greater than or equal to one.') + + # In this plugin we never used num_cores_per_mpiproc so if it is not defined it is OK. + resources.num_cores_per_machine = (resources.num_cores_per_mpiproc * resources.num_mpiprocs_per_machine) - if self.num_cores_per_machine is not None and self.num_cores_per_mpiproc is not None: - if self.num_cores_per_machine != (self.num_cores_per_mpiproc * self.num_mpiprocs_per_machine): - # If user specify both values, check if specified - # values are correct - raise ValueError(value_error) - elif self.num_cores_per_mpiproc is not None: - if self.num_cores_per_mpiproc <= 0: - raise ValueError('num_cores_per_mpiproc must be >=1') - # calculate num_cores_per_machine - # In this plugin we never used num_cores_per_mpiproc so if it - # is not defined it is OK. - self.num_cores_per_machine = (self.num_cores_per_mpiproc * self.num_mpiprocs_per_machine) + return resources class PbsBaseClass(Scheduler): diff --git a/aiida/schedulers/plugins/slurm.py b/aiida/schedulers/plugins/slurm.py index 888e2d24bb..70645b0bbc 100644 --- a/aiida/schedulers/plugins/slurm.py +++ b/aiida/schedulers/plugins/slurm.py @@ -103,49 +103,41 @@ class SlurmJobResource(NodeNumberJobResource): - """ - Slurm job resources object - """ - - def __init__(self, *args, **kwargs): - """ - It extends the base class init method and calculates the - num_cores_per_mpiproc fields to pass to Slurm schedulers. + """Class for SLURM job resources.""" - Checks that num_cores_per_machine should be a multiple of - num_cores_per_mpiproc and/or num_mpiprocs_per_machine + @classmethod + def validate_resources(cls, **kwargs): + """Validate the resources against the job resource class of this scheduler. - Check sequence + This extends the base class validator to check that the `num_cores_per_machine` are a multiple of + `num_cores_per_mpiproc` and/or `num_mpiprocs_per_machine`. - 1. If num_cores_per_mpiproc and num_cores_per_machine both are - specified check whether it satisfies the check - 2. If only num_cores_per_machine is passed, calculate - num_cores_per_mpiproc which should always be an integer value - 3. If only num_cores_per_mpiproc is passed, use it + :param kwargs: dictionary of values to define the job resources + :return: attribute dictionary with the parsed parameters populated + :raises ValueError: if the resources are invalid or incomplete """ - super().__init__(*args, **kwargs) + resources = super().validate_resources(**kwargs) - value_error = ( - 'num_cores_per_machine must be equal to ' - 'num_cores_per_mpiproc * num_mpiprocs_per_machine, ' - 'and in perticular it should be a multiple of ' - 'num_cores_per_mpiproc and/or num_mpiprocs_per_machine' - ) + if resources.num_cores_per_machine is not None and resources.num_cores_per_mpiproc is not None: + if resources.num_cores_per_machine != resources.num_cores_per_mpiproc * resources.num_mpiprocs_per_machine: + raise ValueError( + '`num_cores_per_machine` must be equal to `num_cores_per_mpiproc * num_mpiprocs_per_machine` and in' + ' particular it should be a multiple of `num_cores_per_mpiproc` and/or `num_mpiprocs_per_machine`' + ) + + elif resources.num_cores_per_machine is not None: + if resources.num_cores_per_machine < 1: + raise ValueError('num_cores_per_machine must be greater than or equal to one.') + + # In this plugin we never used num_cores_per_machine so if it is not defined it is OK. + resources.num_cores_per_mpiproc = (resources.num_cores_per_machine / resources.num_mpiprocs_per_machine) + if isinstance(resources.num_cores_per_mpiproc, int): + raise ValueError( + '`num_cores_per_machine` must be equal to `num_cores_per_mpiproc * num_mpiprocs_per_machine` and in' + ' particular it should be a multiple of `num_cores_per_mpiproc` and/or `num_mpiprocs_per_machine`' + ) - if self.num_cores_per_machine is not None and self.num_cores_per_mpiproc is not None: - if self.num_cores_per_machine != (self.num_cores_per_mpiproc * self.num_mpiprocs_per_machine): - # If user specify both values, check if specified - # values are correct - raise ValueError(value_error) - elif self.num_cores_per_machine is not None: - if self.num_cores_per_machine <= 0: - raise ValueError('num_cores_per_machine must be >=1') - # calculate num_cores_per_mpiproc - # In this plugin we never used num_cores_per_machine so if it - # is not defined it is OK. - self.num_cores_per_mpiproc = (self.num_cores_per_machine / self.num_mpiprocs_per_machine) - if isinstance(self.num_cores_per_mpiproc, int): - raise ValueError(value_error) + return resources class SlurmScheduler(Scheduler): diff --git a/aiida/schedulers/scheduler.py b/aiida/schedulers/scheduler.py index 14ddfe418e..a22ffeea04 100644 --- a/aiida/schedulers/scheduler.py +++ b/aiida/schedulers/scheduler.py @@ -8,18 +8,17 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Implementation of `Scheduler` base class.""" -from abc import abstractmethod +import abc -import aiida.common -from aiida.common.lang import classproperty +from aiida.common import exceptions, log from aiida.common.escaping import escape_for_bash -from aiida.common.exceptions import AiidaException, FeatureNotAvailable -from aiida.schedulers.datastructures import JobTemplate +from aiida.common.lang import classproperty +from aiida.schedulers.datastructures import JobResource, JobTemplate __all__ = ('Scheduler', 'SchedulerError', 'SchedulerParsingError') -class SchedulerError(AiidaException): +class SchedulerError(exceptions.AiidaException): pass @@ -27,12 +26,10 @@ class SchedulerParsingError(SchedulerError): pass -class Scheduler: - """ - Base class for all schedulers. - """ +class Scheduler(metaclass=abc.ABCMeta): + """Base class for a job scheduler.""" - _logger = aiida.common.AIIDA_LOGGER.getChild('scheduler') + _logger = log.AIIDA_LOGGER.getChild('scheduler') # A list of features # Features that should be defined in the plugins: @@ -44,27 +41,58 @@ class Scheduler: # The class to be used for the job resource. _job_resource_class = None - def __init__(self): - self._transport = None + @classmethod + def preprocess_resources(cls, resources, default_mpiprocs_per_machine=None): + """Pre process the resources. - def set_transport(self, transport): + Add the `num_mpiprocs_per_machine` key to the `resources` if it is not already defined and it cannot be deduced + from the `num_machines` and `tot_num_mpiprocs` being defined. The value is also not added if the job resource + class of this scheduler does not accept the `num_mpiprocs_per_machine` keyword. Note that the changes are made + in place to the `resources` argument passed. """ - Set the transport to be used to query the machine or to submit scripts. - This class assumes that the transport is open and active. + num_machines = resources.get('num_machines', None) + tot_num_mpiprocs = resources.get('tot_num_mpiprocs', None) + num_mpiprocs_per_machine = resources.get('num_mpiprocs_per_machine', None) + + if ( + num_mpiprocs_per_machine is None and cls.job_resource_class.accepts_default_mpiprocs_per_machine() # pylint: disable=no-member + and (num_machines is None or tot_num_mpiprocs is None) + ): + resources['num_mpiprocs_per_machine'] = default_mpiprocs_per_machine + + @classmethod + def validate_resources(cls, **resources): + """Validate the resources against the job resource class of this scheduler. + + :param resources: keyword arguments to define the job resources + :raises ValueError: if the resources are invalid or incomplete """ - self._transport = transport + cls._job_resource_class.validate_resources(**resources) + + def __init__(self): + self._transport = None + + if not issubclass(self._job_resource_class, JobResource): + raise RuntimeError('the class attribute `_job_resource_class` is not a subclass of `JobResource`.') @classmethod def get_valid_schedulers(cls): - from aiida.plugins.entry_point import get_entry_point_names + """Return all available scheduler plugins. + .. deprecated:: 1.3.0 + + Will be removed in `2.0.0`, use `aiida.plugins.entry_point.get_entry_point_names` instead + """ + import warnings + from aiida.common.warnings import AiidaDeprecationWarning + from aiida.plugins.entry_point import get_entry_point_names + message = 'method is deprecated, use `aiida.plugins.entry_point.get_entry_point_names` instead' + warnings.warn(message, AiidaDeprecationWarning) # pylint: disable=no-member return get_entry_point_names('aiida.schedulers') @classmethod def get_short_doc(cls): - """ - Return the first non-empty line of the class docstring, if available - """ + """Return the first non-empty line of the class docstring, if available.""" # Remove empty lines docstring = cls.__doc__ if not docstring: @@ -84,36 +112,26 @@ def get_feature(self, feature_name): @property def logger(self): - """ - Return the internal logger. - """ + """Return the internal logger.""" try: return self._logger except AttributeError: - from aiida.common.exceptions import InternalError - - raise InternalError('No self._logger configured for {}!') + raise exceptions.InternalError('No self._logger configured for {}!') @classproperty - def job_resource_class(self): - return self._job_resource_class + def job_resource_class(cls): # pylint: disable=no-self-argument + return cls._job_resource_class @classmethod def create_job_resource(cls, **kwargs): - """ - Create a suitable job resource from the kwargs specified - """ + """Create a suitable job resource from the kwargs specified.""" # pylint: disable=not-callable - - if cls._job_resource_class is None: - raise NotImplementedError - return cls._job_resource_class(**kwargs) def get_submit_script(self, job_tmpl): - """ - Return the submit script as a string. - :parameter job_tmpl: a aiida.schedulers.datastrutures.JobTemplate object. + """Return the submit script as a string. + + :parameter job_tmpl: a `aiida.schedulers.datastrutures.JobTemplate` instance. The plugin returns something like @@ -125,11 +143,8 @@ def get_submit_script(self, job_tmpl): postpend_code postpend_computer """ - - from aiida.common.exceptions import InternalError - if not isinstance(job_tmpl, JobTemplate): - raise InternalError('job_tmpl should be of type JobTemplate') + raise exceptions.InternalError('job_tmpl should be of type JobTemplate') empty_line = '' @@ -167,55 +182,30 @@ def get_submit_script(self, job_tmpl): return '\n'.join(script_lines) - @abstractmethod + @abc.abstractmethod def _get_submit_script_header(self, job_tmpl): - """ - Return the submit script header, using the parameters from the - job_tmpl. + """Return the submit script header, using the parameters from the job template. - :param job_tmpl: a JobTemplate instance with relevant parameters set. + :param job_tmpl: a `JobTemplate` instance with relevant parameters set. """ - raise NotImplementedError def _get_submit_script_footer(self, job_tmpl): - """ - Return the submit script final part, using the parameters from the - job_tmpl. + """Return the submit script final part, using the parameters from the job template. - :param job_tmpl: a JobTemplate instance with relevant parameters set. + :param job_tmpl: a `JobTemplate` instance with relevant parameters set. """ - # pylint: disable=no-self-use, unused-argument + # pylint: disable=no-self-use,unused-argument return None def _get_run_line(self, codes_info, codes_run_mode): - """ - Return a string with the line to execute a specific code with - specific arguments. - - :parameter codes_info: a list of aiida.common.datastructures.CodeInfo - objects. Each contains the information needed to run the code. I.e. - cmdline_params, stdin_name, stdout_name, stderr_name, join_files. - See the documentation of JobTemplate and CodeInfo - :parameter codes_run_mode: contains the information on how to launch the - multiple codes. As described in aiida.common.datastructures.CodeRunMode - - - argv: an array with the executable and the command line arguments. - The first argument is the executable. This should contain - everything, including the mpirun command etc. - stdin_name: the filename to be used as stdin, relative to the - working dir, or None if no stdin redirection is required. - stdout_name: the filename to be used to store the standard output, - relative to the working dir, - or None if no stdout redirection is required. - stderr_name: the filename to be used to store the standard error, - relative to the working dir, - or None if no stderr redirection is required. - join_files: if True, stderr is redirected to stdout; the value of - stderr_name is ignored. - - Return a string with the following format: - [executable] [args] {[ < stdin ]} {[ < stdout ]} {[2>&1 | 2> stderr]} + """Return a string with the line to execute a specific code with specific arguments. + + :parameter codes_info: a list of `aiida.common.datastructures.CodeInfo` objects. Each contains the information + needed to run the code. I.e. `cmdline_params`, `stdin_name`, `stdout_name`, `stderr_name`, `join_files`. See + the documentation of `JobTemplate` and `CodeInfo`. + :parameter codes_run_mode: instance of `aiida.common.datastructures.CodeRunMode` contains the information on how + to launch the multiple codes. + :return: string with format: [executable] [args] {[ < stdin ]} {[ < stdout ]} {[2>&1 | 2> stderr]} """ from aiida.common.datastructures import CodeRunMode @@ -251,40 +241,30 @@ def _get_run_line(self, codes_info, codes_run_mode): raise NotImplementedError('Unrecognized code run mode') - @abstractmethod + @abc.abstractmethod def _get_joblist_command(self, jobs=None, user=None): - """ - Return the qstat (or equivalent) command to run with the required - command-line parameters to get the most complete description possible; - also specifies the output format of qsub to be the one to be used - by the parse_queue_output method. + """Return the command to get the most complete description possible of currently active jobs. - Must be implemented in the plugin. + .. note:: - :param jobs: either None to get a list of all jobs in the machine, - or a list of jobs. - :param user: either None, or a string with the username (to show only - jobs of the specific user). + Typically one can pass only either jobs or user, depending on the specific plugin. The choice can be done + according to the value returned by `self.get_feature('can_query_by_user')` - Note: typically one can pass only either jobs or user, depending on the - specific plugin. The choice can be done according to the value - returned by self.get_feature('can_query_by_user') + :param jobs: either None to get a list of all jobs in the machine, or a list of jobs. + :param user: either None, or a string with the username (to show only jobs of the specific user). """ - raise NotImplementedError def _get_detailed_job_info_command(self, job_id): - """ - Return the command to run to get the detailed information on a job. - This is typically called after the job has finished, to retrieve - the most detailed information possible about the job. This is done - because most schedulers just make finished jobs disappear from the - 'qstat' command, and instead sometimes it is useful to know some - more detailed information about the job exit status, etc. + """Return the command to run to get detailed information for a given job. + + This is typically called after the job has finished, to retrieve the most detailed information possible about + the job. This is done because most schedulers just make finished jobs disappear from the `qstat` command, and + instead sometimes it is useful to know some more detailed information about the job exit status, etc. :raises: :class:`aiida.common.exceptions.FeatureNotAvailable` """ # pylint: disable=no-self-use,not-callable,unused-argument - raise FeatureNotAvailable('Cannot get detailed job info') + raise exceptions.FeatureNotAvailable('Cannot get detailed job info') def get_detailed_job_info(self, job_id): """Return the detailed job info. @@ -327,7 +307,7 @@ def get_detailed_jobinfo(self, jobid): with self.transport: retval, stdout, stderr = self.transport.exec_command_wait(command) - return u"""Detailed jobinfo obtained with command '{}' + return """Detailed jobinfo obtained with command '{}' Return Code: {} ------------------------------------------------------------- stdout: @@ -336,33 +316,23 @@ def get_detailed_jobinfo(self, jobid): {} """.format(command, retval, stdout, stderr) - @abstractmethod + @abc.abstractmethod def _parse_joblist_output(self, retval, stdout, stderr): - """ - Parse the joblist output ('qstat'), as returned by executing the - command returned by _get_joblist_command method. - - To be implemented by the plugin. + """Parse the joblist output as returned by executing the command returned by `_get_joblist_command` method. - Return a list of JobInfo objects, one of each job, - each with at least its default params implemented. + :return: list of `JobInfo` objects, one of each job each with at least its default params implemented. """ - raise NotImplementedError def get_jobs(self, jobs=None, user=None, as_dict=False): - """ - Get the list of jobs and return it. + """Return the list of currently active jobs. - Typically, this function does not need to be modified by the plugins. + .. note:: typically, only either jobs or user can be specified. See also comments in `_get_joblist_command`. :param list jobs: a list of jobs to check; only these are checked :param str user: a string with a user: only jobs of this user are checked - :param list as_dict: if False (default), a list of JobInfo objects is - returned. If True, a dictionary is returned, having as key the - job_id and as value the JobInfo object. - - Note: typically, only either jobs or user can be specified. See also - comments in _get_joblist_command. + :param list as_dict: if False (default), a list of JobInfo objects is returned. If True, a dictionary is + returned, having as key the job_id and as value the JobInfo object. + :return: list of active jobs """ with self.transport: retval, stdout, stderr = self.transport.exec_command_wait(self._get_joblist_command(jobs=jobs, user=user)) @@ -378,85 +348,66 @@ def get_jobs(self, jobs=None, user=None, as_dict=False): @property def transport(self): - """ - Return the transport set for this scheduler. - """ + """Return the transport set for this scheduler.""" if self._transport is None: raise SchedulerError('Use the set_transport function to set the transport for the scheduler first.') return self._transport - @abstractmethod - def _get_submit_command(self, submit_script): + def set_transport(self, transport): + """Set the transport to be used to query the machine or to submit scripts. + + This class assumes that the transport is open and active. """ - Return the string to execute to submit a given script. + self._transport = transport + + @abc.abstractmethod + def _get_submit_command(self, submit_script): + """Return the string to execute to submit a given script. - To be implemented by the plugin. + .. warning:: the `submit_script` should already have been bash-escaped - :param str submit_script: the path of the submit script relative to the - working directory. - IMPORTANT: submit_script should be already escaped. + :param submit_script: the path of the submit script relative to the working directory. :return: the string to execute to submit a given script. """ - raise NotImplementedError - @abstractmethod + @abc.abstractmethod def _parse_submit_output(self, retval, stdout, stderr): - """ - Parse the output of the submit command, as returned by executing the - command returned by _get_submit_command command. + """Parse the output of the submit command returned by calling the `_get_submit_command` command. - To be implemented by the plugin. - - :return: a string with the JobID. + :return: a string with the job ID. """ - raise NotImplementedError def submit_from_script(self, working_directory, submit_script): - """ - Goes in the working directory and submits the submit_script. + """Submit the submission script to the scheduler. - Return a string with the JobID in a valid format to be used for - querying. - - Typically, this function does not need to be modified by the plugins. + :return: return a string with the job ID in a valid format to be used for querying. """ - self.transport.chdir(working_directory) - retval, stdout, stderr = self.transport.exec_command_wait( - self._get_submit_command(escape_for_bash(submit_script)) - ) - return self._parse_submit_output(retval, stdout, stderr) + result = self.transport.exec_command_wait(self._get_submit_command(escape_for_bash(submit_script))) + return self._parse_submit_output(*result) def kill(self, jobid): - """ - Kill a remote job, and try to parse the output message of the scheduler - to check if the scheduler accepted the command. + """Kill a remote job and parse the return value of the scheduler to check if the command succeeded. - ..note:: On some schedulers, even if the command is accepted, it may - take some seconds for the job to actually disappear from the queue. + ..note:: - :param str jobid: the job id to be killed + On some schedulers, even if the command is accepted, it may take some seconds for the job to actually + disappear from the queue. + :param jobid: the job ID to be killed :return: True if everything seems ok, False otherwise. """ retval, stdout, stderr = self.transport.exec_command_wait(self._get_kill_command(jobid)) return self._parse_kill_output(retval, stdout, stderr) + @abc.abstractmethod def _get_kill_command(self, jobid): - """ - Return the command to kill the job with specified jobid. - - To be implemented by the plugin. - """ - raise NotImplementedError + """Return the command to kill the job with specified jobid.""" + @abc.abstractmethod def _parse_kill_output(self, retval, stdout, stderr): - """ - Parse the output of the kill command. - - To be implemented by the plugin. + """Parse the output of the kill command. :return: True if everything seems ok, False otherwise. """ - raise NotImplementedError diff --git a/tests/engine/test_calc_job.py b/tests/engine/test_calc_job.py index a2874294fc..38ef789484 100644 --- a/tests/engine/test_calc_job.py +++ b/tests/engine/test_calc_job.py @@ -163,7 +163,7 @@ def test_remote_code_unstored_computer(self): inputs['code'] = self.remote_code inputs['metadata']['computer'] = orm.Computer('different', 'localhost', 'desc', 'local', 'direct') - with self.assertRaises(exceptions.InputValidationError): + with self.assertRaises(ValueError): ArithmeticAddCalculation(inputs=inputs) def test_remote_code_set_computer_explicit(self): @@ -176,7 +176,7 @@ def test_remote_code_set_computer_explicit(self): inputs['code'] = self.remote_code # Setting explicitly a computer that is not the same as that of the `code` should raise - with self.assertRaises(exceptions.InputValidationError): + with self.assertRaises(ValueError): inputs['metadata']['computer'] = orm.Computer('different', 'localhost', 'desc', 'local', 'direct').store() process = ArithmeticAddCalculation(inputs=inputs) @@ -201,7 +201,7 @@ def test_local_code_no_computer(self): inputs = deepcopy(self.inputs) inputs['code'] = self.local_code - with self.assertRaises(exceptions.InputValidationError): + with self.assertRaises(ValueError): ArithmeticAddCalculation(inputs=inputs) def test_invalid_parser_name(self): @@ -210,7 +210,7 @@ def test_invalid_parser_name(self): inputs['code'] = self.remote_code inputs['metadata']['options']['parser_name'] = 'invalid_parser' - with self.assertRaises(exceptions.InputValidationError): + with self.assertRaises(ValueError): ArithmeticAddCalculation(inputs=inputs) def test_invalid_resources(self): @@ -219,9 +219,25 @@ def test_invalid_resources(self): inputs['code'] = self.remote_code inputs['metadata']['options']['resources'] = {'num_machines': 'invalid_type'} - with self.assertRaises(exceptions.InputValidationError): + with self.assertRaises(ValueError): ArithmeticAddCalculation(inputs=inputs) + def test_par_env_resources_computer(self): + """Test launching a `CalcJob` an a computer with a scheduler using `ParEnvJobResource` as resources. + + Even though the computer defines a default number of MPI procs per machine, it should not raise when the + scheduler that is defined does not actually support it, for example SGE or LSF. + """ + inputs = deepcopy(self.inputs) + computer = orm.Computer('sge_computer', 'localhost', 'desc', 'local', 'sge').store() + computer.set_default_mpiprocs_per_machine(1) + + inputs['code'] = orm.Code(remote_computer_exec=(computer, '/bin/bash')).store() + inputs['metadata']['options']['resources'] = {'parallel_env': 'environment', 'tot_num_mpiprocs': 10} + + # Just checking that instantiating does not raise, meaning the inputs were valid + ArithmeticAddCalculation(inputs=inputs) + @pytest.mark.timeout(5) @patch.object(CalcJob, 'presubmit', partial(raise_exception, exceptions.InputValidationError)) def test_exception_presubmit(self): diff --git a/tests/schedulers/test_datastructures.py b/tests/schedulers/test_datastructures.py index 5764ce1494..f6745ad172 100644 --- a/tests/schedulers/test_datastructures.py +++ b/tests/schedulers/test_datastructures.py @@ -7,105 +7,147 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" -Datastructures test -""" -import unittest +"""Tests for the :mod:`aiida.schedulers.test_datastructures` module.""" +import pytest +from aiida.schedulers.datastructures import NodeNumberJobResource, ParEnvJobResource -class TestNodeNumberJobResource(unittest.TestCase): - """Unit tests for the NodeNumberJobResource class.""" - def test_init(self): - """ - Test the __init__ of the NodeNumberJobResource class - """ - from aiida.schedulers.datastructures import NodeNumberJobResource +class TestNodeNumberJobResource: + """Tests for the :class:`~aiida.schedulers.datastructures.NodeNumberJobResource`.""" - # No empty initialization - with self.assertRaises(TypeError): - _ = NodeNumberJobResource() + @staticmethod + def test_validate_resources(): + """Test the `validate_resources` method.""" + cls = NodeNumberJobResource + + with pytest.raises(ValueError): + cls.validate_resources() # Missing required field - with self.assertRaises(TypeError): - _ = NodeNumberJobResource(num_machines=1) - with self.assertRaises(TypeError): - _ = NodeNumberJobResource(num_mpiprocs_per_machine=1) - with self.assertRaises(TypeError): - _ = NodeNumberJobResource(tot_num_mpiprocs=1) + with pytest.raises(ValueError): + cls.validate_resources(num_machines=1) + with pytest.raises(ValueError): + cls.validate_resources(num_mpiprocs_per_machine=1) + with pytest.raises(ValueError): + cls.validate_resources(tot_num_mpiprocs=1) + + # Wrong field name + with pytest.raises(ValueError): + cls.validate_resources(num_machines=2, num_mpiprocs_per_machine=8, wrong_name=16) + + # Examples of wrong information (e.g., number of machines or of nodes < 0 + with pytest.raises(ValueError): + cls.validate_resources(num_machines=0, num_mpiprocs_per_machine=8) + with pytest.raises(ValueError): + cls.validate_resources(num_machines=1, num_mpiprocs_per_machine=0) + with pytest.raises(ValueError): + cls.validate_resources(num_machines=1, tot_num_mpiprocs=0) + with pytest.raises(ValueError): + cls.validate_resources(num_mpiprocs_per_machine=1, tot_num_mpiprocs=0) + + # Examples of inconsistent information + with pytest.raises(ValueError): + cls.validate_resources(num_mpiprocs_per_machine=8, num_machines=2, tot_num_mpiprocs=32) + + with pytest.raises(ValueError): + cls.validate_resources(num_mpiprocs_per_machine=8, tot_num_mpiprocs=15) + @staticmethod + def test_constructor(): + """Test that constructor defines all valid keys even if not all defined explicitly.""" # Standard info job_resource = NodeNumberJobResource(num_machines=2, num_mpiprocs_per_machine=8) - self.assertEqual(job_resource.num_machines, 2) - self.assertEqual(job_resource.num_mpiprocs_per_machine, 8) - self.assertEqual(job_resource.get_tot_num_mpiprocs(), 16) - # redundant but consistent information + assert job_resource.num_machines == 2 + assert job_resource.num_mpiprocs_per_machine == 8 + assert job_resource.get_tot_num_mpiprocs() == 16 + + # Redundant but consistent information job_resource = NodeNumberJobResource(num_machines=2, num_mpiprocs_per_machine=8, tot_num_mpiprocs=16) - self.assertEqual(job_resource.num_machines, 2) - self.assertEqual(job_resource.num_mpiprocs_per_machine, 8) - self.assertEqual(job_resource.get_tot_num_mpiprocs(), 16) - # other equivalent ways of specifying the information + assert job_resource.num_machines == 2 + assert job_resource.num_mpiprocs_per_machine == 8 + assert job_resource.get_tot_num_mpiprocs() == 16 + + # Other equivalent ways of specifying the information job_resource = NodeNumberJobResource(num_mpiprocs_per_machine=8, tot_num_mpiprocs=16) - self.assertEqual(job_resource.num_machines, 2) - self.assertEqual(job_resource.num_mpiprocs_per_machine, 8) - self.assertEqual(job_resource.get_tot_num_mpiprocs(), 16) - # other equivalent ways of specifying the information + assert job_resource.num_machines == 2 + assert job_resource.num_mpiprocs_per_machine == 8 + assert job_resource.get_tot_num_mpiprocs() == 16 + + # Other equivalent ways of specifying the information job_resource = NodeNumberJobResource(num_machines=2, tot_num_mpiprocs=16) - self.assertEqual(job_resource.num_machines, 2) - self.assertEqual(job_resource.num_mpiprocs_per_machine, 8) - self.assertEqual(job_resource.get_tot_num_mpiprocs(), 16) - - # wrong field name - with self.assertRaises(TypeError): - _ = NodeNumberJobResource(num_machines=2, num_mpiprocs_per_machine=8, wrong_name=16) - - # Examples of wrong informaton (e.g., number of machines or of nodes < 0 - with self.assertRaises(ValueError): - _ = NodeNumberJobResource(num_machines=0, num_mpiprocs_per_machine=8) - with self.assertRaises(ValueError): - _ = NodeNumberJobResource(num_machines=1, num_mpiprocs_per_machine=0) - with self.assertRaises(ValueError): - _ = NodeNumberJobResource(num_machines=1, tot_num_mpiprocs=0) - with self.assertRaises(ValueError): - _ = NodeNumberJobResource(num_mpiprocs_per_machine=1, tot_num_mpiprocs=0) + assert job_resource.num_machines == 2 + assert job_resource.num_mpiprocs_per_machine == 8 + assert job_resource.get_tot_num_mpiprocs() == 16 - # Examples of inconsistent information - with self.assertRaises(ValueError): - _ = NodeNumberJobResource(num_mpiprocs_per_machine=8, num_machines=2, tot_num_mpiprocs=32) - - with self.assertRaises(ValueError): - _ = NodeNumberJobResource(num_mpiprocs_per_machine=8, tot_num_mpiprocs=15) - - def test_serialization(self): - """Test the serialization/deserialization of JobInfo classes.""" - from aiida.schedulers.datastructures import JobInfo, JobState - from datetime import datetime - - dict_serialized_content = { - 'job_id': '12723', - 'title': 'some title', - 'queue_name': 'some_queue', - 'account': 'my_account' - } - - to_serialize = {'job_state': (JobState.QUEUED, 'job_state'), 'submission_time': (datetime.now(), 'date')} - - job_info = JobInfo() - for key, val in dict_serialized_content.items(): - setattr(job_info, key, val) - - for key, (val, field_type) in to_serialize.items(): - setattr(job_info, key, val) - # Also append to the dictionary for easier comparison later - dict_serialized_content[key] = JobInfo.serialize_field(value=val, field_type=field_type) - - self.assertEqual(job_info.get_dict(), dict_serialized_content) - # Full loop via JSON, moving data from job_info to job_info2; - # we check that the content is fully preserved - job_info2 = JobInfo.load_from_serialized(job_info.serialize()) - self.assertEqual(job_info2.get_dict(), dict_serialized_content) - - # Check that fields are properly re-serialized with the correct type - self.assertEqual(job_info2.job_state, to_serialize['job_state'][0]) - # Check that fields are properly re-serialized with the correct type - self.assertEqual(job_info2.submission_time, to_serialize['submission_time'][0]) + +class TestParEnvJobResource: + """Tests for the :class:`~aiida.schedulers.datastructures.ParEnvJobResource`.""" + + @staticmethod + def test_validate_resources(): + """Test the `validate_resources` method.""" + cls = ParEnvJobResource + + with pytest.raises(ValueError): + cls.validate_resources() + + # Missing required field + with pytest.raises(ValueError): + cls.validate_resources(parallel_env='env') + with pytest.raises(ValueError): + cls.validate_resources(tot_num_mpiprocs=1) + + # Wrong types + with pytest.raises(ValueError): + cls.validate_resources(parallel_env={}, tot_num_mpiprocs=1) + with pytest.raises(ValueError): + cls.validate_resources(parallel_env='env', tot_num_mpiprocs='test') + with pytest.raises(ValueError): + cls.validate_resources(parallel_env='env', tot_num_mpiprocs=0) + + # Wrong field name + with pytest.raises(ValueError): + cls.validate_resources(parallel_env='env', tot_num_mpiprocs=1, wrong_name=16) + + @staticmethod + def test_constructor(): + """Test that constructor defines all valid keys even if not all defined explicitly.""" + job_resource = ParEnvJobResource(parallel_env='env', tot_num_mpiprocs=1) + assert job_resource.parallel_env == 'env' + assert job_resource.tot_num_mpiprocs == 1 + + +def test_serialization(): + """Test the serialization/deserialization of JobInfo classes.""" + from aiida.schedulers.datastructures import JobInfo, JobState + from datetime import datetime + + dict_serialized_content = { + 'job_id': '12723', + 'title': 'some title', + 'queue_name': 'some_queue', + 'account': 'my_account' + } + + to_serialize = {'job_state': (JobState.QUEUED, 'job_state'), 'submission_time': (datetime.now(), 'date')} + + job_info = JobInfo() + for key, val in dict_serialized_content.items(): + setattr(job_info, key, val) + + for key, (val, field_type) in to_serialize.items(): + setattr(job_info, key, val) + # Also append to the dictionary for easier comparison later + dict_serialized_content[key] = JobInfo.serialize_field(value=val, field_type=field_type) + + assert job_info.get_dict() == dict_serialized_content + # Full loop via JSON, moving data from job_info to job_info2; + # we check that the content is fully preserved + job_info2 = JobInfo.load_from_serialized(job_info.serialize()) + assert job_info2.get_dict() == dict_serialized_content + + # Check that fields are properly re-serialized with the correct type + assert job_info2.job_state == to_serialize['job_state'][0] + # Check that fields are properly re-serialized with the correct type + assert job_info2.submission_time == to_serialize['submission_time'][0]