Skip to content

Commit

Permalink
Fix EnergyFittingNetDirect
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 28, 2024
1 parent 08e18fe commit ae27607
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion deepmd/pt/model/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from deepmd.pt.model.task.ener import (
EnergyFittingNet,
EnergyFittingNetDirect,
)
from deepmd.pt.model.task.polarizability import (
PolarFittingNet,
Expand All @@ -36,7 +37,9 @@ def __new__(cls, descriptor, fitting, *args, **kwargs):
# according to the fitting network to decide the type of the model
if cls is DPModel:
# map fitting to model
if isinstance(fitting, EnergyFittingNet):
if isinstance(fitting, EnergyFittingNet) or isinstance(
fitting, EnergyFittingNetDirect
):
cls = EnergyModel
elif isinstance(fitting, DipoleFittingNet):
cls = DipoleModel
Expand Down

0 comments on commit ae27607

Please sign in to comment.