Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test(pt/dp): add universal uts for all models #3873

Merged
merged 30 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c93269f
test(pt/dp): add universal uts for all models
iProzd Jun 13, 2024
108cdee
update se_r
iProzd Jun 13, 2024
5a5f9e2
fix ut
iProzd Jun 13, 2024
6f59a27
add parametrize to models
iProzd Jun 20, 2024
fbce582
Merge branch 'devel' into add_universal_ut
iProzd Jun 20, 2024
14ae51f
fix squeeze
iProzd Jun 20, 2024
b6afbb1
Update se_r.py
iProzd Jun 20, 2024
5584432
Update path.py
iProzd Jun 20, 2024
16a15c3
Update model.py
iProzd Jun 20, 2024
2542651
Update common.py
iProzd Jun 20, 2024
264718b
Update common.py
iProzd Jun 20, 2024
0098ac5
fix spin aparam
iProzd Jun 20, 2024
569d526
fix ut
iProzd Jun 21, 2024
4e71bba
split tests for descriptor and fitting
iProzd Jun 21, 2024
5e15857
Merge branch 'devel' into add_universal_ut
iProzd Jun 21, 2024
a715178
Merge branch 'devel' into add_universal_ut
iProzd Jun 24, 2024
faa62ed
Update common.py
iProzd Jun 24, 2024
50cffda
fix conversations
iProzd Jun 24, 2024
6c93a35
Merge branch 'devel' into add_universal_ut
iProzd Jun 24, 2024
e6fcb58
Update spin_model.py
iProzd Jun 24, 2024
733b9a4
fix ut
iProzd Jun 24, 2024
481d08c
fix warnings
iProzd Jun 25, 2024
151edb6
Update test_model.py
iProzd Jun 26, 2024
8495659
skip uts on cpu and gpu
iProzd Jun 26, 2024
ba63f73
fix device
iProzd Jun 26, 2024
860a01c
Update utils.py
iProzd Jun 26, 2024
0f06f17
use CUDA_VISIBLE_DEVICES
iProzd Jun 27, 2024
87187cf
Merge branch 'devel' into add_universal_ut
iProzd Jun 27, 2024
28d32ba
Merge branch 'devel' into add_universal_ut
iProzd Jun 27, 2024
3838281
Update test_cuda.yml
iProzd Jun 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ jobs:
- run: python -m pytest source/tests --durations=0
env:
NUM_WORKERS: 0
CUDA_VISIBLE_DEVICES: 0
- name: Download libtorch
run: |
wget https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.1%2Bcu121.zip -O libtorch.zip
Expand Down
13 changes: 10 additions & 3 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,11 @@ def forward_atomic(
]
ener_list = []
for i, model in enumerate(self.models):
mapping = self.mapping_list[i]
type_map_model = self.mapping_list[i]
ener_list.append(
model.forward_atomic(
extended_coord,
mapping[extended_atype],
type_map_model[extended_atype],
nlists_[i],
mapping,
fparam,
Expand Down Expand Up @@ -414,7 +414,12 @@ def _compute_weight(
)

numerator = np.sum(
pairwise_rr * np.exp(-pairwise_rr / self.smin_alpha), axis=-1
np.where(
nlist_larger != -1,
pairwise_rr * np.exp(-pairwise_rr / self.smin_alpha),
np.zeros_like(nlist_larger),
),
axis=-1,
) # masked nnei will be zero, no need to handle
denominator = np.sum(
np.where(
Expand All @@ -436,5 +441,7 @@ def _compute_weight(
smooth = -6 * u**5 + 15 * u**4 - 10 * u**3 + 1
coef[mid_mask] = smooth[mid_mask]
coef[right_mask] = 0
# to handle masked atoms
coef = np.where(sigma != 0, coef, np.zeros_like(coef))
self.zbl_weight = coef
return [1 - np.expand_dims(coef, -1), np.expand_dims(coef, -1)]
2 changes: 2 additions & 0 deletions deepmd/dpmodel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"double": np.float64,
"int32": np.int32,
"int64": np.int64,
"bool": bool,
"default": GLOBAL_NP_FLOAT_PRECISION,
# NumPy doesn't have bfloat16 (and does't plan to add)
# ml_dtypes is a solution, but it seems not supporting np.save/np.load
Expand All @@ -39,6 +40,7 @@
np.int32: "int32",
np.int64: "int64",
ml_dtypes.bfloat16: "bfloat16",
bool: "bool",
}
assert set(RESERVED_PRECISON_DICT.keys()) == set(PRECISION_DICT.values())
DEFAULT_PRECISION = "float64"
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@


def np_softmax(x, axis=-1):
x = np.nan_to_num(x) # to avoid value warning
iProzd marked this conversation as resolved.
Show resolved Hide resolved
e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return e_x / np.sum(e_x, axis=axis, keepdims=True)

Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from .dp_model import (
DPModelCommon,
)
from .ener_model import (
EnergyModel,
)
from .make_model import (
make_model,
)
Expand All @@ -23,6 +26,7 @@
)

__all__ = [
"EnergyModel",
"DPModelCommon",
"SpinModel",
"make_model",
Expand Down
75 changes: 66 additions & 9 deletions deepmd/dpmodel/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,21 @@
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.common import (
NativeOP,
)
from deepmd.dpmodel.model.make_model import (
make_model,
)
from deepmd.dpmodel.output_def import (
ModelOutputDef,
)
from deepmd.utils.spin import (
Spin,
)


class SpinModel:
class SpinModel(NativeOP):
"""A spin model wrapper, with spin input preprocess and output split."""

def __init__(
Expand Down Expand Up @@ -152,15 +158,20 @@
nlist_shift = nlist + nall
nlist[~nlist_mask] = -1
nlist_shift[~nlist_mask] = -1
self_spin = np.arange(0, nloc, dtype=nlist.dtype) + nall
self_spin = self_spin.reshape(1, -1, 1).repeat(nframes, axis=0)
# self spin + real neighbor + virtual neighbor
self_real = (
np.arange(0, nloc, dtype=nlist.dtype)
.reshape(1, -1, 1)
.repeat(nframes, axis=0)
)
self_spin = self_real + nall
# real atom's neighbors: self spin + real neighbor + virtual neighbor
# nf x nloc x (1 + nnei + nnei)
real_nlist = np.concatenate([self_spin, nlist, nlist_shift], axis=-1)
# spin atom's neighbors: real + real neighbor + virtual neighbor
# nf x nloc x (1 + nnei + nnei)
extended_nlist = np.concatenate([self_spin, nlist, nlist_shift], axis=-1)
spin_nlist = np.concatenate([self_real, nlist, nlist_shift], axis=-1)
# nf x (nloc + nloc) x (1 + nnei + nnei)
extended_nlist = np.concatenate(
[extended_nlist, -1 * np.ones_like(extended_nlist)], axis=-2
)
extended_nlist = np.concatenate([real_nlist, spin_nlist], axis=-2)
# update the index for switch
first_part_index = (nloc <= extended_nlist) & (extended_nlist < nall)
second_part_index = (nall <= extended_nlist) & (extended_nlist < (nall + nloc))
Expand All @@ -187,12 +198,40 @@
extended_tensor_updated[:, nloc + nall :] = extended_tensor_virtual[:, nloc:]
return extended_tensor_updated.reshape(out_shape)

@staticmethod
def expand_aparam(aparam, nloc: int):
"""Expand the atom parameters for virtual atoms if necessary."""
nframes, natom, numb_aparam = aparam.shape
if natom == nloc: # good
pass

Check warning on line 206 in deepmd/dpmodel/model/spin_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/spin_model.py#L206

Added line #L206 was not covered by tests
elif natom < nloc: # for spin with virtual atoms
aparam = np.concatenate(
[
aparam,
np.zeros(
[nframes, nloc - natom, numb_aparam],
dtype=aparam.dtype,
),
],
axis=1,
)
else:
raise ValueError(

Check warning on line 219 in deepmd/dpmodel/model/spin_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/spin_model.py#L219

Added line #L219 was not covered by tests
f"get an input aparam with {aparam.shape[1]} inputs, ",
f"which is larger than {nloc} atoms.",
)
return aparam

def get_type_map(self) -> List[str]:
"""Get the type map."""
tmap = self.backbone_model.get_type_map()
ntypes = len(tmap) // 2 # ignore the virtual type
return tmap[:ntypes]

def get_ntypes(self):
"""Returns the number of element types."""
return len(self.get_type_map())

def get_rcut(self):
"""Get the cut-off radius."""
return self.backbone_model.get_rcut()
Expand Down Expand Up @@ -251,6 +290,16 @@
"""Returns whether it has spin input and output."""
return True

def model_output_def(self):
"""Get the output def for the model."""
model_output_type = self.backbone_model.model_output_type()
if "mask" in model_output_type:
model_output_type.pop(model_output_type.index("mask"))
var_name = model_output_type[0]
backbone_model_atomic_output_def = self.backbone_model.atomic_output_def()
backbone_model_atomic_output_def[var_name].magnetic = True
return ModelOutputDef(backbone_model_atomic_output_def)

iProzd marked this conversation as resolved.
Show resolved Hide resolved
def __getattr__(self, name):
"""Get attribute from the wrapped model."""
if name in self.__dict__:
Expand Down Expand Up @@ -313,8 +362,12 @@
The keys are defined by the `ModelOutputDef`.

"""
nframes, nloc = coord.shape[:2]
nframes, nloc = atype.shape[:2]
coord = coord.reshape(nframes, nloc, 3)
spin = spin.reshape(nframes, nloc, 3)
coord_updated, atype_updated = self.process_spin_input(coord, atype, spin)
if aparam is not None:
aparam = self.expand_aparam(aparam, nloc * 2)
model_predict = self.backbone_model.call(
coord_updated,
atype_updated,
Expand Down Expand Up @@ -383,6 +436,8 @@
) = self.process_spin_input_lower(
extended_coord, extended_atype, extended_spin, nlist, mapping=mapping
)
if aparam is not None:
aparam = self.expand_aparam(aparam, nloc * 2)
model_predict = self.backbone_model.call_lower(
extended_coord_updated,
extended_atype_updated,
Expand All @@ -401,3 +456,5 @@
)[0]
# for now omit the grad output
return model_predict

forward_lower = call_lower
16 changes: 15 additions & 1 deletion deepmd/dpmodel/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,11 @@ def __init__(
def size(self):
return self.output_size

def squeeze(self, dim):
# squeeze the shape on given dimension
if -len(self.shape) <= dim < len(self.shape) and self.shape[dim] == 1:
self.shape.pop(dim)

iProzd marked this conversation as resolved.
Show resolved Hide resolved

class FittingOutputDef:
"""Defines the shapes and other properties of the fitting network outputs.
Expand Down Expand Up @@ -306,7 +311,6 @@ def __getitem__(

def get_data(
self,
key: str,
) -> Dict[str, OutputVariableDef]:
return self.var_defs

Expand Down Expand Up @@ -402,6 +406,16 @@ def check_operation_applied(
return var_def.category & op.value == op.value


def check_deriv(var_def: OutputVariableDef) -> bool:
"""Check if a variable is obtained by derivative."""
deriv = (
check_operation_applied(var_def, OutputVariableOperation.DERV_R)
or check_operation_applied(var_def, OutputVariableOperation._SEC_DERV_R)
or check_operation_applied(var_def, OutputVariableOperation.DERV_C)
)
return deriv

iProzd marked this conversation as resolved.
Show resolved Hide resolved

def do_reduce(
def_outp_data: Dict[str, OutputVariableDef],
) -> Dict[str, OutputVariableDef]:
Expand Down
9 changes: 6 additions & 3 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,12 @@ def forward_atomic(
ener_list = []

for i, model in enumerate(self.models):
mapping = self.mapping_list[i]
type_map_model = self.mapping_list[i].to(extended_atype.device)
# apply bias to each individual model
ener_list.append(
model.forward_common_atomic(
extended_coord,
mapping[extended_atype],
type_map_model[extended_atype],
nlists_[i],
mapping,
fparam,
Expand All @@ -239,7 +239,10 @@ def forward_atomic(
weights = self._compute_weight(extended_coord, extended_atype, nlists_)

fit_ret = {
"energy": torch.sum(torch.stack(ener_list) * torch.stack(weights), dim=0),
"energy": torch.sum(
torch.stack(ener_list) * torch.stack(weights).to(extended_atype.device),
dim=0,
),
} # (nframes, nloc, 1)
return fit_ret

Expand Down
3 changes: 2 additions & 1 deletion deepmd/pt/model/atomic_model/polar_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def apply_out_stat(

# (nframes, nloc, 1)
modified_bias = (
modified_bias.unsqueeze(-1) * self.fitting_net.scale[atype]
modified_bias.unsqueeze(-1)
* (self.fitting_net.scale.to(atype.device))[atype]
)

eye = torch.eye(3, dtype=dtype, device=device)
Expand Down
8 changes: 6 additions & 2 deletions deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class DescrptHybrid(BaseDescriptor, torch.nn.Module):
The descriptor can be either an object or a dictionary.
"""

nlist_cut_idx: List[torch.Tensor]

iProzd marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
list: List[Union[BaseDescriptor, Dict[str, Any]]],
Expand Down Expand Up @@ -278,11 +280,13 @@ def forward(
for ii, descrpt in enumerate(self.descrpt_list):
# cut the nlist to the correct length
if self.mixed_types() == descrpt.mixed_types():
nl = nlist[:, :, self.nlist_cut_idx[ii]]
nl = nlist[:, :, self.nlist_cut_idx[ii].to(atype_ext.device)]
else:
# mixed_types is True, but descrpt.mixed_types is False
assert nl_distinguish_types is not None
nl = nl_distinguish_types[:, :, self.nlist_cut_idx[ii]]
nl = nl_distinguish_types[
:, :, self.nlist_cut_idx[ii].to(atype_ext.device)
]
odescriptor, gr, g2, h2, sw = descrpt(coord_ext, atype_ext, nl, mapping)
out_descriptor.append(odescriptor)
if gr is not None:
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def forward(
atype_ext: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Compute the descriptor.

Expand All @@ -321,6 +322,8 @@ def forward(
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, not required by this descriptor.
comm_dict
The data needed for communication for parallel inference.

Returns
-------
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ def get_model(model_params):
"get_model",
"DPModelCommon",
"EnergyModel",
"DipoleModel",
"PolarModel",
"DOSModel",
"FrozenModel",
"SpinModel",
"SpinEnergyModel",
Expand Down
Loading
Loading