Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix DeepGlobalPolar and DeepWFC initlization #3834

Merged
merged 3 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class DeepEvalBackend(ABC):
"dos_redu": "dos",
"mask_mag": "mask_mag",
"mask": "mask",
# old models in v1
"global_polar": "global_polar",
"wfc": "wfc",
}

@abstractmethod
Expand Down
27 changes: 26 additions & 1 deletion deepmd/infer/deep_polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@

import numpy as np

from deepmd.dpmodel.output_def import (
FittingOutputDef,
ModelOutputDef,
OutputVariableDef,
)
from deepmd.infer.deep_tensor import (
DeepTensor,
OldDeepTensor,
)


Expand Down Expand Up @@ -36,7 +42,7 @@ def output_tensor_name(self) -> str:
return "polar"


class DeepGlobalPolar(DeepTensor):
class DeepGlobalPolar(OldDeepTensor):
@property
def output_tensor_name(self) -> str:
return "global_polar"
Expand Down Expand Up @@ -95,3 +101,22 @@ def eval(
mixed_type=mixed_type,
**kwargs,
)

@property
def output_def(self) -> ModelOutputDef:
"""Get the output definition of this model."""
# no atomic or differentiable output is defined
return ModelOutputDef(
FittingOutputDef(
[
OutputVariableDef(
self.output_tensor_name,
shape=[-1],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
atomic=False,
),
]
)
)
21 changes: 21 additions & 0 deletions deepmd/infer/deep_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,24 @@
]
)
)


class OldDeepTensor(DeepTensor):
"""Old tensor models from v1, which has no gradient output."""

# See https://github.com/deepmodeling/deepmd-kit/blob/1d1b251a2c5f05d1401aa89be792f9ed18b8f096/source/train/Model.py#L264
def eval_full(
self,
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
atomic: bool = False,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
mixed_type: bool = False,
**kwargs: dict,
) -> Tuple[np.ndarray, ...]:
"""Unsupported method."""
raise RuntimeError(

Check warning on line 255 in deepmd/infer/deep_tensor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_tensor.py#L255

Added line #L255 was not covered by tests
"This model does not support eval_full method. Use eval instead."
)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
28 changes: 26 additions & 2 deletions deepmd/infer/deep_wfc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.output_def import (
FittingOutputDef,
ModelOutputDef,
OutputVariableDef,
)
from deepmd.infer.deep_tensor import (
DeepTensor,
OldDeepTensor,
)


class DeepWFC(DeepTensor):
class DeepWFC(OldDeepTensor):
"""Deep WFC model.

Parameters
Expand All @@ -26,3 +31,22 @@ class DeepWFC(DeepTensor):
@property
def output_tensor_name(self) -> str:
return "wfc"

@property
def output_def(self) -> ModelOutputDef:
"""Get the output definition of this model."""
# no reduciable or differentiable output is defined
return ModelOutputDef(
FittingOutputDef(
[
OutputVariableDef(
self.output_tensor_name,
shape=[-1],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
atomic=True,
),
]
)
)
41 changes: 25 additions & 16 deletions source/tests/tf/test_get_potential.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Test if `DeepPotential` facto function returns the right type of potential."""

import tempfile
import unittest

from deepmd.infer.deep_polar import (
DeepGlobalPolar,
)
from deepmd.infer.deep_wfc import (
DeepWFC,
)
from deepmd.tf.infer import (
DeepDipole,
DeepPolar,
Expand Down Expand Up @@ -35,16 +42,19 @@ def setUp(self):
str(self.work_dir / "deeppolar.pbtxt"), str(self.work_dir / "deep_polar.pb")
)

# TODO add model files for globalpolar and WFC
# convert_pbtxt_to_pb(
# str(self.work_dir / "deepglobalpolar.pbtxt"),
# str(self.work_dir / "deep_globalpolar.pb")
# )
with open(self.work_dir / "deeppolar.pbtxt") as f:
deeppolar_pbtxt = f.read()

# convert_pbtxt_to_pb(
# str(self.work_dir / "deepwfc.pbtxt"),
# str(self.work_dir / "deep_wfc.pb")
# )
# not an actual globalpolar and wfc model, but still good enough for testing factory
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(deeppolar_pbtxt.replace("polar", "global_polar"))
f.flush()
convert_pbtxt_to_pb(f.name, str(self.work_dir / "deep_globalpolar.pb"))

with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(deeppolar_pbtxt.replace("polar", "wfc"))
f.flush()
convert_pbtxt_to_pb(f.name, str(self.work_dir / "deep_wfc.pb"))

def tearDown(self):
for f in self.work_dir.glob("*.pb"):
Expand All @@ -62,11 +72,10 @@ def test_factory(self):
dp = DeepPotential(self.work_dir / "deep_pot.pb")
self.assertIsInstance(dp, DeepPot, msg.format(DeepPot, type(dp)))

# TODO add model files for globalpolar and WFC
# dp = DeepPotential(self.work_dir / "deep_globalpolar.pb")
# self.assertIsInstance(
# dp, DeepGlobalPolar, msg.format(DeepGlobalPolar, type(dp))
# )
dp = DeepPotential(self.work_dir / "deep_globalpolar.pb")
self.assertIsInstance(
dp, DeepGlobalPolar, msg.format(DeepGlobalPolar, type(dp))
)

# dp = DeepPotential(self.work_dir / "deep_wfc.pb")
# self.assertIsInstance(dp, DeepWFC, msg.format(DeepWFC, type(dp)))
dp = DeepPotential(self.work_dir / "deep_wfc.pb")
self.assertIsInstance(dp, DeepWFC, msg.format(DeepWFC, type(dp)))