Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore: make type_map complusory model attribute #3410

Merged
merged 62 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
063f8e7
feat: add zbl training
anyangml Mar 3, 2024
8f06ab0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
2f7fa77
fix: add atom bias
anyangml Mar 3, 2024
672563c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
312973a
Merge branch 'devel' into devel
anyangml Mar 3, 2024
cf66829
Merge branch 'devel' into devel
anyangml Mar 3, 2024
52ab95f
chore: refactor
anyangml Mar 3, 2024
993efe9
fix: add pairtab stat
anyangml Mar 3, 2024
897f9f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
e8320b6
Merge branch 'devel' into devel
anyangml Mar 3, 2024
701cb55
fix: add UTs
anyangml Mar 3, 2024
e27a816
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
dc30bbd
fix: add UT input
anyangml Mar 3, 2024
a232cf3
fix: UTs
anyangml Mar 3, 2024
d9856e7
Merge branch 'devel' into devel
anyangml Mar 3, 2024
004b63e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
ca99701
fix: UTs
anyangml Mar 3, 2024
162fc16
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
9c25175
fix: UTs
anyangml Mar 3, 2024
8fc3a70
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
84fb816
chore: merge conflict
anyangml Mar 3, 2024
55e2b7f
fix: update numpy shape
anyangml Mar 3, 2024
0b9f7ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
6524694
fix: UTs
anyangml Mar 3, 2024
e3d9a7b
feat: add UTs
anyangml Mar 3, 2024
e648ab4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
7143aa9
Merge branch 'devel' into devel
anyangml Mar 4, 2024
6ed8fde
fix: UTs
anyangml Mar 4, 2024
aadddcb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
f36988d
fix: UTs
anyangml Mar 4, 2024
9c9cbbe
feat: update UTs
anyangml Mar 4, 2024
d2adebb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
7071608
Merge branch 'devel' into devel
anyangml Mar 4, 2024
00c877c
fix: UTs
anyangml Mar 4, 2024
eb36de2
Merge branch 'devel' into devel
anyangml Mar 4, 2024
5de7214
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
8b35fa4
rix: revert abstract method
anyangml Mar 4, 2024
bbc7ad2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
c384b3b
fix: UTs
anyangml Mar 4, 2024
1d5fad0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
dc407e3
chore: refactor
anyangml Mar 4, 2024
a9f65be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
18a4897
fix: precommit
anyangml Mar 4, 2024
94bea6a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
09f9352
fix: precommit
anyangml Mar 4, 2024
a63089d
fix: UTs
anyangml Mar 4, 2024
bda547d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
e6be71b
fix: UTs
anyangml Mar 4, 2024
3482ef2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
a2afe7c
Merge branch 'devel' into devel
anyangml Mar 4, 2024
a20cecf
fix: remove optional
anyangml Mar 4, 2024
cbb0949
chore: revert all changes
anyangml Mar 4, 2024
c6f75c4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
96d3f05
chore: pairtab typemap
anyangml Mar 5, 2024
b76545d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
57e6894
fix: UTs
anyangml Mar 5, 2024
5ac704f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
bb06219
fix: UTs
anyangml Mar 5, 2024
2b4c7ce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
768a683
Merge branch 'devel' into fix/typemap
anyangml Mar 5, 2024
961ce6b
fix: UTs
anyangml Mar 5, 2024
940473d
chore: remove print
anyangml Mar 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,35 @@
The cutoff radius.
sel : int or list[int]
The maxmum number of atoms in the cut-off radius.
type_map: List[str]
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
"""

def __init__(
self, tab_file: str, rcut: float, sel: Union[int, List[int]], **kwargs
self,
tab_file: str,
rcut: float,
sel: Union[int, List[int]],
type_map: List[str],
**kwargs,
):
super().__init__()
self.tab_file = tab_file
self.rcut = rcut

self.tab = PairTab(self.tab_file, rcut=rcut)
self.type_map = type_map
self.ntypes = len(type_map)

Check warning on line 69 in deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py#L68-L69

Added lines #L68 - L69 were not covered by tests

if self.tab_file is not None:
self.tab_info, self.tab_data = self.tab.get()
nspline, ntypes_tab = self.tab_info[-2:].astype(int)
self.tab_data = self.tab_data.reshape(ntypes_tab, ntypes_tab, nspline, 4)
if self.ntypes != ntypes_tab:
raise ValueError(

Check warning on line 76 in deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py#L73-L76

Added lines #L73 - L76 were not covered by tests
"The `type_map` provided does not match the number of columns in the table."
)
else:
self.tab_info, self.tab_data = None, None

Expand Down Expand Up @@ -118,6 +134,7 @@
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
"type_map": self.type_map,
}
)
return dd
Expand All @@ -130,11 +147,13 @@
data.pop("type")
rcut = data.pop("rcut")
sel = data.pop("sel")
type_map = data.pop("type_map")

Check warning on line 150 in deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py#L150

Added line #L150 was not covered by tests
Fixed Show fixed Hide fixed
tab = PairTab.deserialize(data.pop("tab"))
tab_model = cls(None, rcut, sel, **data)
tab_model = cls(None, rcut, sel, type_map, **data)

Check warning on line 152 in deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py#L152

Added line #L152 was not covered by tests
tab_model.tab = tab
tab_model.tab_info = tab_model.tab.tab_info
tab_model.tab_data = tab_model.tab.tab_data
nspline, ntypes = tab_model.tab_info[-2:].astype(int)
tab_model.tab_data = tab_model.tab.tab_data.reshape(ntypes, ntypes, nspline, 4)

Check warning on line 156 in deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py#L155-L156

Added lines #L155 - L156 were not covered by tests
return tab_model

def forward_atomic(
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
self,
descriptor,
fitting,
type_map: Optional[List[str]],
type_map: List[str],
**kwargs,
):
torch.nn.Module.__init__(self)
Expand Down
65 changes: 53 additions & 12 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
get_multiple_nlist_key,
nlist_distinguish_types,
)
from deepmd.utils.path import (

Check warning on line 28 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L28

Added line #L28 was not covered by tests
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -77,10 +80,9 @@
"""Get the cut-off radius."""
return max(self.get_model_rcuts())

@torch.jit.export
@abstractmethod

Check warning on line 83 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L83

Added line #L83 was not covered by tests
def get_type_map(self) -> List[str]:
"""Get the type map."""
raise NotImplementedError("TODO: implement this method")

def get_model_rcuts(self) -> List[float]:
"""Get the cut-off radius for each individual models."""
Expand All @@ -103,8 +105,8 @@
nsels = torch.tensor(self.get_model_nsels(), device=device)
zipped = torch.stack(
[
torch.tensor(rcuts, device=device),
torch.tensor(nsels, device=device),
rcuts,
nsels,
],
dim=0,
).T
Expand Down Expand Up @@ -184,14 +186,20 @@

weights = self._compute_weight(extended_coord, extended_atype, nlists_)

if self.atomic_bias is not None:
raise NotImplementedError("Need to add bias in a future PR.")
else:
fit_ret = {
"energy": torch.sum(
torch.stack(ener_list) * torch.stack(weights), dim=0
),
} # (nframes, nloc, 1)
atype = extended_atype[:, :nloc]
for idx, model in enumerate(self.models):
if isinstance(model, DPAtomicModel):
bias_atom_e = model.fitting_net.bias_atom_e
elif isinstance(model, PairTabAtomicModel):
bias_atom_e = model.bias_atom_e

Check warning on line 194 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L189-L194

Added lines #L189 - L194 were not covered by tests
else:
bias_atom_e = None
if bias_atom_e is not None:
ener_list[idx] += bias_atom_e[atype]

Check warning on line 198 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L196-L198

Added lines #L196 - L198 were not covered by tests

fit_ret = {

Check warning on line 200 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L200

Added line #L200 was not covered by tests
"energy": torch.sum(torch.stack(ener_list) * torch.stack(weights), dim=0),
} # (nframes, nloc, 1)
return fit_ret

def fitting_output_def(self) -> FittingOutputDef:
Expand Down Expand Up @@ -307,6 +315,39 @@
# this is a placeholder being updated in _compute_weight, to handle Jit attribute init error.
self.zbl_weight = torch.empty(0, dtype=torch.float64, device=env.DEVICE)

@torch.jit.export
def get_type_map(self) -> List[str]:
dp_map = self.dp_model.get_type_map()
Fixed Show fixed Hide fixed
zbl_map = self.zbl_model.get_type_map()
Fixed Show fixed Hide fixed
return self.dp_model.get_type_map()

Check warning on line 322 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L318-L322

Added lines #L318 - L322 were not covered by tests

def compute_or_load_stat(

Check warning on line 324 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L324

Added line #L324 was not covered by tests
self,
sampled_func,
stat_file_path: Optional[DPPath] = None,
):
"""
Compute or load the statistics parameters of the model,
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
and saved in the `stat_file_path`(s).
When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
and load the calculated statistics parameters.

