Skip to content

Commit

Permalink
support combining frozen models into a pairwise DPRc model (deepmodel…
Browse files Browse the repository at this point in the history
…ing#2902)

The "frozen" model has been previously supported in deepmodeling#2781, so this PR
allows it to be used in the pairwise DPRc model.
Allow QM model type maps to be just the first $N$ type maps of the whole
models, as the QM model doesn't need MM types.

One can load two separated models into one model or load the existing QM
model and train the rest of the model.

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Oct 7, 2023
1 parent 47d985d commit d0edb3a
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions deepmd/model/pairwise_dprc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@
TypeEmbedNet,
)

from .ener import (
EnerModel,
)


class PairwiseDPRc(Model):
"""Pairwise Deep Potential - Range Correction."""
Expand Down Expand Up @@ -87,13 +83,13 @@ def __init__(
padding=True,
)

self.qm_model = EnerModel(
self.qm_model = Model(
**qm_model,
type_map=type_map,
type_embedding=self.typeebd,
compress=compress,
)
self.qmmm_model = EnerModel(
self.qmmm_model = Model(
**qmmm_model,
type_map=type_map,
type_embedding=self.typeebd,
Expand Down Expand Up @@ -187,6 +183,14 @@ def build(

mesh_mixed_type = make_default_mesh(False, True)

# allow loading a frozen QM model that has only QM types
# Note: here we don't map the type between models, so
# the type of the frozen model must be the same as
# the first Ntypes of the current model
if self.get_ntypes() > self.qm_model.get_ntypes():
natoms_qm = tf.slice(natoms_qm, [0], [self.qm_model.get_ntypes() + 2])
assert self.get_ntypes() == self.qmmm_model.get_ntypes()

qm_dict = self.qm_model.build(
coord_qm,
atype_qm,
Expand Down Expand Up @@ -301,7 +305,7 @@ def get_rcut(self):
return max(self.qm_model.get_rcut(), self.qmmm_model.get_rcut())

def get_ntypes(self) -> int:
return self.qm_model.get_ntypes()
return self.ntypes

def data_stat(self, data):
self.qm_model.data_stat(data)
Expand Down

0 comments on commit d0edb3a

Please sign in to comment.