Skip to content

Commit

Permalink
TMP: example of switching to pydantic
Browse files Browse the repository at this point in the history
  • Loading branch information
mbercx committed Apr 22, 2023
1 parent 6b915bd commit de5c04c
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 60 deletions.
46 changes: 19 additions & 27 deletions aiida_submission_controller/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,39 @@
"""A prototype class to submit processes in batches, avoiding to submit too many."""
import abc
import logging
from pydantic import BaseModel, validator

from aiida import engine, orm
from aiida.common import NotExistent

CMDLINE_LOGGER = logging.getLogger('verdi')


class BaseSubmissionController:
def validate_group_exists(value: str) -> str:
try:
orm.Group.collection.get(label=value)
except NotExistent as exc:
raise ValueError(f'Group with label `{value}` does not exist.') from exc
else:
return value


class BaseSubmissionController(BaseModel):
"""Controller to submit a maximum number of processes (workflows or calculations) at a given time.
This is an abstract base class: you need to subclass it and define the abstract methods.
"""
def __init__(self, group_label, max_concurrent):
"""Create a new controller to manage (and limit) concurrent submissions.
:param group_label: a group label: the group will be created at instantiation (if not existing already,
and it will be used to manage the calculations)
:param extra_unique_keys: a tuple or list of keys of extras that are used to uniquely identify
a process in the group. E.g. ('value1', 'value2').
:note: try to use actual values that allow for an equality comparison (strings, bools, integers), and avoid
floats, because of truncation errors.
"""
self._group_label = group_label
self._max_concurrent = max_concurrent

# Create the group if needed
self._group, _ = orm.Group.objects.get_or_create(self.group_label)
group_label: str
"""Label of the group to store the process nodes in."""
max_concurrent: int
"""Maximum concurrent active processes."""

@property
def group_label(self):
"""Return the label of the group that is managed by this class."""
return self._group_label
_validate_group_exists = validator('group_label', allow_reuse=True)(validate_group_exists)

@property
def group(self):
"""Return the AiiDA ORM Group instance that is managed by this class."""
return self._group

@property
def max_concurrent(self):
"""Value of the maximum number of concurrent processes that can be run."""
return self._max_concurrent
return orm.Group.objects.get(label=self.group_label)

def get_query(self, process_projections, only_active=False):
"""Return a QueryBuilder object to get all processes in the group associated to this.
Expand Down
28 changes: 6 additions & 22 deletions aiida_submission_controller/from_group.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,24 @@
# -*- coding: utf-8 -*-
"""A prototype class to submit processes in batches, avoiding to submit too many."""
from aiida import orm
from .base import BaseSubmissionController

from .base import BaseSubmissionController, validate_group_exists
from pydantic import validator

class FromGroupSubmissionController(BaseSubmissionController): # pylint: disable=abstract-method
"""SubmissionController implementation getting data to submit from a parent group.
This is (still) an abstract base class: you need to subclass it
and define the abstract methods.
"""
def __init__(self, parent_group_label, *args, **kwargs):
"""Create a new controller to manage (and limit) concurrent submissions.
:param parent_group_label: a group label: the group will be used to decide
which submissions to use. The group must already exist. Extras (in the method
`get_all_extras_to_submit`) will be returned from all extras in that group
(you need to make sure they are unique).
parent_group_label: str
"""Label of the parent group from which to construct the process inputs."""

For all other parameters, see the docstring of ``BaseSubmissionController.__init__``.
"""
super().__init__(*args, **kwargs)
self._parent_group_label = parent_group_label
# Load the group (this also ensures it exists)
self._parent_group = orm.Group.objects.get(
label=self.parent_group_label)

@property
def parent_group_label(self):
"""Return the label of the parent group that is used as a reference."""
return self._parent_group_label
_validate_group_exists = validator('parent_group_label', allow_reuse=True)(validate_group_exists)

@property
def parent_group(self):
"""Return the AiiDA ORM Group instance of the parent group."""
return self._parent_group
return orm.Group.objects.get(label=self.parent_group_label)

def get_parent_node_from_extras(self, extras_values):
"""Return the Node instance (in the parent group) from the (unique) extras identifying it."""
Expand Down
30 changes: 19 additions & 11 deletions examples/add_in_batches.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
# -*- coding: utf-8 -*-
"""An example of a SubmissionController implementation to compute a 12x12 table of additions."""

from aiida import orm, plugins
from pydantic import validator
from aiida import orm
from aiida_submission_controller import BaseSubmissionController

from aiida.plugins import CalculationFactory


class AdditionTableSubmissionController(BaseSubmissionController):
"""The implementation of a SubmissionController to compute a 12x12 table of additions."""
def __init__(self, code_name, *args, **kwargs):
"""Pass also a code name, that should be a code associated to an `arithmetic.add` plugin."""
super().__init__(*args, **kwargs)
self._code = orm.load_code(code_name)
self._process_class = plugins.CalculationFactory('arithmetic.add')
code_label: str
"""Label of the `code.arithmetic.add` `Code`."""

@validator('code_label')
def _check_code_plugin(cls, value):
plugin_type = orm.load_code(value).default_calc_job_plugin
if plugin_type == 'core.arithmetic.add':
return value
raise ValueError(f'Code with label `{value}` has incorrect plugin type: `{plugin_type}`')

def get_extra_unique_keys(self):
"""Return a tuple of the keys of the unique extras that will be used to uniquely identify your workchains.
Expand All @@ -37,12 +43,13 @@ def get_inputs_and_processclass_from_extras(self, extras_values):
I just submit an ArithmeticAdd calculation summing the two values stored in the extras:
``left_operand + right_operand``.
"""
code = orm.load_code(self.code_label)
inputs = {
'code': self._code,
'code': code,
'x': orm.Int(extras_values[0]),
'y': orm.Int(extras_values[1])
}
return inputs, self._process_class
return inputs, CalculationFactory(code.get_input_plugin_name())


def main():
Expand All @@ -55,9 +62,10 @@ def main():
## verdi code setup -L add --on-computer --computer=localhost -P arithmetic.add --remote-abs-path=/bin/bash -n
# Create a controller
controller = AdditionTableSubmissionController(
code_name='add@localhost',
code_label='add@localhost',
group_label='tests/addition_table',
max_concurrent=10)
max_concurrent=10
)

print('Max concurrent :', controller.max_concurrent)
print('Active slots :', controller.num_active_slots)
Expand Down

0 comments on commit de5c04c

Please sign in to comment.