Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

134 model and architecture bug #138

Merged
merged 11 commits into from
Jun 7, 2024
110 changes: 75 additions & 35 deletions aiida_mlip/calculations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def validate_inputs(
The inputs dictionary.

port_namespace : `aiida.engine.processes.ports.PortNamespace`
An instance of aiida's `PortNameSpace`.
An instance of aiida's `PortNamespace`.

Raises
------
Expand All @@ -44,6 +44,30 @@ def validate_inputs(
raise InputValidationError(
"Structure must be specified through 'struct' or 'config'"
)
if (
"arch" not in inputs
and "model" not in inputs
and ("config" not in inputs or "arch" not in inputs["config"])
):
raise InputValidationError(
"'arch' must be specified in inputs, config file or ModelData"
)

if "model" not in inputs and (
"config" not in inputs or "model" not in inputs["config"]
):
raise InputValidationError(
"'model' must be specified either in the inputs or in the config file"
)

if (
"arch" in inputs
and "model" in inputs
and inputs["arch"].value is not inputs["model"].architecture
):
raise InputValidationError(
"'arch' in ModelData and in 'arch' input must be the same"
)


class BaseJanus(CalcJob): # numpydoc ignore=PR01
Expand Down Expand Up @@ -206,42 +230,10 @@ def prepare_for_submission(

# Define architecture from model if model is given,
# otherwise get architecture from inputs and download default model
architecture = None
architecture = (
str((self.inputs.model).architecture)
if "model" in self.inputs and hasattr(self.inputs.model, "architecture")
else str(self.inputs.arch.value) if "arch" in self.inputs else None
)

if architecture:
cmd_line["arch"] = architecture

model_path = None
if "model" in self.inputs:
# Raise error if model is None
if self.inputs.model is None:
raise ValueError("Model cannot be None")
model_path = self.inputs.model.filepath
else:
if "config" in self.inputs and "model" in self.inputs.config:
model_path = None
else:
if "arch" in self.inputs:
# if model is not given (which is different than it being None)
model_path = ModelData.download(
"https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model", # pylint: disable=line-too-long
architecture,
).filepath
if model_path:
cmd_line.setdefault("calc-kwargs", {})["model"] = model_path
self._add_arch_to_cmdline(cmd_line)
self._add_model_to_cmdline(cmd_line)

if "config" in self.inputs:
# Check if there are values in the config file that are also in the command
# line and do not store them as only the cmd line parameters will be used
config_dict = self.inputs.config.as_dictionary
overlapping_params = cmd_line.keys() & config_dict.keys()
# Store the other parameters
self.inputs.config.store_content(skip=overlapping_params)
# Add config file to command line
cmd_line["config"] = "config.yaml"
config_parse = self.inputs.config.get_content()
Expand Down Expand Up @@ -274,3 +266,51 @@ def prepare_for_submission(
]

return calcinfo

def _add_arch_to_cmdline(self, cmd_line: dict) -> dict:
"""
Find architecture in inputs or config file and add to command line if needed.

Parameters
----------
cmd_line : dict
Dictionary containing the cmd line keys.

Returns
-------
dict
Dictionary containing the cmd line keys updated with the architecture.
"""
architecture = None
if "model" in self.inputs and hasattr(self.inputs.model, "architecture"):
architecture = str((self.inputs.model).architecture)
elif "arch" in self.inputs:
architecture = str(self.inputs.arch.value)
if architecture:
cmd_line["arch"] = architecture

def _add_model_to_cmdline(
self,
cmd_line: dict,
) -> dict:
"""
Find model in inputs or config file and add to command line if needed.

Parameters
----------
cmd_line : dict
Dictionary containing the cmd line keys.

Returns
-------
dict
Dictionary containing the cmd line keys updated with the model.
"""
model_path = None
if "model" in self.inputs:
# Raise error if model is None (different than model not given as input)
if self.inputs.model is None:
raise ValueError("Model cannot be None")
model_path = self.inputs.model.filepath
if model_path:
cmd_line.setdefault("calc-kwargs", {})["model"] = model_path
3 changes: 3 additions & 0 deletions tests/calculations/configs/config_noarch.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
properties:
- "energy"
model: "small"
3 changes: 3 additions & 0 deletions tests/calculations/configs/config_nomodel.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
properties:
- "energy"
arch: "mace_mp"
67 changes: 43 additions & 24 deletions tests/calculations/test_singlepoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from aiida.orm import Str, StructureData
from aiida.plugins import CalculationFactory

from aiida_mlip.data.config import JanusConfigfile
from aiida_mlip.data.model import ModelData


Expand Down Expand Up @@ -61,48 +62,66 @@ def test_singlepoint(fixture_sandbox, generate_calc_job, janus_code, model_folde
assert sorted(calc_info.retrieve_list) == sorted(retrieve_list)


def test_singlepoint_model_download(fixture_sandbox, generate_calc_job, janus_code):
"""Test generating singlepoint calculation job."""

def test_sp_nostruct(fixture_sandbox, generate_calc_job, model_folder, janus_code):
"""Test singlepoint calculation with error input"""
entry_point_name = "mlip.sp"
model_file = model_folder / "mace_mp_small.model"
# pylint:disable=line-too-long
inputs = {
"metadata": {"options": {"resources": {"num_machines": 1}}},
"code": janus_code,
"arch": Str("mace"),
federicazanca marked this conversation as resolved.
Show resolved Hide resolved
"precision": Str("float64"),
"struct": StructureData(ase=bulk("NaCl", "rocksalt", 5.63)),
"model": ModelData.local_file(model_file, architecture="mace"),
"device": Str("cpu"),
}
with pytest.raises(InputValidationError):
generate_calc_job(fixture_sandbox, entry_point_name, inputs)

calc_info = generate_calc_job(fixture_sandbox, entry_point_name, inputs)

retrieve_list = [
calc_info.uuid,
"aiida.log",
"aiida-results.xyz",
"aiida-stdout.txt",
]
def test_sp_nomodel(fixture_sandbox, generate_calc_job, config_folder, janus_code):
"""Test singlepoint calculation with missing model"""
entry_point_name = "mlip.sp"

# Check the attributes of the returned `CalcInfo`
assert fixture_sandbox.get_content_list() == ["aiida.xyz"]
assert isinstance(calc_info, datastructures.CalcInfo)
assert isinstance(calc_info.codes_info[0], datastructures.CodeInfo)
assert sorted(calc_info.retrieve_list) == sorted(retrieve_list)
inputs = {
"code": janus_code,
"metadata": {"options": {"resources": {"num_machines": 1}}},
"config": JanusConfigfile(config_folder / "config_nomodel.yml"),
"struct": StructureData(ase=bulk("NaCl", "rocksalt", 5.63)),
}

with pytest.raises(InputValidationError):
generate_calc_job(fixture_sandbox, entry_point_name, inputs)

def test_sp_nostruct(fixture_sandbox, generate_calc_job, model_folder, janus_code):
"""Test singlepoint calculation with error input"""

def test_sp_noarch(fixture_sandbox, generate_calc_job, config_folder, janus_code):
"""Test singlepoint calculation with missing architecture"""
entry_point_name = "mlip.sp"
model_file = model_folder / "mace_mp_small.model"
# pylint:disable=line-too-long

inputs = {
"code": janus_code,
"metadata": {"options": {"resources": {"num_machines": 1}}},
"config": JanusConfigfile(config_folder / "config_noarch.yml"),
"struct": StructureData(ase=bulk("NaCl", "rocksalt", 5.63)),
}

with pytest.raises(InputValidationError):
generate_calc_job(fixture_sandbox, entry_point_name, inputs)


def test_two_arch(fixture_sandbox, generate_calc_job, model_folder, janus_code):
"""Test singlepoint calculation with two defined architectures"""
entry_point_name = "mlip.sp"
model_file = model_folder / "mace_mp_small.model"

inputs = {
"code": janus_code,
"arch": Str("mace"),
"precision": Str("float64"),
"model": ModelData.local_file(model_file, architecture="mace"),
"device": Str("cpu"),
"metadata": {"options": {"resources": {"num_machines": 1}}},
"model": ModelData.local_file(model_file, architecture="mace_mp"),
"arch": Str("chgnet"),
"struct": StructureData(ase=bulk("NaCl", "rocksalt", 5.63)),
}

with pytest.raises(InputValidationError):
generate_calc_job(fixture_sandbox, entry_point_name, inputs)

Expand Down