Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: convert model files between backends #3323

Merged
merged 11 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions deepmd/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@
"""Support Deep Eval backend."""
NEIGHBOR_STAT = auto()
"""Support neighbor statistics."""
IO = auto()
"""Support IO hook."""

name: ClassVar[str] = "Unknown"
"""The formal name of the backend.
Expand Down Expand Up @@ -199,3 +201,27 @@
The neighbor statistics of the backend.
"""
pass

@property
@abstractmethod
def serialize_hook(self) -> Callable[[str], dict]:
"""The serialize hook to convert the model file to a dictionary.

Returns
-------
Callable[[str], dict]
The serialize hook of the backend.
"""
pass

Check warning on line 215 in deepmd/backend/backend.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/backend.py#L215

Added line #L215 was not covered by tests

@property
@abstractmethod
def deserialize_hook(self) -> Callable[[str, dict], None]:
"""The deserialize hook to convert the dictionary to a model file.

Returns
-------
Callable[[str, dict], None]
The deserialize hook of the backend.
"""
pass

Check warning on line 227 in deepmd/backend/backend.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/backend.py#L227

Added line #L227 was not covered by tests
34 changes: 33 additions & 1 deletion deepmd/backend/dpmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ class DPModelBackend(Backend):

name = "DPModel"
"""The formal name of the backend."""
features: ClassVar[Backend.Feature] = Backend.Feature.NEIGHBOR_STAT
features: ClassVar[Backend.Feature] = (
Backend.Feature.NEIGHBOR_STAT | Backend.Feature.IO
)
"""The features of the backend."""
suffixes: ClassVar[List[str]] = [".dp"]
"""The suffixes of the backend."""
Expand Down Expand Up @@ -84,3 +86,33 @@ def neighbor_stat(self) -> Type["NeighborStat"]:
)

return NeighborStat

@property
def serialize_hook(self) -> Callable[[str], dict]:
"""The serialize hook to convert the model file to a dictionary.

Returns
-------
Callable[[str], dict]
The serialize hook of the backend.
"""
from deepmd.dpmodel.utils.network import (
load_dp_model,
)

return load_dp_model

@property
def deserialize_hook(self) -> Callable[[str, dict], None]:
"""The deserialize hook to convert the dictionary to a model file.

Returns
-------
Callable[[str, dict], None]
The deserialize hook of the backend.
"""
from deepmd.dpmodel.utils.network import (
save_dp_model,
)

return save_dp_model
31 changes: 31 additions & 0 deletions deepmd/backend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class TensorFlowBackend(Backend):
Backend.Feature.ENTRY_POINT
| Backend.Feature.DEEP_EVAL
| Backend.Feature.NEIGHBOR_STAT
| Backend.Feature.IO
)
"""The features of the backend."""
suffixes: ClassVar[List[str]] = [".pth", ".pt"]
Expand Down Expand Up @@ -93,3 +94,33 @@ def neighbor_stat(self) -> Type["NeighborStat"]:
)

return NeighborStat

@property
def serialize_hook(self) -> Callable[[str], dict]:
"""The serialize hook to convert the model file to a dictionary.

Returns
-------
Callable[[str], dict]
The serialize hook of the backend.
"""
from deepmd.pt.utils.serialization import (
serialize_from_file,
)

return serialize_from_file

@property
def deserialize_hook(self) -> Callable[[str, dict], None]:
"""The deserialize hook to convert the dictionary to a model file.

Returns
-------
Callable[[str, dict], None]
The deserialize hook of the backend.
"""
from deepmd.pt.utils.serialization import (
deserialize_to_file,
)

return deserialize_to_file
31 changes: 31 additions & 0 deletions deepmd/backend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class TensorFlowBackend(Backend):
Backend.Feature.ENTRY_POINT
| Backend.Feature.DEEP_EVAL
| Backend.Feature.NEIGHBOR_STAT
| Backend.Feature.IO
)
"""The features of the backend."""
suffixes: ClassVar[List[str]] = [".pb"]
Expand Down Expand Up @@ -102,3 +103,33 @@ def neighbor_stat(self) -> Type["NeighborStat"]:
)

