Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore: make type_map complusory model attribute #3410

Merged
merged 62 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
063f8e7
feat: add zbl training
anyangml Mar 3, 2024
8f06ab0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
2f7fa77
fix: add atom bias
anyangml Mar 3, 2024
672563c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
312973a
Merge branch 'devel' into devel
anyangml Mar 3, 2024
cf66829
Merge branch 'devel' into devel
anyangml Mar 3, 2024
52ab95f
chore: refactor
anyangml Mar 3, 2024
993efe9
fix: add pairtab stat
anyangml Mar 3, 2024
897f9f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
e8320b6
Merge branch 'devel' into devel
anyangml Mar 3, 2024
701cb55
fix: add UTs
anyangml Mar 3, 2024
e27a816
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
dc30bbd
fix: add UT input
anyangml Mar 3, 2024
a232cf3
fix: UTs
anyangml Mar 3, 2024
d9856e7
Merge branch 'devel' into devel
anyangml Mar 3, 2024
004b63e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
ca99701
fix: UTs
anyangml Mar 3, 2024
162fc16
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
9c25175
fix: UTs
anyangml Mar 3, 2024
8fc3a70
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
84fb816
chore: merge conflict
anyangml Mar 3, 2024
55e2b7f
fix: update numpy shape
anyangml Mar 3, 2024
0b9f7ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
6524694
fix: UTs
anyangml Mar 3, 2024
e3d9a7b
feat: add UTs
anyangml Mar 3, 2024
e648ab4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
7143aa9
Merge branch 'devel' into devel
anyangml Mar 4, 2024
6ed8fde
fix: UTs
anyangml Mar 4, 2024
aadddcb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
f36988d
fix: UTs
anyangml Mar 4, 2024
9c9cbbe
feat: update UTs
anyangml Mar 4, 2024
d2adebb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
7071608
Merge branch 'devel' into devel
anyangml Mar 4, 2024
00c877c
fix: UTs
anyangml Mar 4, 2024
eb36de2
Merge branch 'devel' into devel
anyangml Mar 4, 2024
5de7214
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
8b35fa4
rix: revert abstract method
anyangml Mar 4, 2024
bbc7ad2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
c384b3b
fix: UTs
anyangml Mar 4, 2024
1d5fad0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
dc407e3
chore: refactor
anyangml Mar 4, 2024
a9f65be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
18a4897
fix: precommit
anyangml Mar 4, 2024
94bea6a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
09f9352
fix: precommit
anyangml Mar 4, 2024
a63089d
fix: UTs
anyangml Mar 4, 2024
bda547d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
e6be71b
fix: UTs
anyangml Mar 4, 2024
3482ef2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
a2afe7c
Merge branch 'devel' into devel
anyangml Mar 4, 2024
a20cecf
fix: remove optional
anyangml Mar 4, 2024
cbb0949
chore: revert all changes
anyangml Mar 4, 2024
c6f75c4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
96d3f05
chore: pairtab typemap
anyangml Mar 5, 2024
b76545d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
57e6894
fix: UTs
anyangml Mar 5, 2024
5ac704f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
bb06219
fix: UTs
anyangml Mar 5, 2024
2b4c7ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
768a683
Merge branch 'devel' into fix/typemap
anyangml Mar 5, 2024
961ce6b
fix: UTs
anyangml Mar 5, 2024
940473d
chore: remove print
anyangml Mar 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ def __init__(
self,
descriptor,
fitting,
type_map: Optional[List[str]] = None,
type_map: List[str],
**kwargs,
):
self.type_map = type_map
self.descriptor = descriptor
self.fitting = fitting
self.type_map = type_map
super().__init__(**kwargs)

def fitting_output_def(self) -> FittingOutputDef:
Expand All @@ -65,7 +66,7 @@ def get_sel(self) -> List[int]:
"""Get the neighbor selection."""
return self.descriptor.get_sel()

def get_type_map(self) -> Optional[List[str]]:
def get_type_map(self) -> List[str]:
"""Get the type map."""
return self.type_map

Expand Down Expand Up @@ -154,7 +155,7 @@ def deserialize(cls, data) -> "DPAtomicModel":
data.pop("type")
descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor"))
fitting_obj = BaseFitting.deserialize(data.pop("fitting"))
type_map = data.pop("type_map", None)
type_map = data.pop("type_map")
obj = cls(descriptor_obj, fitting_obj, type_map=type_map, **data)
return obj

