Skip to content

Commit

Permalink
feat(jax/array-api): se_e2_r (#4257)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced a new descriptor class, `DescrptSeR`, enhancing
compatibility with JAX and Array API.
- Added custom logic for attribute handling in the new descriptor class.

- **Bug Fixes**
  - Improved error handling and type conversion for tensor operations.

- **Tests**
- Enhanced testing framework for the `DescrptSeR` descriptor, including
support for JAX and Array API Strict backends.
- Updated test class to better reflect the focus on the `DescrptSeR`
descriptor.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and coderabbitai[bot] authored Oct 26, 2024
1 parent 659f90d commit fa61d69
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 15 deletions.
38 changes: 24 additions & 14 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel import (
DEFAULT_PRECISION,
PRECISION_DICT,
NativeOP,
)
from deepmd.dpmodel.common import (
get_xp_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
EmbeddingNet,
EnvMat,
Expand All @@ -25,9 +30,6 @@
from deepmd.dpmodel.utils.update_sel import (
UpdateSel,
)
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
Expand Down Expand Up @@ -144,31 +146,33 @@ def __init__(
self.env_protection = env_protection

in_dim = 1 # not considiering type embedding
self.embeddings = NetworkCollection(
embeddings = NetworkCollection(
ntypes=self.ntypes,
ndim=(1 if self.type_one_side else 2),
network_type="embedding_network",
)
if not self.type_one_side:
raise NotImplementedError("type_one_side == False not implemented")
for ii in range(self.ntypes):
self.embeddings[(ii,)] = EmbeddingNet(
embeddings[(ii,)] = EmbeddingNet(
in_dim,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
seed=child_seed(seed, ii),
)
self.embeddings = embeddings
self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection)
self.nnei = np.sum(self.sel)
self.nnei = np.sum(self.sel).item()
self.davg = np.zeros(
[self.ntypes, self.nnei, 1], dtype=PRECISION_DICT[self.precision]
)
self.dstd = np.ones(
[self.ntypes, self.nnei, 1], dtype=PRECISION_DICT[self.precision]
)
self.orig_sel = self.sel
self.sel_cumsum = [0, *np.cumsum(self.sel).tolist()]

def __setitem__(self, key, value):
if key in ("avg", "data_avg", "davg"):
Expand Down Expand Up @@ -279,8 +283,9 @@ def cal_g(
ss,
ll,
):
xp = array_api_compat.array_namespace(ss)
nf, nloc, nnei = ss.shape[0:3]
ss = ss.reshape(nf, nloc, nnei, 1)
ss = xp.reshape(ss, (nf, nloc, nnei, 1))
# nf x nloc x nnei x ng
gg = self.embeddings[(ll,)].call(ss)
return gg
Expand Down Expand Up @@ -321,29 +326,34 @@ def call(
sw
The smooth switch function.
"""
xp = array_api_compat.array_namespace(coord_ext)
del mapping
# nf x nloc x nnei x 1
rr, diff, ww = self.env_mat.call(
coord_ext, atype_ext, nlist, self.davg, self.dstd, True
)
nf, nloc, nnei, _ = rr.shape
sec = np.append([0], np.cumsum(self.sel))
sec = self.sel_cumsum

ng = self.neuron[-1]
xyz_scatter = np.zeros([nf, nloc, ng], dtype=PRECISION_DICT[self.precision])
xyz_scatter = xp.zeros(
[nf, nloc, ng], dtype=get_xp_precision(xp, self.precision)
)
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
rr = xp.astype(rr, xyz_scatter.dtype)
for tt in range(self.ntypes):
mm = exclude_mask[:, :, sec[tt] : sec[tt + 1]]
tr = rr[:, :, sec[tt] : sec[tt + 1], :]
tr = tr * mm[:, :, :, None]
tr = tr * xp.astype(mm[:, :, :, None], tr.dtype)
gg = self.cal_g(tr, tt)
gg = np.mean(gg, axis=2)
gg = xp.mean(gg, axis=2)
# nf x nloc x ng x 1
xyz_scatter += gg * (self.sel[tt] / self.nnei)

res_rescale = 1.0 / 5.0
res = xyz_scatter * res_rescale
res = res.reshape(nf, nloc, ng).astype(GLOBAL_NP_FLOAT_PRECISION)
res = xp.reshape(res, (nf, nloc, ng))
res = xp.astype(res, get_xp_precision(xp, "global"))
return res, None, None, None, ww

def serialize(self) -> dict:
Expand All @@ -369,8 +379,8 @@ def serialize(self) -> dict:
"env_mat": self.env_mat.serialize(),
"embeddings": self.embeddings.serialize(),
"@variables": {
"davg": self.davg,
"dstd": self.dstd,
"davg": to_numpy_array(self.davg),
"dstd": to_numpy_array(self.dstd),
},
"type_map": self.type_map,
}
Expand Down
4 changes: 4 additions & 0 deletions deepmd/jax/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
from deepmd.jax.descriptor.se_e2_a import (
DescrptSeA,
)
from deepmd.jax.descriptor.se_e2_r import (
DescrptSeR,
)

__all__ = [
"DescrptSeA",
"DescrptSeR",
"DescrptDPA1",
]
41 changes: 41 additions & 0 deletions deepmd/jax/descriptor/se_e2_r.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
from deepmd.jax.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.jax.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.jax.utils.network import (
NetworkCollection,
)


@BaseDescriptor.register("se_e2_r")
@BaseDescriptor.register("se_r")
@flax_module
class DescrptSeR(DescrptSeRDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"dstd", "davg"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
elif name in {"embeddings"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
elif name == "env_mat":
# env_mat doesn't store any value
pass
elif name == "emask":
value = PairExcludeMask(value.ntypes, value.exclude_types)

return super().__setattr__(name, value)
32 changes: 32 additions & 0 deletions source/tests/array_api_strict/descriptor/se_e2_r.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP

from ..common import (
to_array_api_strict_array,
)
from ..utils.exclude_mask import (
PairExcludeMask,
)
from ..utils.network import (
NetworkCollection,
)


class DescrptSeR(DescrptSeRDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"dstd", "davg"}:
value = to_array_api_strict_array(value)
elif name in {"embeddings"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
elif name == "env_mat":
# env_mat doesn't store any value
pass
elif name == "emask":
value = PairExcludeMask(value.ntypes, value.exclude_types)

return super().__setattr__(name, value)
55 changes: 54 additions & 1 deletion source/tests/consistent/descriptor/test_se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)

from ..common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
CommonTest,
Expand All @@ -33,14 +35,25 @@
descrpt_se_r_args,
)

if INSTALLED_JAX:
from deepmd.jax.descriptor.se_e2_r import DescrptSeR as DescrptSeRJAX
else:
DescrptSeRJAX = None
if INSTALLED_ARRAY_API_STRICT:
from ...array_api_strict.descriptor.se_e2_r import (
DescrptSeR as DescrptSeRArrayAPIStrict,
)
else:
DescrptSeRArrayAPIStrict = None


@parameterized(
(True, False), # resnet_dt
(True, False), # type_one_side
([], [[0, 1]]), # excluded_types
("float32", "float64"), # precision
)
class TestSeA(CommonTest, DescriptorTest, unittest.TestCase):
class TestSeR(CommonTest, DescriptorTest, unittest.TestCase):
@property
def data(self) -> dict:
(
Expand Down Expand Up @@ -81,9 +94,31 @@ def skip_dp(self) -> bool:
) = self.param
return not type_one_side or CommonTest.skip_dp

@property
def skip_jax(self) -> bool:
(
resnet_dt,
type_one_side,
excluded_types,
precision,
) = self.param
return not type_one_side or not INSTALLED_JAX

@property
def skip_array_api_strict(self) -> bool:
(
resnet_dt,
type_one_side,
excluded_types,
precision,
) = self.param
return not type_one_side or not INSTALLED_ARRAY_API_STRICT

tf_class = DescrptSeRTF
dp_class = DescrptSeRDP
pt_class = DescrptSeRPT
jax_class = DescrptSeRJAX
array_api_strict_class = DescrptSeRArrayAPIStrict
args = descrpt_se_r_args()

def setUp(self):
Expand Down Expand Up @@ -148,6 +183,24 @@ def eval_pt(self, pt_obj: Any) -> Any:
self.box,
)

def eval_jax(self, jax_obj: Any) -> Any:
return self.eval_jax_descriptor(
jax_obj,
self.natoms,
self.coords,
self.atype,
self.box,
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
return self.eval_array_api_strict_descriptor(
array_api_strict_obj,
self.natoms,
self.coords,
self.atype,
self.box,
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
return (ret[0],)

Expand Down

0 comments on commit fa61d69

Please sign in to comment.