Skip to content

Commit

Permalink
feat(jax/array-api): se_t_tebd (#4288)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

- **New Features**
- Introduced support for JAX as a backend for the "se_e3_tebd"
descriptor, enhancing flexibility in computational options.
- Added serialization and deserialization methods to the descriptor
classes for better state management.

- **Bug Fixes**
- Improved handling of attributes in the descriptor classes to ensure
correct data types and transformations.

- **Tests**
- Enhanced the test suite to support multiple backends, including JAX
and Array API Strict, improving the robustness of testing.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Nov 2, 2024
1 parent 8355947 commit 6a75c6b
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 41 deletions.
150 changes: 111 additions & 39 deletions deepmd/dpmodel/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,20 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel import (
PRECISION_DICT,
NativeOP,
)
from deepmd.dpmodel.array_api import (
xp_take_along_axis,
)
from deepmd.dpmodel.common import (
get_xp_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
EmbeddingNet,
EnvMat,
Expand All @@ -26,9 +34,6 @@
from deepmd.dpmodel.utils.update_sel import (
UpdateSel,
)
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
Expand Down Expand Up @@ -318,11 +323,15 @@ def call(
sw
The smooth switch function.
"""
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
del mapping
nf, nloc, nnei = nlist.shape
nall = coord_ext.reshape(nf, -1).shape[1] // 3
nall = xp.reshape(coord_ext, (nf, -1)).shape[1] // 3
# nf x nall x tebd_dim
atype_embd_ext = self.type_embedding.call()[atype_ext]
atype_embd_ext = xp.reshape(
xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0),
(nf, nall, self.tebd_dim),
)
# nfnl x tebd_dim
atype_embd = atype_embd_ext[:, :nloc, :]
grrg, g2, h2, rot_mat, sw = self.se_ttebd(
Expand All @@ -334,8 +343,8 @@ def call(
)
# nf x nloc x (ng + tebd_dim)
if self.concat_output_tebd:
grrg = np.concatenate(
[grrg, atype_embd.reshape(nf, nloc, self.tebd_dim)], axis=-1
grrg = xp.concat(
[grrg, xp.reshape(atype_embd, (nf, nloc, self.tebd_dim))], axis=-1
)
return grrg, rot_mat, None, None, sw

Expand Down Expand Up @@ -368,8 +377,8 @@ def serialize(self) -> dict:
"env_protection": obj.env_protection,
"smooth": self.smooth,
"@variables": {
"davg": obj["davg"],
"dstd": obj["dstd"],
"davg": to_numpy_array(obj["davg"]),
"dstd": to_numpy_array(obj["dstd"]),
},
"trainable": self.trainable,
}
Expand Down Expand Up @@ -491,33 +500,35 @@ def __init__(
else:
self.embd_input_dim = 1

self.embeddings = NetworkCollection(
embeddings = NetworkCollection(
ndim=0,
ntypes=self.ntypes,
network_type="embedding_network",
)
self.embeddings[0] = EmbeddingNet(
embeddings[0] = EmbeddingNet(
self.embd_input_dim,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
seed=child_seed(seed, 0),
)
self.embeddings = embeddings
if self.tebd_input_mode in ["strip"]:
self.embeddings_strip = NetworkCollection(
embeddings_strip = NetworkCollection(
ndim=0,
ntypes=self.ntypes,
network_type="embedding_network",
)
self.embeddings_strip[0] = EmbeddingNet(
embeddings_strip[0] = EmbeddingNet(
self.tebd_dim_input,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
seed=child_seed(seed, 1),
)
self.embeddings_strip = embeddings_strip
else:
self.embeddings_strip = None

Expand Down Expand Up @@ -652,82 +663,85 @@ def call(
atype_embd_ext: Optional[np.ndarray] = None,
mapping: Optional[np.ndarray] = None,
):
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
# nf x nloc x nnei x 4
dmatrix, diff, sw = self.env_mat.call(
coord_ext, atype_ext, nlist, self.mean, self.stddev
)
nf, nloc, nnei, _ = dmatrix.shape
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
# nfnl x nnei
exclude_mask = exclude_mask.reshape(nf * nloc, nnei)
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
# nfnl x nnei
nlist = nlist.reshape(nf * nloc, nnei)
nlist = np.where(exclude_mask, nlist, -1)
nlist = xp.reshape(nlist, (nf * nloc, nnei))
nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1))
# nfnl x nnei
nlist_mask = nlist != -1
# nfnl x nnei x 1
sw = np.where(nlist_mask[:, :, None], sw.reshape(nf * nloc, nnei, 1), 0.0)
sw = xp.where(
nlist_mask[:, :, None],
xp.reshape(sw, (nf * nloc, nnei, 1)),
xp.zeros((nf * nloc, nnei, 1), dtype=sw.dtype),
)

# nfnl x nnei x 4
dmatrix = dmatrix.reshape(nf * nloc, nnei, 4)
dmatrix = xp.reshape(dmatrix, (nf * nloc, nnei, 4))
# nfnl x nnei x 4
rr = dmatrix
rr = rr * exclude_mask[:, :, None]
rr = rr * xp.astype(exclude_mask[:, :, None], rr.dtype)
# nfnl x nt_i x 3
rr_i = rr[:, :, 1:]
# nfnl x nt_j x 3
rr_j = rr[:, :, 1:]
# nfnl x nt_i x nt_j
env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j)
# env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j)
env_ij = xp.sum(rr_i[:, :, None, :] * rr_j[:, None, :, :], axis=-1)
# nfnl x nt_i x nt_j x 1
ss = np.expand_dims(env_ij, axis=-1)
ss = env_ij[..., None]

nlist_masked = np.where(nlist_mask, nlist, 0)
index = np.tile(nlist_masked.reshape(nf, -1, 1), (1, 1, self.tebd_dim))
nlist_masked = xp.where(nlist_mask, nlist, xp.zeros_like(nlist))
index = xp.tile(xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim))
# nfnl x nnei x tebd_dim
atype_embd_nlist = np.take_along_axis(atype_embd_ext, index, axis=1).reshape(
nf * nloc, nnei, self.tebd_dim
atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1)
atype_embd_nlist = xp.reshape(
atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim)
)
# nfnl x nt_i x nt_j x tebd_dim
nlist_tebd_i = np.tile(
np.expand_dims(atype_embd_nlist, axis=2), [1, 1, self.nnei, 1]
)
nlist_tebd_j = np.tile(
np.expand_dims(atype_embd_nlist, axis=1), [1, self.nnei, 1, 1]
)
nlist_tebd_i = xp.tile(atype_embd_nlist[:, :, None, :], (1, 1, self.nnei, 1))
nlist_tebd_j = xp.tile(atype_embd_nlist[:, None, :, :], (1, self.nnei, 1, 1))
ng = self.neuron[-1]

if self.tebd_input_mode in ["concat"]:
# nfnl x nt_i x nt_j x (1 + tebd_dim * 2)
ss = np.concatenate([ss, nlist_tebd_i, nlist_tebd_j], axis=-1)
ss = xp.concat([ss, nlist_tebd_i, nlist_tebd_j], axis=-1)
# nfnl x nt_i x nt_j x ng
gg = self.cal_g(ss, 0)
elif self.tebd_input_mode in ["strip"]:
# nfnl x nt_i x nt_j x ng
gg_s = self.cal_g(ss, 0)
assert self.embeddings_strip is not None
# nfnl x nt_i x nt_j x (tebd_dim * 2)
tt = np.concatenate([nlist_tebd_i, nlist_tebd_j], axis=-1)
tt = xp.concat([nlist_tebd_i, nlist_tebd_j], axis=-1)
# nfnl x nt_i x nt_j x ng
gg_t = self.cal_g_strip(tt, 0)
if self.smooth:
gg_t = (
gg_t
* sw.reshape(nf * nloc, self.nnei, 1, 1)
* sw.reshape(nf * nloc, 1, self.nnei, 1)
* xp.reshape(sw, (nf * nloc, self.nnei, 1, 1))
* xp.reshape(sw, (nf * nloc, 1, self.nnei, 1))
)
# nfnl x nt_i x nt_j x ng
gg = gg_s * gg_t + gg_s
else:
raise NotImplementedError

# nfnl x ng
res_ij = np.einsum("ijk,ijkm->im", env_ij, gg)
# res_ij = np.einsum("ijk,ijkm->im", env_ij, gg)
res_ij = xp.sum(env_ij[:, :, :, None] * gg[:, :, :, :], axis=(1, 2))
res_ij = res_ij * (1.0 / float(self.nnei) / float(self.nnei))
# nf x nl x ng
result = res_ij.reshape(nf, nloc, self.filter_neuron[-1]).astype(
GLOBAL_NP_FLOAT_PRECISION
)
result = xp.reshape(res_ij, (nf, nloc, self.filter_neuron[-1]))
result = xp.astype(result, get_xp_precision(xp, "global"))
return (
result,
None,
Expand All @@ -743,3 +757,61 @@ def has_message_passing(self) -> bool:
def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""
return False

def serialize(self) -> dict:
"""Serialize the descriptor to dict."""
obj = self
data = {
"@class": "Descriptor",
"type": "se_e3_tebd",
"@version": 1,
"rcut": obj.rcut,
"rcut_smth": obj.rcut_smth,
"sel": obj.sel,
"ntypes": obj.ntypes,
"neuron": obj.neuron,
"tebd_dim": obj.tebd_dim,
"tebd_input_mode": obj.tebd_input_mode,
"set_davg_zero": obj.set_davg_zero,
"activation_function": obj.activation_function,
"resnet_dt": obj.resnet_dt,
# make deterministic
"precision": np.dtype(PRECISION_DICT[obj.precision]).name,
"embeddings": obj.embeddings.serialize(),
"env_mat": obj.env_mat.serialize(),
"exclude_types": obj.exclude_types,
"env_protection": obj.env_protection,
"smooth": obj.smooth,
"@variables": {
"davg": to_numpy_array(obj["davg"]),
"dstd": to_numpy_array(obj["dstd"]),
},
}
if obj.tebd_input_mode in ["strip"]:
data.update({"embeddings_strip": obj.embeddings_strip.serialize()})
return data

@classmethod
def deserialize(cls, data: dict) -> "DescrptSeTTebd":
"""Deserialize from dict."""
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
data.pop("@class")
data.pop("type")
variables = data.pop("@variables")
embeddings = data.pop("embeddings")
env_mat = data.pop("env_mat")
tebd_input_mode = data["tebd_input_mode"]
if tebd_input_mode in ["strip"]:
embeddings_strip = data.pop("embeddings_strip")
else:
embeddings_strip = None
se_ttebd = cls(**data)

se_ttebd["davg"] = variables["davg"]
se_ttebd["dstd"] = variables["dstd"]
se_ttebd.embeddings = NetworkCollection.deserialize(embeddings)
if tebd_input_mode in ["strip"]:
se_ttebd.embeddings_strip = NetworkCollection.deserialize(embeddings_strip)

return se_ttebd
56 changes: 56 additions & 0 deletions deepmd/jax/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.se_t_tebd import (
DescrptBlockSeTTebd as DescrptBlockSeTTebdDP,
)
from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
from deepmd.jax.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.jax.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.jax.utils.network import (
NetworkCollection,
)
from deepmd.jax.utils.type_embed import (
TypeEmbedNet,
)


@flax_module
class DescrptBlockSeTTebd(DescrptBlockSeTTebdDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"mean", "stddev"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
elif name in {"embeddings", "embeddings_strip"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
elif name == "env_mat":
# env_mat doesn't store any value
pass
elif name == "emask":
value = PairExcludeMask(value.ntypes, value.exclude_types)

return super().__setattr__(name, value)


@BaseDescriptor.register("se_e3_tebd")
@flax_module
class DescrptSeTTebd(DescrptSeTTebdDP):
def __setattr__(self, name: str, value: Any) -> None:
if name == "se_ttebd":
value = DescrptBlockSeTTebd.deserialize(value.serialize())
elif name == "type_embedding":
value = TypeEmbedNet.deserialize(value.serialize())
return super().__setattr__(name, value)
4 changes: 2 additions & 2 deletions doc/model/train-se-e3-tebd.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Descriptor `"se_e3_tebd"` {{ pytorch_icon }} {{ dpmodel_icon }}
# Descriptor `"se_e3_tebd"` {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

The notation of `se_e3_tebd` is short for the three-body embedding descriptor with type embeddings, where the notation `se` denotes the Deep Potential Smooth Edition (DeepPot-SE).
Expand Down
47 changes: 47 additions & 0 deletions source/tests/array_api_strict/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.se_t_tebd import (
DescrptBlockSeTTebd as DescrptBlockSeTTebdDP,
)
from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP

from ..common import (
to_array_api_strict_array,
)
from ..utils.exclude_mask import (
PairExcludeMask,
)
from ..utils.network import (
NetworkCollection,
)
from ..utils.type_embed import (
TypeEmbedNet,
)


class DescrptBlockSeTTebd(DescrptBlockSeTTebdDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"mean", "stddev"}:
value = to_array_api_strict_array(value)
elif name in {"embeddings", "embeddings_strip"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
elif name == "env_mat":
# env_mat doesn't store any value
pass
elif name == "emask":
value = PairExcludeMask(value.ntypes, value.exclude_types)

return super().__setattr__(name, value)


class DescrptSeTTebd(DescrptSeTTebdDP):
def __setattr__(self, name: str, value: Any) -> None:
if name == "se_ttebd":
value = DescrptBlockSeTTebd.deserialize(value.serialize())
elif name == "type_embedding":
value = TypeEmbedNet.deserialize(value.serialize())
return super().__setattr__(name, value)
Loading

0 comments on commit 6a75c6b

Please sign in to comment.