Skip to content

Commit

Permalink
Update prediction function for atomic tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
mjwen committed Aug 30, 2024
1 parent a97fe87 commit 5ebb313
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 6 deletions.
29 changes: 29 additions & 0 deletions notebooks/predict_atomic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
An example script make predictions of any tensor.
"""

from pymatgen.core import Structure

from matten.predict import predict


def get_structure():
a = 5.46
lattice = [[0, a / 2, a / 2], [a / 2, 0, a / 2], [a / 2, a / 2, 0]]
basis = [[0.0, 0.0, 0.0], [0.25, 0.25, 0.25]]
Si = Structure(lattice, ["Si", "Si"], basis)

return Si


if __name__ == "__main__":
structure = get_structure()

# predict for one structure
tensors = predict(
structure,
model_identifier="/Users/mjwen.admin/Downloads/trained",
checkpoint="epoch=9-step=100.ckpt",
is_atomic_tensor=True,
)
print("value:", tensors)
28 changes: 22 additions & 6 deletions src/matten/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from matten.dataset.structure_scalar_tensor import TensorDatasetPrediction
from matten.log import set_logger
from matten.model_factory.tfn_atomic_tensor import AtomicTensorModel
from matten.model_factory.tfn_scalar_tensor import ScalarTensorModel
from matten.utils import CartesianTensorWrapper, yaml_load

Expand All @@ -31,9 +32,11 @@ def get_pretrained_model_dir(identifier: str) -> Path:
return Path(__file__).parent.parent.parent / "pretrained" / identifier


def get_pretrained_model(identifier: str, checkpoint: str = "model_final.ckpt"):
def get_pretrained_model(
identifier: str, checkpoint: str = "model_final.ckpt", model_class=ScalarTensorModel
):
directory = get_pretrained_model_dir(identifier)
model = ScalarTensorModel.load_from_checkpoint(
model = model_class.load_from_checkpoint(
checkpoint_path=directory.joinpath(checkpoint).as_posix(),
map_location="cpu",
)
Expand Down Expand Up @@ -62,6 +65,7 @@ def get_data_loader(
"valset_filename",
"testset_filename",
"compute_dataset_statistics",
"atom_selector",
]:
try:
config.pop(k)
Expand Down Expand Up @@ -151,6 +155,7 @@ def predict(
batch_size: int = 200,
logger_level: str = "ERROR",
is_elasticity_tensor: bool = True,
is_atomic_tensor: bool = False,
) -> Union[ElasticTensor, List[ElasticTensor]]:
f"""
Predict the property of a structure or a list of structures.
Expand All @@ -174,6 +179,8 @@ def predict(
is_elasticity_tensor: whether the target property is an elasticity tensor. If
`True`, the returned value will be a pymargen `ElasticTensor` object.
Otherwise, it will be numpy array.
is_atomic_tensor: whether the target property is an atomic tensor. If `True`,
we predict a tensor value for each atom in the structure.
Returns:
Predicted tensor(s). `None` if the model cannot make prediction for a structure.
Expand All @@ -186,7 +193,16 @@ def predict(
else:
single_struct = False

model = get_pretrained_model(identifier=model_identifier, checkpoint=checkpoint)
if is_atomic_tensor:
model_class = AtomicTensorModel
is_elasticity_tensor = False
else:
model_class = ScalarTensorModel
model = get_pretrained_model(
identifier=model_identifier,
checkpoint=checkpoint,
model_class=model_class,
)
check_species(model, structure)
loader = get_data_loader(structure, model_identifier, batch_size=batch_size)

Expand Down Expand Up @@ -223,10 +239,10 @@ def predict(
else:
pred_tensors = predictions

if single_struct:
if single_struct and not is_atomic_tensor:
return pred_tensors[0]
else:
return pred_tensors

return pred_tensors


if __name__ == "__main__":
Expand Down

0 comments on commit 5ebb313

Please sign in to comment.