Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tf/pt): add/refact lammps & C++ support for spin model #4321

Merged
merged 100 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from 97 commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
f34cbe1
feat(pt/tf): support spin lammps plugin
iProzd Sep 21, 2024
d5b544b
update typo
iProzd Sep 21, 2024
dd331fd
update pt backend
iProzd Sep 22, 2024
31bafb1
rm extend from pair-deepmd
iProzd Sep 22, 2024
15150f6
fix tf interface for spin
hztttt Sep 23, 2024
bdfe205
fix interface for multi model
hztttt Sep 23, 2024
be59313
support spin_norm & virtual_len in model graph and fix bug
hztttt Sep 25, 2024
ec7c16b
fix pt
iProzd Sep 28, 2024
6524e5e
Update pair_deepmd.cpp
iProzd Sep 28, 2024
2c66443
fix tensorflow bug
iProzd Oct 14, 2024
4f3d9d4
fix mag force bug
iProzd Oct 14, 2024
d24d7e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 15, 2024
593bf81
Update c_api.h
iProzd Oct 15, 2024
3466e34
Update c_api.h
iProzd Oct 15, 2024
c3a4f3e
extend sendlist nlist and other tensors but still bugs
CaRoLZhangxy Oct 18, 2024
e2e1e55
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 18, 2024
cf85275
revert `extend sendlist nlist`
iProzd Oct 21, 2024
1d6defe
fix spin communication in lammps
iProzd Oct 21, 2024
2a38025
Merge branch 'devel' into spin_lmp
iProzd Oct 21, 2024
e5c0ecf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2024
85c934b
Update spin_model.py
iProzd Oct 22, 2024
35fd1c6
Update spin.py
iProzd Oct 22, 2024
11aeb17
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 22, 2024
6c5cb1d
add ut for spin c++
iProzd Oct 22, 2024
474a2b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 22, 2024
114898f
bump version
iProzd Oct 22, 2024
fef13f5
Merge branch 'devel' into spin_lmp
iProzd Oct 26, 2024
d4c7d1a
Spin lmp nlist (#35)
hztttt Oct 31, 2024
3afc2fd
Merge branch 'devel' into spin_lmp
iProzd Oct 31, 2024
9e82b8d
add deepspin pair style (#36)
iProzd Oct 31, 2024
5a9a0a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2024
605fb9b
Update pair_deepmd.cpp
iProzd Oct 31, 2024
2cc6d8a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2024
a0b7996
Update plugin
iProzd Oct 31, 2024
3dc6fff
fix spin
iProzd Nov 1, 2024
36aee8d
Merge branch 'devel' into spin_lmp
iProzd Nov 1, 2024
c3bf841
Update pair_base.cpp
iProzd Nov 1, 2024
5451acd
Update test_deeppot_tf_spin.cc
iProzd Nov 1, 2024
319493a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2024
b1e4a03
Update build_lammps.sh
iProzd Nov 1, 2024
de6abef
reformat C/C++ interface
iProzd Nov 4, 2024
799b4e5
rm dead code
iProzd Nov 4, 2024
643e202
fix ut
iProzd Nov 4, 2024
fb4dfe0
add virtual methods
iProzd Nov 4, 2024
d1fd284
fix memory leak
iProzd Nov 5, 2024
99e1e05
add virtual methods
iProzd Nov 5, 2024
ae98964
Update deepmd.hpp
iProzd Nov 5, 2024
3e7501e
rename compute_spin to compute
iProzd Nov 5, 2024
2c4ca0d
update nopbc test
iProzd Nov 5, 2024
7eab6cc
fix lmp uts and rename pair base
iProzd Nov 6, 2024
0965a70
add old c api
iProzd Nov 6, 2024
af09efd
rename base to backend
iProzd Nov 6, 2024
a532c33
rename model filename in lammps tests
njzjz Nov 6, 2024
919654e
add tf nlist nopbc UT for spin
iProzd Nov 7, 2024
c30091b
add tf lmp nopbc UT for spin
iProzd Nov 7, 2024
10b163e
fix torch lmp UT bug
hztttt Nov 7, 2024
811a0b9
Merge branch 'devel' into spin_lmp
iProzd Nov 7, 2024
0039aa4
fix nopbc spin test
iProzd Nov 7, 2024
ab15e47
Merge branch 'devel' into spin_lmp
iProzd Nov 7, 2024
e572e37
Update c_api.h
iProzd Nov 7, 2024
01e7745
Update test_deeppot_a.cc
iProzd Nov 7, 2024
3d1fce6
Update test_deeppot_a.cc
iProzd Nov 7, 2024
960f71a
fix error handle
iProzd Nov 7, 2024
4d71247
rm spin from pairdeepmd
iProzd Nov 7, 2024
41ad708
make pair modification readable
iProzd Nov 7, 2024
dc0f496
add #4269
iProzd Nov 7, 2024
8fd95f8
Update pair_deepmd.cpp
iProzd Nov 7, 2024
edb1e9f
Update DeepSpinTF.cc
iProzd Nov 7, 2024
5c9fda1
rm spin args from deeppottf
iProzd Nov 7, 2024
24896f0
rm black space
iProzd Nov 7, 2024
ba46f54
rm black space and comment
iProzd Nov 7, 2024
809b471
Update DeepPot.h
iProzd Nov 7, 2024
388bb22
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2024
d20d668
resolve conversations
iProzd Nov 7, 2024
e09bf5a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2024
5f53a46
update docs
iProzd Nov 7, 2024
223502d
Update deepmd.hpp
iProzd Nov 7, 2024
1215097
add uts
iProzd Nov 8, 2024
292a68f
Update test_deepspin_a_hpp.cc
iProzd Nov 8, 2024
665f001
update uts
iProzd Nov 8, 2024
58c15ed
Update test_deepspin_a.cc
iProzd Nov 8, 2024
b481274
Update test_deepspin_a.cc
iProzd Nov 8, 2024
29ace48
Update test_deepspin_a.cc
iProzd Nov 8, 2024
e68de42
Update test_deepspin_a.cc
iProzd Nov 8, 2024
cef5817
fix space
iProzd Nov 8, 2024
2085804
Create test_deepspin_a_hpp_tf.cc
iProzd Nov 8, 2024
b367f97
Update test_deepspin_a_hpp_tf.cc
iProzd Nov 8, 2024
b76e272
update ntypes_spin
iProzd Nov 9, 2024
68cfb94
Update test_deepspin_a_hpp_tf.cc
iProzd Nov 9, 2024
ea19b35
Update test_deepspin_a_hpp_tf.cc
iProzd Nov 9, 2024
43e8baf
Update test_deepspin_a_hpp_tf.cc
iProzd Nov 9, 2024
82eca9a
Update test_deepspin_a_hpp_tf.cc
iProzd Nov 9, 2024
7c69066
add uts
iProzd Nov 9, 2024
31d69db
Delete test_deepspin_model_devi_hpp.cc
iProzd Nov 9, 2024
8fb6498
Delete test_deepspin_a_hpp_tf.cc
iProzd Nov 9, 2024
4bc0e42
Create test_deepspin_model_devi_hpp.cc
iProzd Nov 10, 2024
1b7c79b
Update test_deepspin_model_devi_hpp.cc
iProzd Nov 10, 2024
bb8d38e
Update test_deepspin_model_devi_hpp.cc
iProzd Nov 10, 2024
e6bfebe
Update deepmd.hpp
iProzd Nov 10, 2024
117f4c9
add ut for lammps atomic energy
iProzd Nov 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from deepmd.pt.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.pt.utils.spin import (
concat_switch_virtual,
)
from deepmd.pt.utils.utils import (
ActivationFn,
)
Expand Down Expand Up @@ -422,6 +425,7 @@ def forward(
atype_embd = extended_atype_embd
assert isinstance(atype_embd, torch.Tensor) # for jit
g1 = self.act(atype_embd)
ng1 = g1.shape[-1]
# nb x nloc x nnei x 1, nb x nloc x nnei x 3
if not self.direct_dist:
g2, h2 = torch.split(dmatrix, [1, 3], dim=-1)
Expand All @@ -448,10 +452,27 @@ def forward(
assert mapping is not None
g1_ext = torch.gather(g1, 1, mapping)
else:
n_padding = nall - nloc
g1 = torch.nn.functional.pad(
g1.squeeze(0), (0, 0, 0, n_padding), value=0.0
)
has_spin = "has_spin" in comm_dict
if not has_spin:
n_padding = nall - nloc
g1 = torch.nn.functional.pad(
g1.squeeze(0), (0, 0, 0, n_padding), value=0.0
)
real_nloc = nloc
real_nall = nall
else:
# for spin
real_nloc = nloc // 2
real_nall = nall // 2
real_n_padding = real_nall - real_nloc
g1_real, g1_virtual = torch.split(g1, [real_nloc, real_nloc], dim=1)
# mix_g1: nb x real_nloc x (ng1 * 2)
mix_g1 = torch.cat([g1_real, g1_virtual], dim=2)
# nb x real_nall x (ng1 * 2)
g1 = torch.nn.functional.pad(
mix_g1.squeeze(0), (0, 0, 0, real_n_padding), value=0.0
)

assert "send_list" in comm_dict
assert "send_proc" in comm_dict
assert "recv_proc" in comm_dict
Expand All @@ -467,17 +488,22 @@ def forward(
g1,
comm_dict["communicator"],
torch.tensor(
nloc,
real_nloc,
dtype=torch.int32,
device=env.DEVICE,
), # should be int of c++
torch.tensor(
nall - nloc,
real_nall - real_nloc,
dtype=torch.int32,
device=env.DEVICE,
), # should be int of c++
)
g1_ext = ret[0].unsqueeze(0)
if has_spin:
g1_real_ext, g1_virtual_ext = torch.split(g1_ext, [ng1, ng1], dim=2)
g1_ext = concat_switch_virtual(
g1_real_ext, g1_virtual_ext, real_nloc
)
g1, g2, h2 = ll.forward(
g1_ext,
g2,
Expand Down
40 changes: 10 additions & 30 deletions deepmd/pt/model/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from deepmd.pt.model.atomic_model import (
DPAtomicModel,
)
from deepmd.pt.utils.spin import (
concat_switch_virtual,
)
from deepmd.pt.utils.utils import (
to_torch_tensor,
)
Expand Down Expand Up @@ -79,15 +82,15 @@ def process_spin_input_lower(
self.virtual_scale_mask.to(extended_atype.device)
)[extended_atype].reshape([nframes, nall, 1])
virtual_extended_atype = extended_atype + self.ntypes_real
extended_coord_updated = self.concat_switch_virtual(
extended_coord_updated = concat_switch_virtual(
extended_coord, virtual_extended_coord, nloc
)
extended_atype_updated = self.concat_switch_virtual(
extended_atype_updated = concat_switch_virtual(
extended_atype, virtual_extended_atype, nloc
)
if mapping is not None:
virtual_mapping = mapping + nloc
mapping_updated = self.concat_switch_virtual(mapping, virtual_mapping, nloc)
mapping_updated = concat_switch_virtual(mapping, virtual_mapping, nloc)
else:
mapping_updated = None
# extend the nlist
Expand Down Expand Up @@ -203,33 +206,6 @@ def extend_nlist(extended_atype, nlist):
extended_nlist[second_part_index] -= nall - nloc
return extended_nlist

@staticmethod
def concat_switch_virtual(extended_tensor, extended_tensor_virtual, nloc: int):
"""
Concat real and virtual extended tensors, and switch all the local ones to the first nloc * 2 atoms.
- [:, :nloc]: original nloc real atoms.
- [:, nloc: nloc + nloc]: virtual atoms corresponding to nloc real atoms.
- [:, nloc + nloc: nloc + nall]: ghost real atoms.
- [:, nloc + nall: nall + nall]: virtual atoms corresponding to ghost real atoms.
"""
nframes, nall = extended_tensor.shape[:2]
out_shape = list(extended_tensor.shape)
out_shape[1] *= 2
extended_tensor_updated = torch.zeros(
out_shape,
dtype=extended_tensor.dtype,
device=extended_tensor.device,
)
extended_tensor_updated[:, :nloc] = extended_tensor[:, :nloc]
extended_tensor_updated[:, nloc : nloc + nloc] = extended_tensor_virtual[
:, :nloc
]
extended_tensor_updated[:, nloc + nloc : nloc + nall] = extended_tensor[
:, nloc:
]
extended_tensor_updated[:, nloc + nall :] = extended_tensor_virtual[:, nloc:]
return extended_tensor_updated.view(out_shape)

@staticmethod
def expand_aparam(aparam, nloc: int):
"""Expand the atom parameters for virtual atoms if necessary."""
Expand Down Expand Up @@ -469,6 +445,7 @@ def forward_common_lower(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[dict[str, torch.Tensor]] = None,
extra_nlist_sort: bool = False,
):
nframes, nloc = nlist.shape[:2]
Expand All @@ -490,6 +467,7 @@ def forward_common_lower(
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
comm_dict=comm_dict,
extra_nlist_sort=extra_nlist_sort,
)
model_output_type = self.backbone_model.model_output_type()
Expand Down Expand Up @@ -605,6 +583,7 @@ def forward_lower(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[dict[str, torch.Tensor]] = None,
):
model_ret = self.forward_common_lower(
extended_coord,
Expand All @@ -615,6 +594,7 @@ def forward_lower(
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
comm_dict=comm_dict,
extra_nlist_sort=self.backbone_model.need_sorted_nlist_for_lower(),
)
model_predict = {}
Expand Down
30 changes: 30 additions & 0 deletions deepmd/pt/utils/spin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import torch


def concat_switch_virtual(
extended_tensor,
extended_tensor_virtual,
nloc: int,
):
"""
Concat real and virtual extended tensors, and switch all the local ones to the first nloc * 2 atoms.
- [:, :nloc]: original nloc real atoms.
- [:, nloc: nloc + nloc]: virtual atoms corresponding to nloc real atoms.
- [:, nloc + nloc: nloc + nall]: ghost real atoms.
- [:, nloc + nall: nall + nall]: virtual atoms corresponding to ghost real atoms.
"""
nframes, nall = extended_tensor.shape[:2]
out_shape = list(extended_tensor.shape)
out_shape[1] *= 2
extended_tensor_updated = torch.zeros(
out_shape,
dtype=extended_tensor.dtype,
device=extended_tensor.device,
)
extended_tensor_updated[:, :nloc] = extended_tensor[:, :nloc]
extended_tensor_updated[:, nloc : nloc + nloc] = extended_tensor_virtual[:, :nloc]
extended_tensor_updated[:, nloc + nloc : nloc + nall] = extended_tensor[:, nloc:]
extended_tensor_updated[:, nloc + nall :] = extended_tensor_virtual[:, nloc:]
return extended_tensor_updated.view(out_shape)
4 changes: 4 additions & 0 deletions deepmd/tf/entrypoints/freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def _make_node_names(
"o_atom_energy",
"o_atom_virial",
"spin_attr/ntypes_spin",
"spin_attr/virtual_len",
"spin_attr/spin_norm",
"fitting_attr/dfparam",
"fitting_attr/daparam",
"fitting_attr/aparam_nall",
Expand Down Expand Up @@ -258,6 +260,8 @@ def freeze_graph(
"train_attr/min_nbor_dist",
"fitting_attr/aparam_nall",
"spin_attr/ntypes_spin",
"spin_attr/virtual_len",
"spin_attr/spin_norm",
]
different_set = set(output_node) - set(input_node)
if different_set:
Expand Down
Loading
Loading