Expand Down
54 changes: 44 additions & 10 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,28 @@
----------
models : list[DPAtomicModel or PairTabAtomicModel]
A list of models to be combined. PairTabAtomicModel must be used together with a DPAtomicModel.
type_map : list[str]
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
"""

def __init__(
self,
models: List[BaseAtomicModel],
type_map: List[str],
**kwargs,
):
self.models = models
sub_model_type_maps = [md.get_type_map() for md in models]
err_msg = []
common_type_map = set(type_map)
for tpmp in sub_model_type_maps:
if not common_type_map.issubset(set(tpmp)):
err_msg.append(

Check warning on line 65 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L65

Added line #L65 was not covered by tests
f"type_map {tpmp} is not a subset of type_map {type_map}"
)
assert len(err_msg) == 0, "\n".join(err_msg)
self.type_map = type_map
self.mixed_types_list = [model.mixed_types() for model in self.models]
super().__init__(**kwargs)

Expand All @@ -72,9 +86,9 @@
"""Get the cut-off radius."""
return max(self.get_model_rcuts())

def get_type_map(self) -> Optional[List[str]]:
def get_type_map(self) -> List[str]:
"""Get the type map."""
raise NotImplementedError("TODO: get_type_map should be implemented")
raise self.type_map

Check warning on line 91 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L91

Added line #L91 was not covered by tests

def get_model_rcuts(self) -> List[float]:
"""Get the cut-off radius for each individual models."""
Expand Down Expand Up @@ -184,27 +198,29 @@
)

@staticmethod
def serialize(models) -> dict:
def serialize(models, type_map) -> dict:
return {
"@class": "Model",
"type": "linear",
"@version": 1,
"models": [model.serialize() for model in models],
"model_name": [model.__class__.__name__ for model in models],
"type_map": type_map,
}

@staticmethod
def deserialize(data) -> List[BaseAtomicModel]:
def deserialize(data) -> Tuple[List[BaseAtomicModel], List[str]]:
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
data.pop("type")
model_names = data["model_name"]
type_map = data["type_map"]
models = [
getattr(sys.modules[__name__], name).deserialize(model)
for name, model in zip(model_names, data["models"])
]
return models
return models, type_map

@abstractmethod
def _compute_weight(
Expand Down Expand Up @@ -250,8 +266,20 @@

Parameters
----------
models
This linear model should take a DPAtomicModel and a PairTable model.
dp_model
The DPAtomicModel being combined.
zbl_model
The PairTable model being combined.
sw_rmin
The lower boundary of the interpolation between short-range tabulated interaction and DP.
sw_rmax
The upper boundary of the interpolation between short-range tabulated interaction and DP.
type_map
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
smin_alpha
The short-range tabulated interaction will be swithed according to the distance of the nearest neighbor.
This distance is calculated by softmin.
"""

def __init__(
Expand All @@ -260,11 +288,12 @@
zbl_model: PairTabAtomicModel,
sw_rmin: float,
sw_rmax: float,
type_map: List[str],
smin_alpha: Optional[float] = 0.1,
**kwargs,
):
models = [dp_model, zbl_model]
super().__init__(models, **kwargs)
super().__init__(models, type_map, **kwargs)
self.dp_model = dp_model
self.zbl_model = zbl_model

Expand All @@ -279,7 +308,9 @@
"@class": "Model",
"type": "zbl",
"@version": 1,
"models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]),
"models": LinearAtomicModel.serialize(
[self.dp_model, self.zbl_model], self.type_map
),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
"smin_alpha": self.smin_alpha,
Expand All @@ -297,13 +328,16 @@
sw_rmax = data.pop("sw_rmax")
smin_alpha = data.pop("smin_alpha")

dp_model, zbl_model = LinearAtomicModel.deserialize(data.pop("models"))
([dp_model, zbl_model], type_map) = LinearAtomicModel.deserialize(
data.pop("models")
)

return cls(
dp_model=dp_model,
zbl_model=zbl_model,
sw_rmin=sw_rmin,
sw_rmax=sw_rmax,
type_map=type_map,
smin_alpha=smin_alpha,
**data,
)
Expand Down
19 changes: 15 additions & 4 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,23 @@ class PairTabAtomicModel(BaseAtomicModel):
The cutoff radius.
sel : int or list[int]
The maxmum number of atoms in the cut-off radius.
type_map : list[str]
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
"""

def __init__(
self, tab_file: str, rcut: float, sel: Union[int, List[int]], **kwargs
self,
tab_file: str,
rcut: float,
sel: Union[int, List[int]],
type_map: List[str],
**kwargs,
):
super().__init__()
self.tab_file = tab_file
self.rcut = rcut
self.type_map = type_map

self.tab = PairTab(self.tab_file, rcut=rcut)

Expand Down Expand Up @@ -86,8 +95,8 @@ def fitting_output_def(self) -> FittingOutputDef:
def get_rcut(self) -> float:
return self.rcut

def get_type_map(self) -> Optional[List[str]]:
raise NotImplementedError("TODO: get_type_map should be implemented")
def get_type_map(self) -> List[str]:
return self.type_map

def get_sel(self) -> List[int]:
return [self.sel]
Expand Down Expand Up @@ -118,6 +127,7 @@ def serialize(self) -> dict:
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
"type_map": self.type_map,
}
)
return dd
Expand All @@ -130,8 +140,9 @@ def deserialize(cls, data) -> "PairTabAtomicModel":
data.pop("type")
rcut = data.pop("rcut")
sel = data.pop("sel")
type_map = data.pop("type_map")
Fixed Show fixed Hide fixed
tab = PairTab.deserialize(data.pop("tab"))
tab_model = cls(None, rcut, sel, **data)
tab_model = cls(None, rcut, sel, type_map, **data)
tab_model.tab = tab
tab_model.tab_info = tab_model.tab.tab_info
tab_model.tab_data = tab_model.tab.tab_data
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
self,
descriptor,
fitting,
type_map: Optional[List[str]],
type_map: List[str],
**kwargs,
):
torch.nn.Module.__init__(self)
Expand Down
52 changes: 43 additions & 9 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,29 @@
----------
models : list[DPAtomicModel or PairTabAtomicModel]
A list of models to be combined. PairTabAtomicModel must be used together with a DPAtomicModel.
type_map : list[str]
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
"""

