Skip to content

Commit

Permalink
Merge pull request #266 from laserkelvin/mptraj-dataset
Browse files Browse the repository at this point in the history
Implementing `MaterialsTrajectoryDataset`
  • Loading branch information
laserkelvin authored Aug 6, 2024
2 parents 39dae9b + 99a90d9 commit b7bc8fd
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 2 deletions.
6 changes: 5 additions & 1 deletion matsciml/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from matsciml.datasets.carolina_db import CMDataset
from matsciml.datasets.colabfit import ColabFitDataset
from matsciml.datasets.lips import LiPSDataset
from matsciml.datasets.materials_project import MaterialsProjectDataset
from matsciml.datasets.materials_project import (
MaterialsProjectDataset,
MaterialsTrajectoryDataset,
)
from matsciml.datasets.multi_dataset import MultiDataset
from matsciml.datasets.nomad import NomadDataset
from matsciml.datasets.ocp_datasets import IS2REDataset, S2EFDataset
Expand All @@ -26,6 +29,7 @@
"NomadDataset",
"OQMDDataset",
"MaterialsProjectDataset",
"MaterialsTrajectoryDataset",
"LiPSDataset",
"SyntheticPointGroupDataset",
"MultiDataset",
Expand Down
14 changes: 13 additions & 1 deletion matsciml/datasets/materials_project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@
from importlib.util import find_spec

from matsciml.datasets.materials_project.api import MaterialsProjectRequest
from matsciml.datasets.materials_project.dataset import MaterialsProjectDataset
from matsciml.datasets.materials_project.dataset import (
MaterialsProjectDataset,
MaterialsTrajectoryDataset,
)

__all__ = [
"MaterialsProjectDataset",
"MaterialsTrajectoryDataset",
"MaterialsProjectRequest",
]

if find_spec("torch_geometric") is not None:
from matsciml.datasets.materials_project.dataset import (
Expand All @@ -12,3 +21,6 @@
PyGMaterialsProjectDataset,
)

__all__.extend(
["CdvaeLMDBDataset", "PyGCdvaeDataset", "PyGMaterialsProjectDataset"]
)
86 changes: 86 additions & 0 deletions matsciml/datasets/materials_project/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,92 @@ def collate_fn(batch: list[DataDict]) -> BatchDict:
)


@registry.register_dataset("MaterialsTrajectoryDataset")
class MaterialsTrajectoryDataset(MaterialsProjectDataset):
"""
This is predominantly a variant of the base class,
but specialized towards energy, stress, and force prediction.
"""

def data_from_key(
self,
lmdb_index: int,
subindex: int,
) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]:
"""
Variation of the base class method. A lot of the complexity
is removed, as this class exclusively focuses on performing
energy/stress/force regression tasks.
Parameters
----------
lmdb_index : int
Index corresponding to which LMDB environment to parse from.
subindex : int
Index within an LMDB file that maps to a sample.
Returns
-------
Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]
A single sample/material from Materials Project.
"""
data: dict[str, Any] = super(PointCloudDataset, self).data_from_key(
lmdb_index, subindex
)
return_dict = {}
# parse out relevant structure/lattice data
self._parse_structure(data, return_dict)
self._parse_symmetry(data, return_dict)
target_keys = [
"energy",
"force",
"stress",
"uncorrected_total_energy",
"ef_per_atom_relaxed",
"ef_per_atom",
"e_per_atom_relaxed",
"energy_per_atom",
"corrected_total_energy",
]
# pack stuff
targets = {}
for key in target_keys:
item = data.get(key, None)
if item is not None:
# in the case we have something like a list, we'll try
# and convert it into a tensor in this pass
if isinstance(item, Iterable):
item = torch.Tensor(item)
targets[key] = item
# check for missing essential keys
missing_keys = [
key
for key in ["force", "stress", "corrected_total_energy"]
if key not in targets
]
if len(missing_keys) != 0:
raise KeyError(
f"Missing the following keys: {missing_keys} needed for MPTraj."
" Please ensure you're using the right dataset/LMDB files!"
)
# copy as nominal energy value
if "energy" not in targets:
targets["energy"] = targets["corrected_total_energy"]
return_dict["targets"] = targets
# compress all the targets into a single tensor for convenience
target_types = {"classification": [], "regression": target_keys}
return_dict["target_types"] = target_types
self.target_keys = target_types
return return_dict

@classmethod
def from_devset(cls, transforms: list[Callable] | None = None, **kwargs):
"""Override the devset call, at least temporarily."""
raise NotImplementedError(
"Materials Trajectory does not have a devset available."
)


if _has_pyg:
from torch_geometric.data import Batch, Data

Expand Down

0 comments on commit b7bc8fd

Please sign in to comment.