From 14d9364b41cbdd2389a64e02903fc71c88075f2e Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 28 May 2024 18:01:32 -0400 Subject: [PATCH] fix: fix DeepGlobalPolar and DeepWFC initlization Fix #3561. Fix #3562. Not sure if some one uses them, but it's good to keep compatibility. Signed-off-by: Jinzhe Zeng --- deepmd/infer/deep_eval.py | 3 ++ deepmd/infer/deep_polar.py | 27 +++++++++++++++++- deepmd/infer/deep_tensor.py | 21 ++++++++++++++ deepmd/infer/deep_wfc.py | 28 ++++++++++++++++-- source/tests/tf/test_get_potential.py | 41 ++++++++++++++++----------- 5 files changed, 101 insertions(+), 19 deletions(-) diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index 5a00ba616d..879455b942 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -76,6 +76,9 @@ class DeepEvalBackend(ABC): "dos_redu": "dos", "mask_mag": "mask_mag", "mask": "mask", + # old models in v1 + "global_polar": "global_polar", + "wfc": "wfc", } @abstractmethod diff --git a/deepmd/infer/deep_polar.py b/deepmd/infer/deep_polar.py index c2089b278d..6650c349a2 100644 --- a/deepmd/infer/deep_polar.py +++ b/deepmd/infer/deep_polar.py @@ -7,8 +7,14 @@ import numpy as np +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + ModelOutputDef, + OutputVariableDef, +) from deepmd.infer.deep_tensor import ( DeepTensor, + OldDeepTensor, ) @@ -36,7 +42,7 @@ def output_tensor_name(self) -> str: return "polar" -class DeepGlobalPolar(DeepTensor): +class DeepGlobalPolar(OldDeepTensor): @property def output_tensor_name(self) -> str: return "global_polar" @@ -95,3 +101,22 @@ def eval( mixed_type=mixed_type, **kwargs, ) + + @property + def output_def(self) -> ModelOutputDef: + """Get the output definition of this model.""" + # no atomic or differentiable output is defined + return ModelOutputDef( + FittingOutputDef( + [ + OutputVariableDef( + self.output_tensor_name, + shape=[-1], + reduciable=False, + r_differentiable=False, + c_differentiable=False, + atomic=False, + ), + ] + ) + ) diff --git a/deepmd/infer/deep_tensor.py b/deepmd/infer/deep_tensor.py index 14e13e7f84..106bc3156c 100644 --- a/deepmd/infer/deep_tensor.py +++ b/deepmd/infer/deep_tensor.py @@ -234,3 +234,24 @@ def output_def(self) -> ModelOutputDef: ] ) ) + + +class OldDeepTensor(DeepTensor): + """Old tensor models from v1, which has no gradient output.""" + + # See https://github.com/deepmodeling/deepmd-kit/blob/1d1b251a2c5f05d1401aa89be792f9ed18b8f096/source/train/Model.py#L264 + def eval_full( + self, + coords: np.ndarray, + cells: Optional[np.ndarray], + atom_types: np.ndarray, + atomic: bool = False, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + mixed_type: bool = False, + **kwargs: dict, + ) -> Tuple[np.ndarray, ...]: + """Unsupported method.""" + raise RuntimeError( + "This model does not support eval_full method. Use eval instead." + ) diff --git a/deepmd/infer/deep_wfc.py b/deepmd/infer/deep_wfc.py index deed938e04..d92af28f5a 100644 --- a/deepmd/infer/deep_wfc.py +++ b/deepmd/infer/deep_wfc.py @@ -1,10 +1,15 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + ModelOutputDef, + OutputVariableDef, +) from deepmd.infer.deep_tensor import ( - DeepTensor, + OldDeepTensor, ) -class DeepWFC(DeepTensor): +class DeepWFC(OldDeepTensor): """Deep WFC model. Parameters @@ -26,3 +31,22 @@ class DeepWFC(DeepTensor): @property def output_tensor_name(self) -> str: return "wfc" + + @property + def output_def(self) -> ModelOutputDef: + """Get the output definition of this model.""" + # no reduciable or differentiable output is defined + return ModelOutputDef( + FittingOutputDef( + [ + OutputVariableDef( + self.output_tensor_name, + shape=[-1], + reduciable=False, + r_differentiable=False, + c_differentiable=False, + atomic=True, + ), + ] + ) + ) diff --git a/source/tests/tf/test_get_potential.py b/source/tests/tf/test_get_potential.py index 47462a20a3..fb39d41d2e 100644 --- a/source/tests/tf/test_get_potential.py +++ b/source/tests/tf/test_get_potential.py @@ -1,8 +1,15 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """Test if `DeepPotential` facto function returns the right type of potential.""" +import tempfile import unittest +from deepmd.infer.deep_polar import ( + DeepGlobalPolar, +) +from deepmd.infer.deep_wfc import ( + DeepWFC, +) from deepmd.tf.infer import ( DeepDipole, DeepPolar, @@ -35,16 +42,19 @@ def setUp(self): str(self.work_dir / "deeppolar.pbtxt"), str(self.work_dir / "deep_polar.pb") ) - # TODO add model files for globalpolar and WFC - # convert_pbtxt_to_pb( - # str(self.work_dir / "deepglobalpolar.pbtxt"), - # str(self.work_dir / "deep_globalpolar.pb") - # ) + with open(self.work_dir / "deeppolar.pbtxt") as f: + deeppolar_pbtxt = f.read() - # convert_pbtxt_to_pb( - # str(self.work_dir / "deepwfc.pbtxt"), - # str(self.work_dir / "deep_wfc.pb") - # ) + # not an actual globalpolar and wfc model, but still good enough for testing factory + with tempfile.NamedTemporaryFile(mode="w") as f: + f.write(deeppolar_pbtxt.replace("polar", "global_polar")) + f.flush() + convert_pbtxt_to_pb(f.name, str(self.work_dir / "deep_globalpolar.pb")) + + with tempfile.NamedTemporaryFile(mode="w") as f: + f.write(deeppolar_pbtxt.replace("polar", "wfc")) + f.flush() + convert_pbtxt_to_pb(f.name, str(self.work_dir / "deep_wfc.pb")) def tearDown(self): for f in self.work_dir.glob("*.pb"): @@ -62,11 +72,10 @@ def test_factory(self): dp = DeepPotential(self.work_dir / "deep_pot.pb") self.assertIsInstance(dp, DeepPot, msg.format(DeepPot, type(dp))) - # TODO add model files for globalpolar and WFC - # dp = DeepPotential(self.work_dir / "deep_globalpolar.pb") - # self.assertIsInstance( - # dp, DeepGlobalPolar, msg.format(DeepGlobalPolar, type(dp)) - # ) + dp = DeepPotential(self.work_dir / "deep_globalpolar.pb") + self.assertIsInstance( + dp, DeepGlobalPolar, msg.format(DeepGlobalPolar, type(dp)) + ) - # dp = DeepPotential(self.work_dir / "deep_wfc.pb") - # self.assertIsInstance(dp, DeepWFC, msg.format(DeepWFC, type(dp))) + dp = DeepPotential(self.work_dir / "deep_wfc.pb") + self.assertIsInstance(dp, DeepWFC, msg.format(DeepWFC, type(dp)))