Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax/array-api): se_atten_v2 #4289

Merged
merged 1 commit into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions deepmd/dpmodel/descriptor/se_atten_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
DEFAULT_PRECISION,
PRECISION_DICT,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils import (
NetworkCollection,
)
Expand Down Expand Up @@ -146,8 +149,8 @@ def serialize(self) -> dict:
"exclude_types": obj.exclude_types,
"env_protection": obj.env_protection,
"@variables": {
"davg": obj["davg"],
"dstd": obj["dstd"],
"davg": to_numpy_array(obj["davg"]),
"dstd": to_numpy_array(obj["dstd"]),
},
## to be updated when the options are supported.
"trainable": self.trainable,
Expand Down
13 changes: 13 additions & 0 deletions deepmd/jax/descriptor/se_atten_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.descriptor.se_atten_v2 import DescrptSeAttenV2 as DescrptSeAttenV2DP
from deepmd.jax.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.jax.descriptor.dpa1 import (
DescrptDPA1,
)


@BaseDescriptor.register("se_atten_v2")
class DescrptSeAttenV2(DescrptDPA1, DescrptSeAttenV2DP):
pass
10 changes: 10 additions & 0 deletions source/tests/array_api_strict/descriptor/se_atten_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.descriptor.se_atten_v2 import DescrptSeAttenV2 as DescrptSeAttenV2DP

from .dpa1 import (
DescrptDPA1,
)


class DescrptSeAttenV2(DescrptDPA1, DescrptSeAttenV2DP):
pass
95 changes: 95 additions & 0 deletions source/tests/consistent/descriptor/test_se_atten_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
)

from ..common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
CommonTest,
parameterized,
Expand All @@ -30,6 +32,18 @@
)
else:
DescrptSeAttenV2PT = None
if INSTALLED_JAX:
from deepmd.jax.descriptor.se_atten_v2 import (
DescrptSeAttenV2 as DescrptSeAttenV2JAX,
)
else:
DescrptSeAttenV2JAX = None
if INSTALLED_ARRAY_API_STRICT:
from ...array_api_strict.descriptor.se_atten_v2 import (
DescrptSeAttenV2 as DescrptSeAttenV2Strict,
)
else:
DescrptSeAttenV2Strict = None
DescrptSeAttenV2TF = None
from deepmd.utils.argcheck import (
descrpt_se_atten_args,
Expand Down Expand Up @@ -175,9 +189,70 @@ def skip_dp(self) -> bool:
def skip_tf(self) -> bool:
return True

@property
def skip_jax(self) -> bool:
(
tebd_dim,
resnet_dt,
type_one_side,
attn,
attn_layer,
attn_dotr,
excluded_types,
env_protection,
set_davg_zero,
scaling_factor,
normalize,
temperature,
ln_eps,
concat_output_tebd,
precision,
use_econf_tebd,
use_tebd_bias,
) = self.param
return not INSTALLED_JAX or self.is_meaningless_zero_attention_layer_tests(
attn_layer,
attn_dotr,
normalize,
temperature,
)

@property
def skip_array_api_strict(self) -> bool:
(
tebd_dim,
resnet_dt,
type_one_side,
attn,
attn_layer,
attn_dotr,
excluded_types,
env_protection,
set_davg_zero,
scaling_factor,
normalize,
temperature,
ln_eps,
concat_output_tebd,
precision,
use_econf_tebd,
use_tebd_bias,
) = self.param
return (
not INSTALLED_ARRAY_API_STRICT
or self.is_meaningless_zero_attention_layer_tests(
attn_layer,
attn_dotr,
normalize,
temperature,
)
)
njzjz marked this conversation as resolved.
Show resolved Hide resolved

tf_class = DescrptSeAttenV2TF
dp_class = DescrptSeAttenV2DP
pt_class = DescrptSeAttenV2PT
jax_class = DescrptSeAttenV2JAX
array_api_strict_class = DescrptSeAttenV2Strict
args = descrpt_se_atten_args().append(Argument("ntypes", int, optional=False))

def setUp(self):
Expand Down Expand Up @@ -244,6 +319,26 @@ def eval_pt(self, pt_obj: Any) -> Any:
mixed_types=True,
)

def eval_jax(self, jax_obj: Any) -> Any:
return self.eval_jax_descriptor(
jax_obj,
self.natoms,
self.coords,
self.atype,
self.box,
mixed_types=True,
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
return self.eval_array_api_strict_descriptor(
array_api_strict_obj,
self.natoms,
self.coords,
self.atype,
self.box,
mixed_types=True,
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
return (ret[0],)

Expand Down