return NeighborStat

@property
def serialize_hook(self) -> Callable[[str], dict]:
"""The serialize hook to convert the model file to a dictionary.

Returns
-------
Callable[[str], dict]
The serialize hook of the backend.
"""
from deepmd.tf.utils.serialization import (
serialize_from_file,
)

return serialize_from_file

@property
def deserialize_hook(self) -> Callable[[str, dict], None]:
"""The deserialize hook to convert the dictionary to a model file.

Returns
-------
Callable[[str, dict], None]
The deserialize hook of the backend.
"""
from deepmd.tf.utils.serialization import (
deserialize_to_file,
)

return deserialize_to_file
19 changes: 10 additions & 9 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import copy
import itertools
import json
from datetime import (
datetime,
)
from typing import (
ClassVar,
Dict,
Expand Down Expand Up @@ -54,6 +57,8 @@ def traverse_model_dict(model_obj, callback: callable, is_variable: bool = False
elif isinstance(model_obj, list):
for ii, vv in enumerate(model_obj):
model_obj[ii] = traverse_model_dict(vv, callback, is_variable=is_variable)
elif model_obj is None:
return model_obj
elif is_variable:
model_obj = callback(model_obj)
return model_obj
Expand All @@ -79,7 +84,8 @@ def __call__(self):
return self.count


def save_dp_model(filename: str, model_dict: dict, extra_info: Optional[dict] = None):
# TODO: should be moved to otherwhere...
def save_dp_model(filename: str, model_dict: dict) -> None:
"""Save a DP model to a file in the native format.

Parameters
Expand All @@ -88,15 +94,9 @@ def save_dp_model(filename: str, model_dict: dict, extra_info: Optional[dict] =
The filename to save to.
model_dict : dict
The model dict to save.
extra_info : dict, optional
Extra meta information to save.
"""
model_dict = model_dict.copy()
variable_counter = Counter()
if extra_info is not None:
extra_info = extra_info.copy()
else:
extra_info = {}
with h5py.File(filename, "w") as f:
model_dict = traverse_model_dict(
model_dict,
Expand All @@ -105,10 +105,11 @@ def save_dp_model(filename: str, model_dict: dict, extra_info: Optional[dict] =
).name,
)
save_dict = {
"model": model_dict,
"software": "deepmd-kit",
"version": __version__,
**extra_info,
# use UTC+0 time
"time": str(datetime.utcnow()),
**model_dict,
}
f.attrs["json"] = json.dumps(save_dict, separators=(",", ":"))

Expand Down
27 changes: 27 additions & 0 deletions deepmd/entrypoints/convert_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.backend.backend import (

Check warning on line 2 in deepmd/entrypoints/convert_backend.py

View check run for this annotation

Codecov / codecov/patch

deepmd/entrypoints/convert_backend.py#L2

Added line #L2 was not covered by tests
Backend,
)


def convert_backend(

Check warning on line 7 in deepmd/entrypoints/convert_backend.py

View check run for this annotation

Codecov / codecov/patch

deepmd/entrypoints/convert_backend.py#L7

Added line #L7 was not covered by tests
*, # Enforce keyword-only arguments
INPUT: str,
OUTPUT: str,
**kwargs,
) -> None:
"""Convert a model file from one backend to another.

Parameters
----------
INPUT : str
The input model file.
INPUT : str
The output model file.
"""
inp_backend: Backend = Backend.detect_backend_by_model(INPUT)()
out_backend: Backend = Backend.detect_backend_by_model(OUTPUT)()
inp_hook = inp_backend.serialize_hook
out_hook = out_backend.deserialize_hook
data = inp_hook(INPUT)
out_hook(OUTPUT, data)

Check warning on line 27 in deepmd/entrypoints/convert_backend.py

View check run for this annotation

Codecov / codecov/patch

deepmd/entrypoints/convert_backend.py#L22-L27

Added lines #L22 - L27 were not covered by tests
5 changes: 5 additions & 0 deletions deepmd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from deepmd.backend.suffix import (
format_model_suffix,
)
from deepmd.entrypoints.convert_backend import (

Check warning on line 15 in deepmd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/entrypoints/main.py#L15

Added line #L15 was not covered by tests
convert_backend,
)
from deepmd.entrypoints.doc import (
doc_train_input,
)
Expand Down Expand Up @@ -76,5 +79,7 @@
neighbor_stat(**dict_args)
elif args.command == "gui":
start_dpgui(**dict_args)
elif args.command == "convert-backend":
convert_backend(**dict_args)

Check warning on line 83 in deepmd/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/entrypoints/main.py#L82-L83

Added lines #L82 - L83 were not covered by tests
else:
raise ValueError(f"Unknown command: {args.command}")
18 changes: 18 additions & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,23 @@ def main_parser() -> argparse.ArgumentParser:
"to the network on both IPv4 and IPv6 (where available)."
),
)