Parameters
----------
sampled_func
The lazy sampled function to get data frames from different data systems.
stat_file_path
The dictionary of paths to the statistics files.
"""
self.dp_model.compute_or_load_stat(sampled_func, stat_file_path)
self.zbl_model.compute_or_load_stat(sampled_func, stat_file_path)

Check warning on line 345 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L344-L345

Added lines #L344 - L345 were not covered by tests

def change_energy_bias(self):

Check warning on line 347 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L347

Added line #L347 was not covered by tests
# need to implement
pass

Check warning on line 349 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L349

Added line #L349 was not covered by tests

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
96 changes: 90 additions & 6 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Callable,
Dict,
List,
Optional,
Expand All @@ -13,9 +14,18 @@
FittingOutputDef,
OutputVariableDef,
)
from deepmd.pt.utils import (

Check warning on line 17 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L17

Added line #L17 was not covered by tests
env,
)
from deepmd.pt.utils.stat import (

Check warning on line 20 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L20

Added line #L20 was not covered by tests
compute_output_stats,
)
from deepmd.utils.pair_tab import (
PairTab,
)
from deepmd.utils.path import (

Check warning on line 26 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L26

Added line #L26 was not covered by tests
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -47,29 +57,59 @@
The cutoff radius.
sel : int or list[int]
The maxmum number of atoms in the cut-off radius.
type_map: List[str]
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
rcond : float, optional
The condition number for the regression of atomic energy.
atom_ener
Specifying atomic energy contribution in vacuum. The `set_davg_zero` key in the descrptor should be set.
"""

