-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add descriptors * Refactor descriptors * Add descriptor options to CLI * Add descriptor tests * Update docs * Apply suggestions from code review Co-authored-by: Jacob Wilkins <46597752+oerc0122@users.noreply.github.com> * Fix test --------- Co-authored-by: Jacob Wilkins <46597752+oerc0122@users.noreply.github.com>
- Loading branch information
1 parent
2fad7f1
commit 6a03e15
Showing
9 changed files
with
787 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
"""Class to run descriptors calculations.""" | ||
|
||
from aiida.common import datastructures | ||
import aiida.common.folders | ||
from aiida.engine import CalcJobProcessSpec | ||
import aiida.engine.processes | ||
from aiida.orm import Bool | ||
|
||
from aiida_mlip.calculations.singlepoint import Singlepoint | ||
|
||
|
||
class Descriptors(Singlepoint): # numpydoc ignore=PR01 | ||
""" | ||
Calcjob implementation to calculate MLIP descriptors. | ||
Methods | ||
------- | ||
define(spec: CalcJobProcessSpec) -> None: | ||
Define the process specification, its inputs, outputs and exit codes. | ||
prepare_for_submission(folder: Folder) -> CalcInfo: | ||
Create the input files for the `CalcJob`. | ||
""" | ||
|
||
@classmethod | ||
def define(cls, spec: CalcJobProcessSpec) -> None: | ||
""" | ||
Define the process specification, its inputs, outputs and exit codes. | ||
Parameters | ||
---------- | ||
spec : aiida.engine.CalcJobProcessSpec | ||
The calculation job process spec to define. | ||
""" | ||
super().define(spec) | ||
|
||
# Define inputs | ||
|
||
# Remove unused singlepoint input | ||
del spec.inputs["properties"] | ||
|
||
spec.input( | ||
"invariants_only", | ||
valid_type=Bool, | ||
required=False, | ||
help="Only calculate invariant descriptors.", | ||
) | ||
|
||
spec.input( | ||
"calc_per_element", | ||
valid_type=Bool, | ||
required=False, | ||
help="Calculate mean descriptors for each element.", | ||
) | ||
|
||
spec.input( | ||
"calc_per_atom", | ||
valid_type=Bool, | ||
required=False, | ||
help="Calculate descriptors for each atom.", | ||
) | ||
|
||
spec.inputs["metadata"]["options"][ | ||
"parser_name" | ||
].default = "mlip.descriptors_parser" | ||
|
||
# pylint: disable=too-many-locals | ||
def prepare_for_submission( | ||
self, folder: aiida.common.folders.Folder | ||
) -> datastructures.CalcInfo: | ||
""" | ||
Create the input files for the `Calcjob`. | ||
Parameters | ||
---------- | ||
folder : aiida.common.folders.Folder | ||
Folder where the calculation is run. | ||
Returns | ||
------- | ||
aiida.common.datastructures.CalcInfo | ||
An instance of `aiida.common.datastructures.CalcInfo`. | ||
""" | ||
# Call the parent class method to prepare common inputs | ||
calcinfo = super().prepare_for_submission(folder) | ||
codeinfo = calcinfo.codes_info[0] | ||
|
||
# Adding command line params for when we run janus | ||
# descriptors is overwriting the placeholder "calculation" from the base.py file | ||
codeinfo.cmdline_params[0] = "descriptors" | ||
|
||
cmdline_options = { | ||
key.replace("_", "-"): getattr(self.inputs, key).value | ||
for key in ("invariants_only", "calc_per_element", "calc_per_atom") | ||
if key in self.inputs | ||
} | ||
|
||
for flag, value in cmdline_options.items(): | ||
if isinstance(value, bool): | ||
# Add boolean flags without value if True | ||
if value: | ||
codeinfo.cmdline_params.append(f"--{flag}") | ||
else: | ||
codeinfo.cmdline_params += [f"--{flag}", value] | ||
|
||
return calcinfo |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
"""Parsers provided by aiida_mlip.""" | ||
|
||
from aiida.common import exceptions | ||
from aiida.orm.nodes.process.process import ProcessNode | ||
from aiida.plugins import CalculationFactory | ||
|
||
from aiida_mlip.parsers.sp_parser import SPParser | ||
|
||
DescriptorsCalc = CalculationFactory("mlip.descriptors") | ||
|
||
|
||
class DescriptorsParser(SPParser): | ||
""" | ||
Parser class for parsing output of descriptors calculation. | ||
Inherits from SPParser. | ||
Parameters | ||
---------- | ||
node : aiida.orm.nodes.process.process.ProcessNode | ||
ProcessNode of calculation. | ||
Raises | ||
------ | ||
exceptions.ParsingError | ||
If the ProcessNode being passed was not produced by a DescriptorsCalc. | ||
""" | ||
|
||
def __init__(self, node: ProcessNode): | ||
""" | ||
Check that the ProcessNode being passed was produced by a `Descriptors`. | ||
Parameters | ||
---------- | ||
node : aiida.orm.nodes.process.process.ProcessNode | ||
ProcessNode of calculation. | ||
""" | ||
super().__init__(node) | ||
|
||
if not issubclass(node.process_class, DescriptorsCalc): | ||
raise exceptions.ParsingError("Can only parse `Descriptors` calculations") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
"""Example code for submitting descriptors calculation.""" | ||
|
||
import click | ||
|
||
from aiida.common import NotExistent | ||
from aiida.engine import run_get_node | ||
from aiida.orm import Bool, Str, load_code | ||
from aiida.plugins import CalculationFactory | ||
|
||
from aiida_mlip.helpers.help_load import load_model, load_structure | ||
|
||
|
||
def descriptors(params: dict) -> None: | ||
""" | ||
Prepare inputs and run a descriptors calculation. | ||
Parameters | ||
---------- | ||
params : dict | ||
A dictionary containing the input parameters for the calculations | ||
Returns | ||
------- | ||
None | ||
""" | ||
structure = load_structure(params["struct"]) | ||
|
||
# Select model to use | ||
model = load_model(params["model"], params["arch"]) | ||
|
||
# Select calculation to use | ||
DescriptorsCalc = CalculationFactory("mlip.descriptors") | ||
|
||
# Define inputs | ||
inputs = { | ||
"metadata": {"options": {"resources": {"num_machines": 1}}}, | ||
"code": params["code"], | ||
"arch": Str(params["arch"]), | ||
"struct": structure, | ||
"model": model, | ||
"precision": Str(params["precision"]), | ||
"device": Str(params["device"]), | ||
"invariants_only": Bool(params["invariants_only"]), | ||
"calc_per_element": Bool(params["calc_per_element"]), | ||
"calc_per_atom": Bool(params["calc_per_atom"]), | ||
} | ||
|
||
# Run calculation | ||
result, node = run_get_node(DescriptorsCalc, **inputs) | ||
print(f"Printing results from calculation: {result}") | ||
print(f"Printing node of calculation: {node}") | ||
|
||
|
||
# Arguments and options to give to the cli when running the script | ||
@click.command("cli") | ||
@click.argument("codelabel", type=str) | ||
@click.option( | ||
"--struct", | ||
default=None, | ||
type=str, | ||
help="Specify the structure (aiida node or path to a structure file)", | ||
) | ||
@click.option( | ||
"--model", | ||
default=None, | ||
type=str, | ||
help="Specify path or URI of the model to use", | ||
) | ||
@click.option( | ||
"--arch", | ||
default="mace_mp", | ||
type=str, | ||
help="MLIP architecture to use for calculations.", | ||
) | ||
@click.option( | ||
"--device", default="cpu", type=str, help="Device to run calculations on." | ||
) | ||
@click.option( | ||
"--precision", default="float64", type=str, help="Chosen level of precision." | ||
) | ||
@click.option( | ||
"--invariants-only", | ||
default=False, | ||
type=bool, | ||
help="Only calculate invariant descriptors.", | ||
) | ||
@click.option( | ||
"--calc-per-element", | ||
default=False, | ||
type=bool, | ||
help="Calculate mean descriptors for each element.", | ||
) | ||
@click.option( | ||
"--calc-per-atom", | ||
default=False, | ||
type=bool, | ||
help="Calculate descriptors for each atom.", | ||
) | ||
def cli( | ||
codelabel, | ||
struct, | ||
model, | ||
arch, | ||
device, | ||
precision, | ||
invariants_only, | ||
calc_per_element, | ||
calc_per_atom, | ||
) -> None: | ||
# pylint: disable=too-many-arguments | ||
"""Click interface.""" | ||
try: | ||
code = load_code(codelabel) | ||
except NotExistent as exc: | ||
print(f"The code '{codelabel}' does not exist.") | ||
raise SystemExit from exc | ||
|
||
params = { | ||
"code": code, | ||
"struct": struct, | ||
"model": model, | ||
"arch": arch, | ||
"device": device, | ||
"precision": precision, | ||
"invariants_only": invariants_only, | ||
"calc_per_element": calc_per_element, | ||
"calc_per_atom": calc_per_atom, | ||
} | ||
|
||
# Submit descriptors | ||
descriptors(params) | ||
|
||
|
||
if __name__ == "__main__": | ||
cli() # pylint: disable=no-value-for-parameter |
Oops, something went wrong.