Skip to content

Commit

Permalink
get the receptive field
Browse files Browse the repository at this point in the history
  • Loading branch information
theAfish committed Nov 27, 2024
1 parent e695a91 commit d30f0d1
Show file tree
Hide file tree
Showing 11 changed files with 870 additions and 7 deletions.
50 changes: 50 additions & 0 deletions deepmd/pt/model/atomic_model/post_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
)

import torch

from deepmd.pt.model.task.ener import (
EnergyFittingNet,
)

from .dp_atomic_model import (
DPAtomicModel,
)


class DPPostAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, post_net, type_map, **kwargs):
assert isinstance(fitting, EnergyFittingNet)
super().__init__(descriptor, fitting, type_map, **kwargs)
self.post_net = post_net

def apply_out_stat(
self,
ret: Dict[str, torch.Tensor],
atype: torch.Tensor,
):
"""Apply the stat to each atomic output.
This function defines how the bias is applied to the atomic output of the model.
Parameters
----------
ret
The returned dict by the forward_atomic method
atype
The atom types. nf x nloc
"""
if self.fitting_net.get_bias_method() == "normal":
out_bias, out_std = self._fetch_out_stat(self.bias_keys)
for kk in self.bias_keys:
# nf x nloc x odims, out_bias: ntypes x odims
ret[kk] = ret[kk] + out_bias[kk][atype]
return ret
elif self.fitting_net.get_bias_method() == "no_bias":
return ret
else:
raise NotImplementedError(
"Only 'normal' and 'no_bias' is supported for parameter 'bias_method'."
)
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def forward(
model_predict["force"] = model_ret["dforce"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
if "debug" in model_ret:
model_predict["debug"] = model_ret["debug"]
else:
model_predict = model_ret
model_predict["updated_coord"] += coord
Expand Down
133 changes: 133 additions & 0 deletions deepmd/pt/model/model/ener_rec_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Optional,
)

import torch

from deepmd.pt.model.atomic_model import (
DPEnergyAtomicModel,
)
from deepmd.pt.model.model.model import (
BaseModel,
)

from .dp_model import (
DPModelCommon,
)
from .make_model import (
make_model,
)

DPEnergyModel_ = make_model(DPEnergyAtomicModel)


@BaseModel.register("ener_rec")
class EnergyModel(DPModelCommon, DPEnergyModel_):
model_type = "ener_rec"

def __init__(
self,
*args,
**kwargs,
) -> None:
DPModelCommon.__init__(self)
DPEnergyModel_.__init__(self, *args, **kwargs)

def translated_output_def(self):
out_def_data = self.model_output_def().get_data()
output_def = {
"atom_energy": out_def_data["energy"],
"energy": out_def_data["energy_redu"],
}
if self.do_grad_r("energy"):
output_def["force"] = out_def_data["energy_derv_r"]
output_def["force"].squeeze(-2)
if self.do_grad_c("energy"):
output_def["virial"] = out_def_data["energy_derv_c_redu"]
output_def["virial"].squeeze(-2)
output_def["atom_virial"] = out_def_data["energy_derv_c"]
output_def["atom_virial"].squeeze(-3)
if "mask" in out_def_data:
output_def["mask"] = out_def_data["mask"]
return output_def

def forward(
self,
coord,
atype,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> dict[str, torch.Tensor]:
model_ret = self.forward_common(
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.get_fitting_net() is not None:
model_predict = {}
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
if self.do_grad_r("energy"):
model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2)
if self.do_grad_c("energy"):
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(
-3
)
else:
model_predict["force"] = model_ret["dforce"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
else:
model_predict = model_ret
model_predict["updated_coord"] += coord
return model_predict

@torch.jit.export
def forward_lower(
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[dict[str, torch.Tensor]] = None,
):
model_ret = self.forward_common_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
comm_dict=comm_dict,
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
)
if self.get_fitting_net() is not None:
model_predict = {}
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
if self.do_grad_r("energy"):
model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2)
if self.do_grad_c("energy"):
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["extended_virial"] = model_ret[
"energy_derv_c"
].squeeze(-3)
else:
assert model_ret["dforce"] is not None
model_predict["dforce"] = model_ret["dforce"]
else:
model_predict = model_ret
return model_predict
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,8 @@ def output_type_cast(
)
pp = self.precision_dict[input_prec]
odef = self.model_output_def()
if "debug" in model_ret.keys():
model_ret["debug"] = model_ret["debug"].to(pp)
for kk in odef.keys():
if kk not in model_ret.keys():
# do not return energy_derv_c if not do_atomic_virial
Expand Down
113 changes: 113 additions & 0 deletions deepmd/pt/model/model/post_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from copy import (
deepcopy,
)
from typing import (
Dict,
Optional,
)

import torch

from deepmd.pt.model.atomic_model import (
DPPostAtomicModel,
)
from deepmd.pt.model.model.model import (
BaseModel,
)

from .dp_model import (
DPModelCommon,
)
from .make_model import (
make_model,
)

DPPostModel_ = make_model(DPPostAtomicModel)


@BaseModel.register("post")
class PostModel(DPModelCommon, DPPostModel_):
model_type = "post"

def __init__(
self,
*args,
**kwargs,
):
DPModelCommon.__init__(self)
DPPostModel_.__init__(self, *args, **kwargs)

def translated_output_def(self):
out_def_data = self.model_output_def().get_data()
output_def = {
"atom_property": deepcopy(out_def_data["property"]),
"property": deepcopy(out_def_data["property_redu"]),
}
if "mask" in out_def_data:
output_def["mask"] = deepcopy(out_def_data["mask"])
return output_def

def forward(
self,
coord,
atype,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
model_ret = self.forward_common(
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
model_predict = {}
model_predict["atom_property"] = model_ret["property"]
model_predict["property"] = model_ret["property_redu"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
return model_predict

@torch.jit.export
def get_task_dim(self) -> int:
"""Get the output dimension of PropertyFittingNet."""
return self.get_fitting_net().dim_out

@torch.jit.export
def get_intensive(self) -> bool:
"""Get whether the property is intensive."""
return self.model_output_def()["property"].intensive

@torch.jit.export
def forward_lower(
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
model_ret = self.forward_common_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
comm_dict=comm_dict,
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
)
model_predict = {}
model_predict["atom_property"] = model_ret["property"]
model_predict["property"] = model_ret["property_redu"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
return model_predict
Loading

0 comments on commit d30f0d1

Please sign in to comment.