From 3ae012b4091f87912118ee172d3add0eb3c76374 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 22 Apr 2024 19:08:37 -0400 Subject: [PATCH] chore: move save_dp_model and load_dp_model to a seperated module Fix #3526. Signed-off-by: Jinzhe Zeng --- deepmd/backend/dpmodel.py | 4 +- deepmd/dpmodel/infer/deep_eval.py | 2 +- deepmd/dpmodel/utils/__init__.py | 10 +-- deepmd/dpmodel/utils/network.py | 107 ------------------------ deepmd/dpmodel/utils/serialization.py | 115 ++++++++++++++++++++++++++ 5 files changed, 123 insertions(+), 115 deletions(-) create mode 100644 deepmd/dpmodel/utils/serialization.py diff --git a/deepmd/backend/dpmodel.py b/deepmd/backend/dpmodel.py index 64df95586d..30591fb51a 100644 --- a/deepmd/backend/dpmodel.py +++ b/deepmd/backend/dpmodel.py @@ -100,7 +100,7 @@ def serialize_hook(self) -> Callable[[str], dict]: Callable[[str], dict] The serialize hook of the backend. """ - from deepmd.dpmodel.utils.network import ( + from deepmd.dpmodel.utils.serialization import ( load_dp_model, ) @@ -115,7 +115,7 @@ def deserialize_hook(self) -> Callable[[str, dict], None]: Callable[[str, dict], None] The deserialize hook of the backend. """ - from deepmd.dpmodel.utils.network import ( + from deepmd.dpmodel.utils.serialization import ( save_dp_model, ) diff --git a/deepmd/dpmodel/infer/deep_eval.py b/deepmd/dpmodel/infer/deep_eval.py index 22267c895a..1db5d539cf 100644 --- a/deepmd/dpmodel/infer/deep_eval.py +++ b/deepmd/dpmodel/infer/deep_eval.py @@ -24,7 +24,7 @@ from deepmd.dpmodel.utils.batch_size import ( AutoBatchSize, ) -from deepmd.dpmodel.utils.network import ( +from deepmd.dpmodel.utils.serialization import ( load_dp_model, ) from deepmd.infer.deep_dipole import ( diff --git a/deepmd/dpmodel/utils/__init__.py b/deepmd/dpmodel/utils/__init__.py index 60a4486d52..0ae70dc31d 100644 --- a/deepmd/dpmodel/utils/__init__.py +++ b/deepmd/dpmodel/utils/__init__.py @@ -12,12 +12,9 @@ NativeLayer, NativeNet, NetworkCollection, - load_dp_model, make_embedding_network, make_fitting_network, make_multilayer_network, - save_dp_model, - traverse_model_dict, ) from .nlist import ( build_multiple_neighbor_list, @@ -32,6 +29,11 @@ phys2inter, to_face_distance, ) +from .serialization import ( + load_dp_model, + save_dp_model, + traverse_model_dict, +) __all__ = [ "EnvMat", @@ -46,8 +48,6 @@ "load_dp_model", "save_dp_model", "traverse_model_dict", - "PRECISION_DICT", - "DEFAULT_PRECISION", "build_neighbor_list", "nlist_distinguish_types", "get_multiple_nlist_key", diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 661358ed70..1cc8fda347 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -6,10 +6,6 @@ import copy import itertools -import json -from datetime import ( - datetime, -) from typing import ( Callable, ClassVar, @@ -19,7 +15,6 @@ Union, ) -import h5py import numpy as np from deepmd.utils.version import ( @@ -38,108 +33,6 @@ ) -def traverse_model_dict(model_obj, callback: callable, is_variable: bool = False): - """Traverse a model dict and call callback on each variable. - - Parameters - ---------- - model_obj : object - The model object to traverse. - callback : callable - The callback function to call on each variable. - is_variable : bool, optional - Whether the current node is a variable. - - Returns - ------- - object - The model object after traversing. - """ - if isinstance(model_obj, dict): - for kk, vv in model_obj.items(): - model_obj[kk] = traverse_model_dict( - vv, callback, is_variable=is_variable or kk == "@variables" - ) - elif isinstance(model_obj, list): - for ii, vv in enumerate(model_obj): - model_obj[ii] = traverse_model_dict(vv, callback, is_variable=is_variable) - elif model_obj is None: - return model_obj - elif is_variable: - model_obj = callback(model_obj) - return model_obj - - -class Counter: - """A callable counter. - - Examples - -------- - >>> counter = Counter() - >>> counter() - 0 - >>> counter() - 1 - """ - - def __init__(self): - self.count = -1 - - def __call__(self): - self.count += 1 - return self.count - - -# TODO: move save_dp_model and load_dp_model to a seperated module -# should be moved to otherwhere... -def save_dp_model(filename: str, model_dict: dict) -> None: - """Save a DP model to a file in the native format. - - Parameters - ---------- - filename : str - The filename to save to. - model_dict : dict - The model dict to save. - """ - model_dict = model_dict.copy() - variable_counter = Counter() - with h5py.File(filename, "w") as f: - model_dict = traverse_model_dict( - model_dict, - lambda x: f.create_dataset( - f"variable_{variable_counter():04d}", data=x - ).name, - ) - save_dict = { - "software": "deepmd-kit", - "version": __version__, - # use UTC+0 time - "time": str(datetime.utcnow()), - **model_dict, - } - f.attrs["json"] = json.dumps(save_dict, separators=(",", ":")) - - -def load_dp_model(filename: str) -> dict: - """Load a DP model from a file in the native format. - - Parameters - ---------- - filename : str - The filename to load from. - - Returns - ------- - dict - The loaded model dict, including meta information. - """ - with h5py.File(filename, "r") as f: - model_dict = json.loads(f.attrs["json"]) - model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy()) - return model_dict - - class NativeLayer(NativeOP): """Native representation of a layer. diff --git a/deepmd/dpmodel/utils/serialization.py b/deepmd/dpmodel/utils/serialization.py new file mode 100644 index 0000000000..a69170e51d --- /dev/null +++ b/deepmd/dpmodel/utils/serialization.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +from datetime import ( + datetime, +) +from typing import ( + Callable, +) + +import h5py + +try: + from deepmd._version import version as __version__ +except ImportError: + __version__ = "unknown" + + +def traverse_model_dict(model_obj, callback: Callable, is_variable: bool = False): + """Traverse a model dict and call callback on each variable. + + Parameters + ---------- + model_obj : object + The model object to traverse. + callback : callable + The callback function to call on each variable. + is_variable : bool, optional + Whether the current node is a variable. + + Returns + ------- + object + The model object after traversing. + """ + if isinstance(model_obj, dict): + for kk, vv in model_obj.items(): + model_obj[kk] = traverse_model_dict( + vv, callback, is_variable=is_variable or kk == "@variables" + ) + elif isinstance(model_obj, list): + for ii, vv in enumerate(model_obj): + model_obj[ii] = traverse_model_dict(vv, callback, is_variable=is_variable) + elif model_obj is None: + return model_obj + elif is_variable: + model_obj = callback(model_obj) + return model_obj + + +class Counter: + """A callable counter. + + Examples + -------- + >>> counter = Counter() + >>> counter() + 0 + >>> counter() + 1 + """ + + def __init__(self): + self.count = -1 + + def __call__(self): + self.count += 1 + return self.count + + +def save_dp_model(filename: str, model_dict: dict) -> None: + """Save a DP model to a file in the native format. + + Parameters + ---------- + filename : str + The filename to save to. + model_dict : dict + The model dict to save. + """ + model_dict = model_dict.copy() + variable_counter = Counter() + with h5py.File(filename, "w") as f: + model_dict = traverse_model_dict( + model_dict, + lambda x: f.create_dataset( + f"variable_{variable_counter():04d}", data=x + ).name, + ) + save_dict = { + "software": "deepmd-kit", + "version": __version__, + # use UTC+0 time + "time": str(datetime.utcnow()), + **model_dict, + } + f.attrs["json"] = json.dumps(save_dict, separators=(",", ":")) + + +def load_dp_model(filename: str) -> dict: + """Load a DP model from a file in the native format. + + Parameters + ---------- + filename : str + The filename to load from. + + Returns + ------- + dict + The loaded model dict, including meta information. + """ + with h5py.File(filename, "r") as f: + model_dict = json.loads(f.attrs["json"]) + model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy()) + return model_dict