Skip to content

Commit

Permalink
Update se_r.py
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 28, 2024
1 parent 463f9fb commit e8575af
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,34 @@ def mixed_types(self) -> bool:
"""
return False

def share_params(self, base_class, shared_level, resume=False):
assert (
self.__class__ == base_class.__class__
), "Only descriptors of the same type can share params!"
# For SeR descriptors, the user-defined share-level
# shared_level: 0
if shared_level == 0:
# link buffers
if hasattr(self, "mean") and not resume:
# in case of change params during resume
base_env = EnvMatStatSe(base_class)
base_env.stats = base_class.stats
for kk in base_class.get_stats():
base_env.stats[kk] += self.get_stats()[kk]
mean, stddev = base_env()
if not base_class.set_davg_zero:
base_class.mean.copy_(torch.tensor(mean, device=env.DEVICE))
base_class.stddev.copy_(torch.tensor(stddev, device=env.DEVICE))
self.mean = base_class.mean
self.stddev = base_class.stddev
# self.load_state_dict(base_class.state_dict()) # this does not work, because it only inits the model
# the following will successfully link all the params except buffers
for item in self._modules:
self._modules[item] = base_class._modules[item]
# Other shared levels
else:
raise NotImplementedError

def compute_input_stats(
self, merged: Union[Callable, List[dict]], path: Optional[DPPath] = None
):
Expand Down

0 comments on commit e8575af

Please sign in to comment.