# convert_backend
parser_convert_backend = subparsers.add_parser(
"convert-backend",
parents=[parser_log],
help="Convert model to another backend.",
formatter_class=RawTextArgumentDefaultsHelpFormatter,
epilog=textwrap.dedent(
"""\
examples:
dp convert-backend model.pb model.pth
dp convert-backend model.pb model.dp
"""
),
)
parser_convert_backend.add_argument("INPUT", help="The input model file.")
parser_convert_backend.add_argument("OUTPUT", help="The output model file.")
return parser


Expand Down Expand Up @@ -767,6 +784,7 @@ def main():
"model-devi",
"neighbor-stat",
"gui",
"convert-backend",
):
# common entrypoints
from deepmd.entrypoints.main import main as deepmd_main
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ class BaseAtomicModel(BaseAtomicModel_):
# export public methods that are not abstract
get_nsel = torch.jit.export(BaseAtomicModel_.get_nsel)
get_nnei = torch.jit.export(BaseAtomicModel_.get_nnei)

@torch.jit.export
def get_model_param(self) -> str:
return self.model_param
1 change: 1 addition & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class DPAtomicModel(torch.nn.Module, BaseAtomicModel):

def __init__(self, descriptor, fitting, type_map: Optional[List[str]]):
super().__init__()
self.model_param = ""
ntypes = len(type_map)
self.type_map = type_map
self.ntypes = ntypes
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def __init__(
):
models = [dp_model, zbl_model]
super().__init__(models, **kwargs)
self.model_param = ""
self.dp_model = dp_model
self.zbl_model = zbl_model

Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
self, tab_file: str, rcut: float, sel: Union[int, List[int]], **kwargs
):
super().__init__()
self.model_param = ""
self.tab_file = tab_file
self.rcut = rcut
self.tab = self._set_pairtab(tab_file, rcut)
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def __init__(
self.ffn = ffn
self.ffn_embed_dim = ffn_embed_dim
self.activation = activation
# TODO: To be fixed: precision should be given from inputs
self.prec = torch.float64
self.scaling_factor = scaling_factor
self.head_num = head_num
self.normalize = normalize
Expand Down
5 changes: 4 additions & 1 deletion deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""

import copy
import json

from deepmd.pt.model.atomic_model import (
DPAtomicModel,
Expand Down Expand Up @@ -98,7 +99,9 @@ def get_model(model_params):
fitting_net["return_energy"] = True
fitting = Fitting(**fitting_net)

return EnergyModel(descriptor, fitting, type_map=model_params["type_map"])
model = EnergyModel(descriptor, fitting, type_map=model_params["type_map"])
model.model_param = json.dumps(model_params)
return model


__all__ = [
Expand Down
Loading