def __init__(
self,
models: List[BaseAtomicModel],
type_map: List[str],
**kwargs,
):
torch.nn.Module.__init__(self)
self.models = torch.nn.ModuleList(models)
sub_model_type_maps = [md.get_type_map() for md in models]
err_msg = []
common_type_map = set(type_map)
for tpmp in sub_model_type_maps:
if not common_type_map.issubset(set(tpmp)):
err_msg.append(

Check warning on line 68 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L68

Added line #L68 was not covered by tests
f"type_map {tpmp} is not a subset of type_map {type_map}"
)
assert len(err_msg) == 0, "\n".join(err_msg)
self.type_map = type_map
self.atomic_bias = None
self.mixed_types_list = [model.mixed_types() for model in self.models]
BaseAtomicModel.__init__(self, **kwargs)
Expand All @@ -80,7 +94,7 @@
@torch.jit.export
def get_type_map(self) -> List[str]:
"""Get the type map."""
raise NotImplementedError("TODO: implement this method")
return self.type_map

def get_model_rcuts(self) -> List[float]:
"""Get the cut-off radius for each individual models."""
Expand Down Expand Up @@ -208,25 +222,27 @@
)

@staticmethod
def serialize(models) -> dict:
def serialize(models, type_map) -> dict:
return {
"@class": "Model",
"@version": 1,
"type": "linear",
"models": [model.serialize() for model in models],
"model_name": [model.__class__.__name__ for model in models],
"type_map": type_map,
}

@staticmethod
def deserialize(data) -> List[BaseAtomicModel]:
def deserialize(data) -> Tuple[List[BaseAtomicModel], List[str]]:
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
model_names = data["model_name"]
type_map = data["type_map"]
models = [
getattr(sys.modules[__name__], name).deserialize(model)
for name, model in zip(model_names, data["models"])
]
return models
return models, type_map

@abstractmethod
def _compute_weight(
Expand Down Expand Up @@ -281,8 +297,20 @@

Parameters
----------
models
This linear model should take a DPAtomicModel and a PairTable model.
dp_model
The DPAtomicModel being combined.
zbl_model
The PairTable model being combined.
sw_rmin
The lower boundary of the interpolation between short-range tabulated interaction and DP.
sw_rmax
The upper boundary of the interpolation between short-range tabulated interaction and DP.
type_map
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
smin_alpha
The short-range tabulated interaction will be swithed according to the distance of the nearest neighbor.
This distance is calculated by softmin.
"""

def __init__(
Expand All @@ -291,11 +319,12 @@
zbl_model: PairTabAtomicModel,
sw_rmin: float,
sw_rmax: float,
type_map: List[str],
smin_alpha: Optional[float] = 0.1,
**kwargs,
):
models = [dp_model, zbl_model]
super().__init__(models, **kwargs)
super().__init__(models, type_map, **kwargs)
self.model_def_script = ""
self.dp_model = dp_model
self.zbl_model = zbl_model
Expand All @@ -314,7 +343,9 @@
"@class": "Model",
"@version": 1,
"type": "zbl",
"models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]),
"models": LinearAtomicModel.serialize(
[self.dp_model, self.zbl_model], self.type_map
),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
"smin_alpha": self.smin_alpha,
Expand All @@ -330,7 +361,9 @@
sw_rmax = data.pop("sw_rmax")
smin_alpha = data.pop("smin_alpha")

dp_model, zbl_model = LinearAtomicModel.deserialize(data.pop("models"))
[dp_model, zbl_model], type_map = LinearAtomicModel.deserialize(
data.pop("models")
)

data.pop("@class", None)
data.pop("type", None)
Expand All @@ -339,6 +372,7 @@
zbl_model=zbl_model,
sw_rmin=sw_rmin,
sw_rmax=sw_rmax,
type_map=type_map,
smin_alpha=smin_alpha,
**data,
)
Expand Down
Loading