Skip to content

Commit

Permalink
split base in more functions
Browse files Browse the repository at this point in the history
  • Loading branch information
federicazanca committed Jun 4, 2024
1 parent 3179fbe commit 83dde1e
Showing 1 changed file with 78 additions and 36 deletions.
114 changes: 78 additions & 36 deletions aiida_mlip/calculations/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Base class for features common to most calculations."""

from typing import Union

from ase.io import read, write

from aiida.common import InputValidationError, datastructures
Expand Down Expand Up @@ -217,42 +219,8 @@ def prepare_for_submission(

# Define architecture from model if model is given,
# otherwise get architecture from inputs and download default model
architecture = None
if "model" in self.inputs and hasattr(self.inputs.model, "architecture"):
architecture = str((self.inputs.model).architecture)
cmd_line["arch"] = architecture
elif "arch" in self.inputs:
architecture = str(self.inputs.arch.value)
cmd_line["arch"] = architecture
# At this point we must have the model in the config and the arch in the config
# or nowhere, so we don't need to write in the cmd line if in config
elif "config" in self.inputs and "arch" in self.inputs.config:
architecture = self.inputs.config.as_dictionary["arch"]
# And we need a default value if it's nowhere
else:
architecture = "mace_mp"
cmd_line["arch"] = architecture

model_path = None
if "model" in self.inputs:
# Raise error if model is None
if self.inputs.model is None:
raise ValueError("Model cannot be None")
model_path = self.inputs.model.filepath
else:
if "config" in self.inputs and "model" in self.inputs.config:
# No need for command line
model_path = None
# If we have not found the model anywhere let's use a default
else:
# if model is not given (which is different than it being None)
model_path = ModelData.download(
"https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model", # pylint: disable=line-too-long
architecture,
).filepath

if model_path:
cmd_line.setdefault("calc-kwargs", {})["model"] = model_path
cmd_line, architecture = self._define_architecture(cmd_line)
cmd_line = self._add_model_to_cmdline(cmd_line, architecture)

if "config" in self.inputs:
# Add config file to command line
Expand Down Expand Up @@ -287,3 +255,77 @@ def prepare_for_submission(
]

return calcinfo

def _define_architecture(self, cmd_line: dict) -> Union[dict, str]:
"""
Find architecture in inputs or config file and add to command line if needed.
Parameters
----------
cmd_line : dict
Dictionary containing the cmd line keys.
Returns
-------
cmd_line_updated: dict
Dictionary containing the cmd line keys updated with the architecture.
architecture: str
Architecture type, either from inputs/config or a default value ("mace_mp").
"""
architecture = None
cmd_line_updated = cmd_line
if "model" in self.inputs and hasattr(self.inputs.model, "architecture"):
architecture = str((self.inputs.model).architecture)
cmd_line_updated["arch"] = architecture
elif "arch" in self.inputs:
architecture = str(self.inputs.arch.value)
cmd_line_updated["arch"] = architecture
# At this point we must have the model in the config and the arch in the config
# or nowhere, so we don't need to write in the cmd line if in config
elif "config" in self.inputs and "arch" in self.inputs.config:
architecture = self.inputs.config.as_dictionary["arch"]
# And we need a default value if it's nowhere
else:
architecture = "mace_mp"
cmd_line_updated["arch"] = architecture
return cmd_line_updated, architecture

def _add_model_to_cmdline(self, cmd_line: dict, architecture: str) -> dict:
"""
Find model in inputs or config file and add to command line if needed.
Parameters
----------
cmd_line : dict
Dictionary containing the cmd line keys.
architecture : str
Architecture type.
Returns
-------
dict
Dictionary containing the cmd line keys updated with the model.
"""
model_path = None
cmd_line_updated = cmd_line
if "model" in self.inputs:
# Raise error if model is None
if self.inputs.model is None:
raise ValueError("Model cannot be None")
model_path = self.inputs.model.filepath
else:
if "config" in self.inputs and "model" in self.inputs.config:
# No need for command line
model_path = None
# If we have not found the model anywhere let's use a default
else:
# if model is not given (which is different than it being None)
model_path = ModelData.download(
"https://github.com/stfc/janus-core/raw/main/tests/models/mace_mp_small.model", # pylint: disable=line-too-long
architecture,
).filepath

if model_path:
cmd_line_updated.setdefault("calc-kwargs", {})["model"] = model_path

return cmd_line_updated

0 comments on commit 83dde1e

Please sign in to comment.