From 0529e221ec4971f714cccd878f618047b0005a38 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Mon, 5 Aug 2024 16:29:40 -0700 Subject: [PATCH 1/4] feat: added materials trajectory subclass --- .../datasets/materials_project/dataset.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/matsciml/datasets/materials_project/dataset.py b/matsciml/datasets/materials_project/dataset.py index 309f1e86..2aaf446d 100644 --- a/matsciml/datasets/materials_project/dataset.py +++ b/matsciml/datasets/materials_project/dataset.py @@ -347,6 +347,74 @@ def collate_fn(batch: list[DataDict]) -> BatchDict: ) +@registry.register_dataset("MaterialsTrajectoryDataset") +class MaterialsTrajectoryDataset(MaterialsProjectDataset): + """ + This is predominantly a variant of the base class, + but specialized towards energy, stress, and force prediction. + """ + + def data_from_key( + self, + lmdb_index: int, + subindex: int, + ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: + """ + Variation of the base class method. A lot of the complexity + is removed, as this class exclusively focuses on performing + energy/stress/force regression tasks. + + Parameters + ---------- + lmdb_index : int + Index corresponding to which LMDB environment to parse from. + subindex : int + Index within an LMDB file that maps to a sample. + + Returns + ------- + Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]] + A single sample/material from Materials Project. + """ + data: dict[str, Any] = super(PointCloudDataset, self).data_from_key( + lmdb_index, subindex + ) + return_dict = {} + # parse out relevant structure/lattice data + self._parse_structure(data, return_dict) + self._parse_symmetry(data, return_dict) + target_keys = [ + "energy", + "force", + "stress", + "uncorrected_total_energy", + "ef_per_atom_relaxed", + "ef_per_atom", + "e_per_atom_relaxed", + "energy_per_atom", + "corrected_total_energy", + ] + # pack stuff + targets = {} + for key in target_keys: + item = data.get(key, None) + if item is not None: + # in the case we have something like a list, we'll try + # and convert it into a tensor in this pass + if isinstance(item, Iterable): + item = torch.Tensor(item) + targets[key] = item + # copy as nominal energy value + if "energy" not in targets: + targets["energy"] = targets["corrected_total_energy"] + return_dict["targets"] = targets + # compress all the targets into a single tensor for convenience + target_types = {"classification": [], "regression": target_keys} + return_dict["target_types"] = target_types + self.target_keys = target_types + return return_dict + + if _has_pyg: from torch_geometric.data import Batch, Data From fc7437901fe7eb34104a91fce23d5af4b140be0c Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Mon, 5 Aug 2024 16:50:49 -0700 Subject: [PATCH 2/4] feat: adding mptraj to datasets namespace Signed-off-by: Lee, Kin Long Kelvin --- matsciml/datasets/__init__.py | 6 +++++- matsciml/datasets/materials_project/__init__.py | 14 +++++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/matsciml/datasets/__init__.py b/matsciml/datasets/__init__.py index 8fba3281..3b4e87a3 100644 --- a/matsciml/datasets/__init__.py +++ b/matsciml/datasets/__init__.py @@ -11,7 +11,10 @@ from matsciml.datasets.carolina_db import CMDataset from matsciml.datasets.colabfit import ColabFitDataset from matsciml.datasets.lips import LiPSDataset -from matsciml.datasets.materials_project import MaterialsProjectDataset +from matsciml.datasets.materials_project import ( + MaterialsProjectDataset, + MaterialsTrajectoryDataset, +) from matsciml.datasets.multi_dataset import MultiDataset from matsciml.datasets.nomad import NomadDataset from matsciml.datasets.ocp_datasets import IS2REDataset, S2EFDataset @@ -26,6 +29,7 @@ "NomadDataset", "OQMDDataset", "MaterialsProjectDataset", + "MaterialsTrajectoryDataset", "LiPSDataset", "SyntheticPointGroupDataset", "MultiDataset", diff --git a/matsciml/datasets/materials_project/__init__.py b/matsciml/datasets/materials_project/__init__.py index f6f54594..c8e2e07c 100644 --- a/matsciml/datasets/materials_project/__init__.py +++ b/matsciml/datasets/materials_project/__init__.py @@ -3,7 +3,16 @@ from importlib.util import find_spec from matsciml.datasets.materials_project.api import MaterialsProjectRequest -from matsciml.datasets.materials_project.dataset import MaterialsProjectDataset +from matsciml.datasets.materials_project.dataset import ( + MaterialsProjectDataset, + MaterialsTrajectoryDataset, +) + +__all__ = [ + "MaterialsProjectDataset", + "MaterialsTrajectoryDataset", + "MaterialsProjectRequest", +] if find_spec("torch_geometric") is not None: from matsciml.datasets.materials_project.dataset import ( @@ -12,3 +21,6 @@ PyGMaterialsProjectDataset, ) + __all__.extend( + ["CdvaeLMDBDataset", "PyGCdvaeDataset", "PyGMaterialsProjectDataset"] + ) From b5e9bb6c0083a2ce729e9bb517728a39581c1597 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Mon, 5 Aug 2024 16:55:59 -0700 Subject: [PATCH 3/4] feat: added temporary exception handling for devset --- matsciml/datasets/materials_project/dataset.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/matsciml/datasets/materials_project/dataset.py b/matsciml/datasets/materials_project/dataset.py index 2aaf446d..9a30713d 100644 --- a/matsciml/datasets/materials_project/dataset.py +++ b/matsciml/datasets/materials_project/dataset.py @@ -414,6 +414,13 @@ def data_from_key( self.target_keys = target_types return return_dict + @classmethod + def from_devset(cls, transforms: list[Callable] | None = None, **kwargs): + """Override the devset call, at least temporarily.""" + raise NotImplementedError( + "Materials Trajectory does not have a devset available." + ) + if _has_pyg: from torch_geometric.data import Batch, Data From 99a90d9a523b81cf0f4e48cf8ced203293e03ed6 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Tue, 6 Aug 2024 07:36:27 -0700 Subject: [PATCH 4/4] refactor: added exception handling for missing mptraj keys --- matsciml/datasets/materials_project/dataset.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/matsciml/datasets/materials_project/dataset.py b/matsciml/datasets/materials_project/dataset.py index 9a30713d..763a4308 100644 --- a/matsciml/datasets/materials_project/dataset.py +++ b/matsciml/datasets/materials_project/dataset.py @@ -404,6 +404,17 @@ def data_from_key( if isinstance(item, Iterable): item = torch.Tensor(item) targets[key] = item + # check for missing essential keys + missing_keys = [ + key + for key in ["force", "stress", "corrected_total_energy"] + if key not in targets + ] + if len(missing_keys) != 0: + raise KeyError( + f"Missing the following keys: {missing_keys} needed for MPTraj." + " Please ensure you're using the right dataset/LMDB files!" + ) # copy as nominal energy value if "energy" not in targets: targets["energy"] = targets["corrected_total_energy"]