def __init__(
self, tab_file: str, rcut: float, sel: Union[int, List[int]], **kwargs
self,
tab_file: str,
rcut: float,
sel: Union[int, List[int]],
type_map: List[str],
rcond: Optional[float] = None,
atom_ener: Optional[List[float]] = None,
**kwargs,
):
torch.nn.Module.__init__(self)
self.model_def_script = ""
self.tab_file = tab_file
self.rcut = rcut
self.tab = self._set_pairtab(tab_file, rcut)

BaseAtomicModel.__init__(self, **kwargs)
self.rcond = rcond
self.atom_ener = atom_ener
self.type_map = type_map
self.ntypes = len(type_map)

Check warning on line 89 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L86-L89

Added lines #L86 - L89 were not covered by tests

# handle deserialization with no input file
if self.tab_file is not None:
(
tab_info,
tab_data,
) = self.tab.get() # this returns -> Tuple[np.array, np.array]
nspline, ntypes_tab = tab_info[-2:].astype(int)

Check warning on line 97 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L97

Added line #L97 was not covered by tests
self.register_buffer("tab_info", torch.from_numpy(tab_info))
self.register_buffer("tab_data", torch.from_numpy(tab_data))
self.register_buffer(

Check warning on line 99 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L99

Added line #L99 was not covered by tests
"tab_data",
torch.from_numpy(tab_data).reshape(ntypes_tab, ntypes_tab, nspline, 4),
)
if self.ntypes != ntypes_tab:
raise ValueError(

Check warning on line 104 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L103-L104

Added lines #L103 - L104 were not covered by tests
"The `type_map` provided does not match the number of columns in the table."
)
else:
self.register_buffer("tab_info", None)
self.register_buffer("tab_data", None)
self.bias_atom_e = torch.zeros(

Check warning on line 110 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L110

Added line #L110 was not covered by tests
self.ntypes, 1, dtype=env.GLOBAL_PT_ENER_FLOAT_PRECISION, device=env.DEVICE
)

