Skip to content

Commit

Permalink
feat(jax): DPA-2
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Oct 31, 2024
1 parent ff04d8b commit 1cac90b
Show file tree
Hide file tree
Showing 8 changed files with 533 additions and 137 deletions.
25 changes: 15 additions & 10 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel import (
NativeOP,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils import (
EnvMat,
NetworkCollection,
Expand Down Expand Up @@ -787,6 +791,7 @@ def call(
The smooth switch function. shape: nf x nloc x nnei
"""
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
use_three_body = self.use_three_body
nframes, nloc, nnei = nlist.shape
nall = coord_ext.reshape(nframes, -1).shape[1] // 3
Expand Down Expand Up @@ -823,16 +828,16 @@ def call(
g1_ext,
mapping,
)
g1 = np.concatenate([g1, g1_three_body], axis=-1)
g1 = xp.concatenate([g1, g1_three_body], axis=-1)
# linear to change shape
g1 = self.g1_shape_tranform(g1)
if self.add_tebd_to_repinit_out:
assert self.tebd_transform is not None
g1 = g1 + self.tebd_transform(g1_inp)
# mapping g1
assert mapping is not None
mapping_ext = np.tile(mapping.reshape(nframes, nall, 1), (1, 1, g1.shape[-1]))
g1_ext = np.take_along_axis(g1, mapping_ext, axis=1)
mapping_ext = xp.tile(mapping.reshape(nframes, nall, 1), (1, 1, g1.shape[-1]))
g1_ext = xp.take_along_axis(g1, mapping_ext, axis=1)
# repformer
g1, g2, h2, rot_mat, sw = self.repformers(
nlist_dict[
Expand All @@ -846,7 +851,7 @@ def call(
mapping,
)
if self.concat_output_tebd:
g1 = np.concatenate([g1, g1_inp], axis=-1)
g1 = xp.concatenate([g1, g1_inp], axis=-1)
return g1, rot_mat, g2, h2, sw

def serialize(self) -> dict:
Expand Down Expand Up @@ -883,8 +888,8 @@ def serialize(self) -> dict:
"embeddings": repinit.embeddings.serialize(),
"env_mat": EnvMat(repinit.rcut, repinit.rcut_smth).serialize(),
"@variables": {
"davg": repinit["davg"],
"dstd": repinit["dstd"],
"davg": to_numpy_array(repinit["davg"]),
"dstd": to_numpy_array(repinit["dstd"]),
},
}
if repinit.tebd_input_mode in ["strip"]:
Expand All @@ -896,8 +901,8 @@ def serialize(self) -> dict:
"repformer_layers": [layer.serialize() for layer in repformers.layers],
"env_mat": EnvMat(repformers.rcut, repformers.rcut_smth).serialize(),
"@variables": {
"davg": repformers["davg"],
"dstd": repformers["dstd"],
"davg": to_numpy_array(repformers["davg"]),
"dstd": to_numpy_array(repformers["dstd"]),
},
}
data.update(
Expand All @@ -913,8 +918,8 @@ def serialize(self) -> dict:
repinit_three_body.rcut, repinit_three_body.rcut_smth
).serialize(),
"@variables": {
"davg": repinit_three_body["davg"],
"dstd": repinit_three_body["dstd"],
"davg": to_numpy_array(repinit_three_body["davg"]),
"dstd": to_numpy_array(repinit_three_body["dstd"]),
},
}
if repinit_three_body.tebd_input_mode in ["strip"]:
Expand Down
Loading

0 comments on commit 1cac90b

Please sign in to comment.