Skip to content

Commit

Permalink
chore(dpmodel): move save_dp_model and load_dp_model to a seperated m…
Browse files Browse the repository at this point in the history
…odule (deepmodeling#3701)

Fix deepmodeling#3526.

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored and Mathieu Taillefumier committed Sep 18, 2024
1 parent 2851a42 commit 8a7827d
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 115 deletions.
4 changes: 2 additions & 2 deletions deepmd/backend/dpmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def serialize_hook(self) -> Callable[[str], dict]:
Callable[[str], dict]
The serialize hook of the backend.
"""
from deepmd.dpmodel.utils.network import (
from deepmd.dpmodel.utils.serialization import (
load_dp_model,
)

Expand All @@ -115,7 +115,7 @@ def deserialize_hook(self) -> Callable[[str, dict], None]:
Callable[[str, dict], None]
The deserialize hook of the backend.
"""
from deepmd.dpmodel.utils.network import (
from deepmd.dpmodel.utils.serialization import (
save_dp_model,
)

Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from deepmd.dpmodel.utils.batch_size import (
AutoBatchSize,
)
from deepmd.dpmodel.utils.network import (
from deepmd.dpmodel.utils.serialization import (
load_dp_model,
)
from deepmd.infer.deep_dipole import (
Expand Down
10 changes: 5 additions & 5 deletions deepmd/dpmodel/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
NativeLayer,
NativeNet,
NetworkCollection,
load_dp_model,
make_embedding_network,
make_fitting_network,
make_multilayer_network,
save_dp_model,
traverse_model_dict,
)
from .nlist import (
build_multiple_neighbor_list,
Expand All @@ -32,6 +29,11 @@
phys2inter,
to_face_distance,
)
from .serialization import (
load_dp_model,
save_dp_model,
traverse_model_dict,
)

__all__ = [
"EnvMat",
Expand All @@ -46,8 +48,6 @@
"load_dp_model",
"save_dp_model",
"traverse_model_dict",
"PRECISION_DICT",
"DEFAULT_PRECISION",
"build_neighbor_list",
"nlist_distinguish_types",
"get_multiple_nlist_key",
Expand Down
107 changes: 0 additions & 107 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@

import copy
import itertools
import json
from datetime import (
datetime,
)
from typing import (
Callable,
ClassVar,
Expand All @@ -19,7 +15,6 @@
Union,
)

import h5py
import numpy as np

from deepmd.utils.version import (
Expand All @@ -38,108 +33,6 @@
)


def traverse_model_dict(model_obj, callback: callable, is_variable: bool = False):
"""Traverse a model dict and call callback on each variable.
Parameters
----------
model_obj : object
The model object to traverse.
callback : callable
The callback function to call on each variable.
is_variable : bool, optional
Whether the current node is a variable.
Returns
-------
object
The model object after traversing.
"""
if isinstance(model_obj, dict):
for kk, vv in model_obj.items():
model_obj[kk] = traverse_model_dict(
vv, callback, is_variable=is_variable or kk == "@variables"
)
elif isinstance(model_obj, list):
for ii, vv in enumerate(model_obj):
model_obj[ii] = traverse_model_dict(vv, callback, is_variable=is_variable)
elif model_obj is None:
return model_obj
elif is_variable:
model_obj = callback(model_obj)
return model_obj


class Counter:
"""A callable counter.
Examples
--------
>>> counter = Counter()
>>> counter()
0
>>> counter()
1
"""

def __init__(self):
self.count = -1

def __call__(self):
self.count += 1
return self.count


# TODO: move save_dp_model and load_dp_model to a seperated module
# should be moved to otherwhere...
def save_dp_model(filename: str, model_dict: dict) -> None:
"""Save a DP model to a file in the native format.
Parameters
----------
filename : str
The filename to save to.
model_dict : dict
The model dict to save.
"""
model_dict = model_dict.copy()
variable_counter = Counter()
with h5py.File(filename, "w") as f:
model_dict = traverse_model_dict(
model_dict,
lambda x: f.create_dataset(
f"variable_{variable_counter():04d}", data=x
).name,
)
save_dict = {
"software": "deepmd-kit",
"version": __version__,
# use UTC+0 time
"time": str(datetime.utcnow()),
**model_dict,
}
f.attrs["json"] = json.dumps(save_dict, separators=(",", ":"))


def load_dp_model(filename: str) -> dict:
"""Load a DP model from a file in the native format.
Parameters
----------
filename : str
The filename to load from.
Returns
-------
dict
The loaded model dict, including meta information.
"""
with h5py.File(filename, "r") as f:
model_dict = json.loads(f.attrs["json"])
model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy())
return model_dict


class NativeLayer(NativeOP):
"""Native representation of a layer.
Expand Down
115 changes: 115 additions & 0 deletions deepmd/dpmodel/utils/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
from datetime import (
datetime,
)
from typing import (
Callable,
)

import h5py

try:
from deepmd._version import version as __version__
except ImportError:
__version__ = "unknown"


def traverse_model_dict(model_obj, callback: Callable, is_variable: bool = False):
"""Traverse a model dict and call callback on each variable.
Parameters
----------
model_obj : object
The model object to traverse.
callback : callable
The callback function to call on each variable.
is_variable : bool, optional
Whether the current node is a variable.
Returns
-------
object
The model object after traversing.
"""
if isinstance(model_obj, dict):
for kk, vv in model_obj.items():
model_obj[kk] = traverse_model_dict(
vv, callback, is_variable=is_variable or kk == "@variables"
)
elif isinstance(model_obj, list):
for ii, vv in enumerate(model_obj):
model_obj[ii] = traverse_model_dict(vv, callback, is_variable=is_variable)
elif model_obj is None:
return model_obj
elif is_variable:
model_obj = callback(model_obj)
return model_obj


class Counter:
"""A callable counter.
Examples
--------
>>> counter = Counter()
>>> counter()
0
>>> counter()
1
"""

def __init__(self):
self.count = -1

def __call__(self):
self.count += 1
return self.count


def save_dp_model(filename: str, model_dict: dict) -> None:
"""Save a DP model to a file in the native format.
Parameters
----------
filename : str
The filename to save to.
model_dict : dict
The model dict to save.
"""
model_dict = model_dict.copy()
variable_counter = Counter()
with h5py.File(filename, "w") as f:
model_dict = traverse_model_dict(
model_dict,
lambda x: f.create_dataset(
f"variable_{variable_counter():04d}", data=x
).name,
)
save_dict = {
"software": "deepmd-kit",
"version": __version__,
# use UTC+0 time
"time": str(datetime.utcnow()),
**model_dict,
}
f.attrs["json"] = json.dumps(save_dict, separators=(",", ":"))


def load_dp_model(filename: str) -> dict:
"""Load a DP model from a file in the native format.
Parameters
----------
filename : str
The filename to load from.
Returns
-------
dict
The loaded model dict, including meta information.
"""
with h5py.File(filename, "r") as f:
model_dict = json.loads(f.attrs["json"])
model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy())
return model_dict

0 comments on commit 8a7827d

Please sign in to comment.