# self.model_type = "ener"
# self.model_version = MODEL_VERSION ## this shoud be in the parent class
Expand Down Expand Up @@ -103,8 +143,8 @@
return self.rcut

@torch.jit.export
def get_type_map(self) -> Optional[List[str]]:
raise NotImplementedError("TODO: implement this method")
def get_type_map(self) -> List[str]:
return self.type_map

Check warning on line 147 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L146-L147

Added lines #L146 - L147 were not covered by tests

def get_sel(self) -> List[int]:
return [self.sel]
Expand Down Expand Up @@ -135,6 +175,9 @@
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
"type_map": self.type_map,
"rcond": self.rcond,
"atom_ener": self.atom_ener,
}
)
return dd
Expand All @@ -145,15 +188,56 @@
check_version_compatibility(data.pop("@version", 1), 1, 1)
rcut = data.pop("rcut")
sel = data.pop("sel")
type_map = data.pop("type_map")
rcond = data.pop("rcond")
atom_ener = data.pop("atom_ener")

Check warning on line 193 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L191-L193

Added lines #L191 - L193 were not covered by tests
tab = PairTab.deserialize(data.pop("tab"))
data.pop("@class", None)
data.pop("type", None)
tab_model = cls(None, rcut, sel, **data)
tab_model = cls(None, rcut, sel, type_map, rcond, atom_ener, **data)

Check warning on line 197 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L197

Added line #L197 was not covered by tests
tab_model.tab = tab
tab_model.register_buffer("tab_info", torch.from_numpy(tab_model.tab.tab_info))
tab_model.register_buffer("tab_data", torch.from_numpy(tab_model.tab.tab_data))
nspline, ntypes = tab_model.tab.tab_info[-2:].astype(int)
tab_model.register_buffer(

Check warning on line 201 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L200-L201

Added lines #L200 - L201 were not covered by tests
"tab_data",
torch.from_numpy(tab_model.tab.tab_data).reshape(
ntypes, ntypes, nspline, 4
),
)
return tab_model

def compute_or_load_stat(

Check warning on line 209 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L209

Added line #L209 was not covered by tests
self,
merged: Union[Callable[[], List[dict]], List[dict]],
stat_file_path: Optional[DPPath] = None,
):
"""
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.

Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The path to the stat file.

"""
bias_atom_e = compute_output_stats(

Check warning on line 230 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L230

Added line #L230 was not covered by tests
merged, stat_file_path, self.rcond, self.atom_ener
)
self.bias_atom_e.copy_(

Check warning on line 233 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L233

Added line #L233 was not covered by tests
torch.tensor(bias_atom_e, device=env.DEVICE).view([self.ntypes, 1])
)

def change_energy_bias(self) -> None:

Check warning on line 237 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L237

Added line #L237 was not covered by tests
# need to implement
pass

Check warning on line 239 in deepmd/pt/model/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/pairtab_atomic_model.py#L239

Added line #L239 was not covered by tests

def forward_atomic(
self,
extended_coord: torch.Tensor,
Expand Down
5 changes: 4 additions & 1 deletion deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ def get_zbl_model(model_params):
# pairtab
filepath = model_params["use_srtab"]
pt_model = PairTabAtomicModel(
filepath, model_params["descriptor"]["rcut"], model_params["descriptor"]["sel"]
filepath,
model_params["descriptor"]["rcut"],
model_params["descriptor"]["sel"],
type_map=model_params["type_map"],
)

rmin = model_params["sw_rmin"]
Expand Down
Loading
Loading