Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Oct 23, 2024
1 parent 9b571d1 commit fb3df8b
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 8 deletions.
3 changes: 3 additions & 0 deletions deepmd/jax/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.jax.common import (
ArrayAPIVariable,
to_jax_array,
)
from deepmd.jax.utils.exclude_mask import (
Expand All @@ -11,6 +12,8 @@
def base_atomic_model_set_attr(name, value):
if name in {"out_bias", "out_std"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
elif name == "pair_excl" and value is not None:
value = PairExcludeMask(value.ntypes, value.exclude_types)
elif name == "atom_excl" and value is not None:
Expand Down
14 changes: 14 additions & 0 deletions deepmd/jax/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,17 @@ def __setattr__(self, name: str, value: Any) -> None:
return super().__setattr__(name, value)

return FlaxModule


class ArrayAPIVariable(nnx.Variable):
def __array__(self, *args, **kwargs):
return self.value.__array__(*args, **kwargs)

def __array_namespace__(self, *args, **kwargs):
return self.value.__array_namespace__(*args, **kwargs)

def __dlpack__(self, *args, **kwargs):
return self.value.__dlpack__(*args, **kwargs)

Check warning on line 94 in deepmd/jax/common.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/common.py#L94

Added line #L94 was not covered by tests

def __dlpack_device__(self, *args, **kwargs):
return self.value.__dlpack_device__(*args, **kwargs)

Check warning on line 97 in deepmd/jax/common.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/common.py#L97

Added line #L97 was not covered by tests
3 changes: 3 additions & 0 deletions deepmd/jax/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
NeighborGatedAttentionLayer as NeighborGatedAttentionLayerDP,
)
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
Expand Down Expand Up @@ -65,6 +66,8 @@ class DescrptBlockSeAtten(DescrptBlockSeAttenDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"mean", "stddev"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
elif name in {"embeddings", "embeddings_strip"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
Expand Down
3 changes: 3 additions & 0 deletions deepmd/jax/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
Expand All @@ -26,6 +27,8 @@ class DescrptSeA(DescrptSeADP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"dstd", "davg"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
elif name in {"embeddings"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
Expand Down
3 changes: 3 additions & 0 deletions deepmd/jax/fitting/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP
from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
Expand All @@ -29,6 +30,8 @@ def setattr_for_general_fitting(name: str, value: Any) -> Any:
"aparam_inv_std",
}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
elif name == "emask":
value = AtomExcludeMask(value.ntypes, value.exclude_types)
elif name == "nets":
Expand Down
5 changes: 5 additions & 0 deletions deepmd/jax/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP
from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
Expand All @@ -16,6 +17,8 @@ class AtomExcludeMask(AtomExcludeMaskDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"type_mask"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
return super().__setattr__(name, value)


Expand All @@ -24,4 +27,6 @@ class PairExcludeMask(PairExcludeMaskDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"type_mask"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
return super().__setattr__(name, value)
24 changes: 17 additions & 7 deletions deepmd/jax/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
BaseModel,
get_model,
)
from deepmd.jax.utils.network import (
ArrayAPIParam,
)


def deserialize_to_file(model_file: str, data: dict) -> None:
Expand All @@ -31,14 +28,14 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
if model_file.endswith(".jax"):
model = BaseModel.deserialize(data["model"])
model_def_script = data["model_def_script"]
state = nnx.state(model, ArrayAPIParam)
_, state = nnx.split(model)
with ocp.Checkpointer(
ocp.CompositeCheckpointHandler("state", "model_def_script")
) as checkpointer:
checkpointer.save(
Path(model_file).absolute(),
ocp.args.Composite(
state=ocp.args.StandardSave(state),
state=ocp.args.StandardSave(state.to_pure_dict()),
model_def_script=ocp.args.JsonSave(model_def_script),
),
)
Expand Down Expand Up @@ -71,9 +68,22 @@ def serialize_from_file(model_file: str) -> dict:
),
)
state = data.state

# convert str "1" to int 1 key
def convert_str_to_int_key(item: dict):
for key, value in item.copy().items():
if isinstance(value, dict):
convert_str_to_int_key(value)
if key.isdigit():
item[int(key)] = item.pop(key)

convert_str_to_int_key(state)

model_def_script = data.model_def_script
model = get_model(model_def_script)
nnx.update(model, state)
abstract_model = get_model(model_def_script)
graphdef, abstract_state = nnx.split(abstract_model)
abstract_state.replace_by_pure_dict(state)
model = nnx.merge(graphdef, abstract_state)
model_dict = model.serialize()
data = {
"backend": "JAX",
Expand Down
3 changes: 3 additions & 0 deletions deepmd/jax/utils/type_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
Expand All @@ -18,6 +19,8 @@ class TypeEmbedNet(TypeEmbedNetDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"econf_tebd"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
if name in {"embedding_net"}:
value = EmbeddingNet.deserialize(value.serialize())
return super().__setattr__(name, value)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ cu12 = [
]
jax = [
'jax>=0.4.33;python_version>="3.10"',
'flax>=0.8.0;python_version>="3.10"',
'flax>=0.10.0;python_version>="3.10"',
'orbax-checkpoint;python_version>="3.10"',
# The pinning of ml_dtypes may conflict with TF
# 'jax-ai-stack;python_version>="3.10"',
Expand Down

0 comments on commit fb3df8b

Please sign in to comment.