diff --git a/deepmd/pt/model/model/make_hessian_model.py b/deepmd/pt/model/model/make_hessian_model.py index f8410a3e37..4269976095 100644 --- a/deepmd/pt/model/model/make_hessian_model.py +++ b/deepmd/pt/model/model/make_hessian_model.py @@ -115,8 +115,8 @@ def forward_common( def _cal_hessian_all( self, - coord, - atype, + coord: torch.Tensor, + atype: torch.Tensor, box: Optional[torch.Tensor] = None, fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, @@ -184,7 +184,7 @@ def __init__( self, obj: CM, ci: int, - atype, + atype: torch.Tensor, box: Optional[torch.Tensor], fparam: Optional[torch.Tensor], aparam: Optional[torch.Tensor],