From 42724ce17bcb367862715d8978374ef1f377009f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 17 May 2024 02:22:04 -0400 Subject: [PATCH] chore: improve type anotations in deepmd.infer (#3792) Fix several incorrect type anotations. ## Summary by CodeRabbit - **New Features** - Enhanced flexibility in function and method parameters, allowing for more versatile use cases. - **Improvements** - Streamlined type annotations for improved code maintainability and readability. - Updated import statements for better module organization and efficiency. - **Bug Fixes** - Corrected parameter types to ensure proper handling of optional and varied input types. These changes aim to improve the overall usability and robustness of the application, making it more adaptable to different scenarios. Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/infer/deep_eval.py | 8 +++--- deepmd/entrypoints/test.py | 9 +----- deepmd/infer/deep_dos.py | 3 +- deepmd/infer/deep_eval.py | 23 +++++++-------- deepmd/infer/deep_polar.py | 2 +- deepmd/infer/deep_pot.py | 47 +++++++++++++++++++++++++++++-- deepmd/infer/model_devi.py | 14 +++++++-- deepmd/pt/infer/deep_eval.py | 11 ++++---- deepmd/tf/infer/deep_eval.py | 11 ++++---- deepmd/tf/infer/deep_tensor.py | 4 +-- 10 files changed, 89 insertions(+), 43 deletions(-) diff --git a/deepmd/dpmodel/infer/deep_eval.py b/deepmd/dpmodel/infer/deep_eval.py index 1db5d539cf..2a006f1bd6 100644 --- a/deepmd/dpmodel/infer/deep_eval.py +++ b/deepmd/dpmodel/infer/deep_eval.py @@ -76,10 +76,10 @@ def __init__( self, model_file: str, output_def: ModelOutputDef, - *args: List[Any], + *args: Any, auto_batch_size: Union[bool, int, AutoBatchSize] = True, neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ): self.output_def = output_def self.model_path = model_file @@ -161,12 +161,12 @@ def get_ntypes_spin(self): def eval( self, coords: np.ndarray, - cells: np.ndarray, + cells: Optional[np.ndarray], atom_types: np.ndarray, atomic: bool = False, fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> Dict[str, np.ndarray]: """Evaluate the energy, force and virial by using this DP. diff --git a/deepmd/entrypoints/test.py b/deepmd/entrypoints/test.py index 5f4e758f0b..1d2248aa0d 100644 --- a/deepmd/entrypoints/test.py +++ b/deepmd/entrypoints/test.py @@ -46,14 +46,7 @@ ) if TYPE_CHECKING: - from deepmd.tf.infer import ( - DeepDipole, - DeepDOS, - DeepPolar, - DeepPot, - DeepWFC, - ) - from deepmd.tf.infer.deep_tensor import ( + from deepmd.infer.deep_tensor import ( DeepTensor, ) diff --git a/deepmd/infer/deep_dos.py b/deepmd/infer/deep_dos.py index 7823f02999..c8d55560b6 100644 --- a/deepmd/infer/deep_dos.py +++ b/deepmd/infer/deep_dos.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Any, - Dict, List, Optional, Tuple, @@ -70,7 +69,7 @@ def eval( fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, mixed_type: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> Tuple[np.ndarray, ...]: """Evaluate energy, force, and virial. If atomic is True, also return atomic energy and atomic virial. diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index aae2082e13..5a00ba616d 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -11,6 +11,7 @@ List, Optional, Tuple, + Type, Union, ) @@ -82,10 +83,10 @@ def __init__( self, model_file: str, output_def: ModelOutputDef, - *args: List[Any], + *args: Any, auto_batch_size: Union[bool, int, AutoBatchSize] = True, neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: pass @@ -99,12 +100,12 @@ def __new__(cls, model_file: str, *args, **kwargs): def eval( self, coords: np.ndarray, - cells: np.ndarray, + cells: Optional[np.ndarray], atom_types: np.ndarray, atomic: bool = False, fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> Dict[str, np.ndarray]: """Evaluate the energy, force and virial by using this DP. @@ -166,13 +167,13 @@ def get_dim_aparam(self) -> int: def eval_descriptor( self, coords: np.ndarray, - cells: np.ndarray, + cells: Optional[np.ndarray], atom_types: np.ndarray, fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, efield: Optional[np.ndarray] = None, mixed_type: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> np.ndarray: """Evaluate descriptors by using this DP. @@ -246,11 +247,11 @@ def _check_mixed_types(self, atom_types: np.ndarray) -> bool: # assume mixed_types if there are virtual types, even when # the atom types of all frames are the same return False - return np.all(np.equal(atom_types, atom_types[0])) + return np.all(np.equal(atom_types, atom_types[0])).item() @property @abstractmethod - def model_type(self) -> "DeepEval": + def model_type(self) -> Type["DeepEval"]: """The the evaluator of the model type.""" @abstractmethod @@ -316,10 +317,10 @@ def __new__(cls, model_file: str, *args, **kwargs): def __init__( self, model_file: str, - *args: List[Any], + *args: Any, auto_batch_size: Union[bool, int, AutoBatchSize] = True, neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: self.deep_eval = DeepEvalBackend( model_file, @@ -387,7 +388,7 @@ def eval_descriptor( fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, mixed_type: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> np.ndarray: """Evaluate descriptors by using this DP. diff --git a/deepmd/infer/deep_polar.py b/deepmd/infer/deep_polar.py index f857619871..c2089b278d 100644 --- a/deepmd/infer/deep_polar.py +++ b/deepmd/infer/deep_polar.py @@ -50,7 +50,7 @@ def eval( fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, mixed_type: bool = False, - **kwargs: dict, + **kwargs, ) -> np.ndarray: """Evaluate the model. diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index bc0bfc9599..401698bb14 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -1,11 +1,12 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Any, - Dict, List, + Literal, Optional, Tuple, Union, + overload, ) import numpy as np @@ -89,6 +90,48 @@ def output_def_mag(self) -> ModelOutputDef: ) ) + @overload + def eval( + self, + coords: np.ndarray, + cells: Optional[np.ndarray], + atom_types: Union[List[int], np.ndarray], + atomic: Literal[True], + fparam: Optional[np.ndarray], + aparam: Optional[np.ndarray], + mixed_type: bool, + **kwargs: Any, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + pass + + @overload + def eval( + self, + coords: np.ndarray, + cells: Optional[np.ndarray], + atom_types: Union[List[int], np.ndarray], + atomic: Literal[False], + fparam: Optional[np.ndarray], + aparam: Optional[np.ndarray], + mixed_type: bool, + **kwargs: Any, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + pass + + @overload + def eval( + self, + coords: np.ndarray, + cells: Optional[np.ndarray], + atom_types: Union[List[int], np.ndarray], + atomic: bool, + fparam: Optional[np.ndarray], + aparam: Optional[np.ndarray], + mixed_type: bool, + **kwargs: Any, + ) -> Tuple[np.ndarray, ...]: + pass + def eval( self, coords: np.ndarray, @@ -98,7 +141,7 @@ def eval( fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, mixed_type: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> Tuple[np.ndarray, ...]: """Evaluate energy, force, and virial. If atomic is True, also return atomic energy and atomic virial. diff --git a/deepmd/infer/model_devi.py b/deepmd/infer/model_devi.py index 477acf0282..61025bcb70 100644 --- a/deepmd/infer/model_devi.py +++ b/deepmd/infer/model_devi.py @@ -28,7 +28,7 @@ def calc_model_devi_f( fs: np.ndarray, real_f: Optional[np.ndarray] = None, relative: Optional[float] = None, - atomic: Literal[False] = False, + atomic: Literal[False] = ..., ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: ... @@ -37,11 +37,19 @@ def calc_model_devi_f( fs: np.ndarray, real_f: Optional[np.ndarray] = None, relative: Optional[float] = None, - *, - atomic: Literal[True], + atomic: Literal[True] = ..., ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ... +@overload +def calc_model_devi_f( + fs: np.ndarray, + real_f: Optional[np.ndarray] = None, + relative: Optional[float] = None, + atomic: bool = False, +) -> Tuple[np.ndarray, ...]: ... + + def calc_model_devi_f( fs: np.ndarray, real_f: Optional[np.ndarray] = None, diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 8a3a61400d..0e3dd292cb 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -7,6 +7,7 @@ List, Optional, Tuple, + Type, Union, ) @@ -87,11 +88,11 @@ def __init__( self, model_file: str, output_def: ModelOutputDef, - *args: List[Any], + *args: Any, auto_batch_size: Union[bool, int, AutoBatchSize] = True, neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None, head: Optional[str] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ): self.output_def = output_def self.model_path = model_file @@ -165,7 +166,7 @@ def get_dim_aparam(self) -> int: return self.dp.model["Default"].get_dim_aparam() @property - def model_type(self) -> "DeepEvalWrapper": + def model_type(self) -> Type["DeepEvalWrapper"]: """The the evaluator of the model type.""" model_output_type = self.dp.model["Default"].model_output_type() if "energy" in model_output_type: @@ -211,12 +212,12 @@ def get_has_spin(self): def eval( self, coords: np.ndarray, - cells: np.ndarray, + cells: Optional[np.ndarray], atom_types: np.ndarray, atomic: bool = False, fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> Dict[str, np.ndarray]: """Evaluate the energy, force and virial by using this DP. diff --git a/deepmd/tf/infer/deep_eval.py b/deepmd/tf/infer/deep_eval.py index ccbd44cf97..825ac6704a 100644 --- a/deepmd/tf/infer/deep_eval.py +++ b/deepmd/tf/infer/deep_eval.py @@ -10,6 +10,7 @@ List, Optional, Tuple, + Type, Union, ) @@ -262,7 +263,7 @@ def _init_attr(self): @property @lru_cache(maxsize=None) - def model_type(self) -> "DeepEvalWrapper": + def model_type(self) -> Type["DeepEvalWrapper"]: """Get type of model. :type:str @@ -693,13 +694,13 @@ def _get_natoms_and_nframes( def eval( self, coords: np.ndarray, - cells: np.ndarray, + cells: Optional[np.ndarray], atom_types: np.ndarray, atomic: bool = False, fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, efield: Optional[np.ndarray] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> Dict[str, np.ndarray]: """Evaluate the energy, force and virial by using this DP. @@ -1023,7 +1024,7 @@ def _get_output_shape(self, odef, nframes, natoms): def eval_descriptor( self, coords: np.ndarray, - cells: np.ndarray, + cells: Optional[np.ndarray], atom_types: np.ndarray, fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, @@ -1080,7 +1081,7 @@ def eval_descriptor( def _eval_descriptor_inner( self, coords: np.ndarray, - cells: np.ndarray, + cells: Optional[np.ndarray], atom_types: np.ndarray, fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, diff --git a/deepmd/tf/infer/deep_tensor.py b/deepmd/tf/infer/deep_tensor.py index 59fdab7cd1..c3ca22847e 100644 --- a/deepmd/tf/infer/deep_tensor.py +++ b/deepmd/tf/infer/deep_tensor.py @@ -146,7 +146,7 @@ def get_dim_aparam(self) -> int: def eval( self, coords: np.ndarray, - cells: np.ndarray, + cells: Optional[np.ndarray], atom_types: List[int], atomic: bool = True, fparam: Optional[np.ndarray] = None, @@ -276,7 +276,7 @@ def eval( def eval_full( self, coords: np.ndarray, - cells: np.ndarray, + cells: Optional[np.ndarray], atom_types: List[int], atomic: bool = False, fparam: Optional[np.array] = None,