Skip to content

Commit

Permalink
Resolve conversations
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 8, 2024
1 parent d2b6f41 commit 993ee55
Show file tree
Hide file tree
Showing 22 changed files with 196 additions and 178 deletions.
8 changes: 3 additions & 5 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,13 @@ def distinguish_types(self) -> bool:
"""
pass

@abstractmethod
def compute_input_stats(self, merged):
"""Update mean and stddev for descriptor elements."""
pass
raise NotImplementedError

@abstractmethod
def init_desc_stat(self, stat_dict):
def init_desc_stat(self, **kwargs):
"""Initialize the model bias by the statistics."""
pass
raise NotImplementedError

@abstractmethod
def fwd(
Expand Down
8 changes: 3 additions & 5 deletions deepmd/dpmodel/fitting/make_base_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,13 @@ def fwd(
"""Calculate fitting."""
pass

@abstractmethod
def compute_output_stats(self, merged):
"""Update the output bias for fitting net."""
pass
raise NotImplementedError

@abstractmethod
def init_fitting_stat(self, result_dict):
def init_fitting_stat(self, **kwargs):
"""Initialize the model bias by the statistics."""
pass
raise NotImplementedError

@abstractmethod
def serialize(self) -> dict:
Expand Down
26 changes: 0 additions & 26 deletions deepmd/dpmodel/model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,32 +84,6 @@ def serialize(self) -> dict:
def deserialize(cls):
pass

def compute_or_load_stat(
self,
type_map=None,
sampled=None,
stat_file_path=None,
):
"""
Compute or load the statistics parameters of the model,
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
and saved in the `stat_file_path`(s).
When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
and load the calculated statistics parameters.
Parameters
----------
type_map
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
sampled
The sampled data frames from different data systems.
stat_file_path
The path to the statistics files.
"""
pass

def do_grad(
self,
var_name: Optional[str] = None,
Expand Down
36 changes: 7 additions & 29 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
)
from deepmd.pt.utils.stat import (
make_stat_input,
process_stat_path,
)
from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter

Expand Down Expand Up @@ -128,37 +129,14 @@ def prepare_trainer_input_single(

# stat files
hybrid_descrpt = model_params_single["descriptor"]["type"] == "hybrid"
has_stat_file_path = True
if not hybrid_descrpt:
model_params_single["stat_file_dir"] = data_dict_single.get(
"stat_file_dir", f"stat_files{suffix}"
has_stat_file_path = process_stat_path(
data_dict_single.get("stat_file", None),
data_dict_single.get("stat_file_dir", f"stat_files{suffix}"),
model_params_single,
Descriptor,
Fitting,
)
stat_file = data_dict_single.get("stat_file", None)
if stat_file is None:
stat_file = {}
if "descriptor" in model_params_single:
default_stat_file_name_descrpt = Descriptor.get_stat_name(
model_params_single["descriptor"],
len(model_params_single["type_map"]),
)
stat_file["descriptor"] = default_stat_file_name_descrpt
if "fitting_net" in model_params_single:
default_stat_file_name_fitting = Fitting.get_stat_name(
model_params_single["fitting_net"],
len(model_params_single["type_map"]),
)
stat_file["fitting_net"] = default_stat_file_name_fitting
model_params_single["stat_file_path"] = {
key: os.path.join(model_params_single["stat_file_dir"], stat_file[key])
for key in stat_file
}

has_stat_file_path_list = [
os.path.exists(model_params_single["stat_file_path"][key])
for key in stat_file
]
if False in has_stat_file_path_list:
has_stat_file_path = False
else: ### TODO hybrid descriptor not implemented
raise NotImplementedError(
"data stat for hybrid descriptor is not implemented!"
Expand Down
35 changes: 18 additions & 17 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,17 @@ class SomeDescript(Descriptor):
return Descriptor.__plugins.register(key)

@classmethod
def get_stat_name(cls, config, ntypes):
def get_stat_name(cls, ntypes, type_name, **kwargs):
"""
Get the name for the statistic file of the descriptor.
Usually use the combination of descriptor name, rcut, rcut_smth and sel as the statistic file name.
"""
if cls is not Descriptor:
raise NotImplementedError("get_stat_name is not implemented!")
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_stat_name(config, ntypes)
descrpt_type = type_name
return Descriptor.__plugins.plugins[descrpt_type].get_stat_name(
ntypes, type_name, **kwargs
)

@classmethod
def get_data_process_key(cls, config):
Expand All @@ -82,21 +84,22 @@ def get_data_process_key(cls, config):
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_data_process_key(config)

def get_data_stat_key(self):
@property
def data_stat_key(self):
"""
Get the keys for the data statistic of the descriptor.
Return a list of statistic names needed, such as "sumr", "suma" or "sumn".
"""
raise NotImplementedError("get_data_stat_key is not implemented!")
raise NotImplementedError("data_stat_key is not implemented!")

def set_stats(
def compute_or_load_stat(
self,
type_map: List[str],
sampled,
sampled=None,
stat_file_path: Optional[Union[str, List[str]]] = None,
):
"""
Set the statistics parameters for the descriptor.
Compute or load the statistics parameters of the descriptor.
Calculate and save the mean and standard deviation of the descriptor to `stat_file_path`
if `sampled` is not None, otherwise load them from `stat_file_path`.
Expand All @@ -111,7 +114,7 @@ def set_stats(
The path to the statistics files.
"""
# TODO support hybrid descriptor
descrpt_stat_key = self.get_data_stat_key()
descrpt_stat_key = self.data_stat_key
if sampled is not None: # compute the statistics results
tmp_dict = self.compute_input_stats(sampled)
result_dict = {key: tmp_dict[key] for key in descrpt_stat_key}
Expand All @@ -121,7 +124,7 @@ def set_stats(
else: # load the statistics results
assert stat_file_path is not None, "No stat file to load!"
result_dict = self.load_stats(type_map, stat_file_path)
self.init_desc_stat(result_dict)
self.init_desc_stat(**result_dict)

def save_stats(self, result_dict, stat_file_path: Union[str, List[str]]):
"""
Expand Down Expand Up @@ -159,7 +162,7 @@ def load_stats(self, type_map, stat_file_path: Union[str, List[str]]):
result_dict
The dictionary of statistics results.
"""
descrpt_stat_key = self.get_data_stat_key()
descrpt_stat_key = self.data_stat_key
target_type_map = type_map
if not isinstance(stat_file_path, list):
log.info(f"Loading stat file from {stat_file_path}")
Expand Down Expand Up @@ -272,15 +275,13 @@ def get_dim_emb(self) -> int:
"""Returns the embedding dimension."""
pass

@abstractmethod
def compute_input_stats(self, merged):
"""Update mean and stddev for DescriptorBlock elements."""
pass
raise NotImplementedError

@abstractmethod
def init_desc_stat(self, stat_dict):
"""Initialize the model bias by the statistics."""
pass
def init_desc_stat(self, **kwargs):
"""Initialize mean and stddev by the statistics."""
raise NotImplementedError

def share_params(self, base_class, shared_level, resume=False):
assert (
Expand Down
18 changes: 12 additions & 6 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,23 @@ def dim_emb(self):
def compute_input_stats(self, merged):
return self.se_atten.compute_input_stats(merged)

def init_desc_stat(self, stat_dict):
self.se_atten.init_desc_stat(stat_dict)
def init_desc_stat(
self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs
):
assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]]
self.se_atten.init_desc_stat(sumr, suma, sumn, sumr2, suma2)

@classmethod
def get_stat_name(cls, config, ntypes):
def get_stat_name(
cls, ntypes, type_name, rcut=None, rcut_smth=None, sel=None, **kwargs
):
"""
Get the name for the statistic file of the descriptor.
Usually use the combination of descriptor name, rcut, rcut_smth and sel as the statistic file name.
"""
descrpt_type = config["type"]
descrpt_type = type_name
assert descrpt_type in ["dpa1", "se_atten"]
return f'stat_file_descrpt_dpa1_rcut{config["rcut"]:.2f}_smth{config["rcut_smth"]:.2f}_sel{config["sel"]}_ntypes{ntypes}.npz'
return f"stat_file_descrpt_dpa1_rcut{rcut:.2f}_smth{rcut_smth:.2f}_sel{sel}_ntypes{ntypes}.npz"

@classmethod
def get_data_process_key(cls, config):
Expand All @@ -149,7 +154,8 @@ def get_data_process_key(cls, config):
assert descrpt_type in ["dpa1", "se_atten"]
return {"sel": config["sel"], "rcut": config["rcut"]}

def get_data_stat_key(self):
@property
def data_stat_key(self):
"""
Get the keys for the data statistic of the descriptor.
Return a list of statistic names needed, such as "sumr", "suma" or "sumn".
Expand Down
47 changes: 33 additions & 14 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,10 @@ def compute_input_stats(self, merged):
"suma2": suma2,
}

def init_desc_stat(self, stat_dict):
for key in ["sumr", "suma", "sumn", "sumr2", "suma2"]:
assert key in stat_dict, f"Statistics {key} not found in the dictionary!"
sumr = stat_dict["sumr"]
suma = stat_dict["suma"]
sumn = stat_dict["sumn"]
sumr2 = stat_dict["sumr2"]
suma2 = stat_dict["suma2"]
def init_desc_stat(
self, sumr=None, suma=None, sumn=None, sumr2=None, suma2=None, **kwargs
):
assert True not in [x is None for x in [sumr, suma, sumn, sumr2, suma2]]
for ii, descrpt in enumerate([self.repinit, self.repformers]):
stat_dict_ii = {
"sumr": sumr[ii],
Expand All @@ -329,19 +325,41 @@ def init_desc_stat(self, stat_dict):
"sumr2": sumr2[ii],
"suma2": suma2[ii],
}
descrpt.init_desc_stat(stat_dict_ii)
descrpt.init_desc_stat(**stat_dict_ii)

@classmethod
def get_stat_name(cls, config, ntypes):
def get_stat_name(
cls,
ntypes,
type_name,
repinit_rcut=None,
repinit_rcut_smth=None,
repinit_nsel=None,
repformer_rcut=None,
repformer_rcut_smth=None,
repformer_nsel=None,
**kwargs,
):
"""
Get the name for the statistic file of the descriptor.
Usually use the combination of descriptor name, rcut, rcut_smth and sel as the statistic file name.
"""
descrpt_type = config["type"]
descrpt_type = type_name
assert descrpt_type in ["dpa2"]
assert True not in [
x is None
for x in [
repinit_rcut,
repinit_rcut_smth,
repinit_nsel,
repformer_rcut,
repformer_rcut_smth,
repformer_nsel,
]
]
return (
f'stat_file_descrpt_dpa2_repinit_rcut{config["repinit_rcut"]:.2f}_smth{config["repinit_rcut_smth"]:.2f}_sel{config["repinit_nsel"]}'
f'_repformer_rcut{config["repformer_rcut"]:.2f}_smth{config["repformer_rcut_smth"]:.2f}_sel{config["repformer_nsel"]}_ntypes{ntypes}.npz'
f"stat_file_descrpt_dpa2_repinit_rcut{repinit_rcut:.2f}_smth{repinit_rcut_smth:.2f}_sel{repinit_nsel}"
f"_repformer_rcut{repformer_rcut:.2f}_smth{repformer_rcut_smth:.2f}_sel{repformer_nsel}_ntypes{ntypes}.npz"
)

@classmethod
Expand All @@ -358,7 +376,8 @@ def get_data_process_key(cls, config):
"rcut": [config["repinit_rcut"], config["repformer_rcut"]],
}

def get_data_stat_key(self):
@property
def data_stat_key(self):
"""
Get the keys for the data statistic of the descriptor.
Return a list of statistic names needed, such as "sumr", "suma" or "sumn".
Expand Down
9 changes: 1 addition & 8 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,14 +329,7 @@ def compute_input_stats(self, merged):
"suma2": suma2,
}

def init_desc_stat(self, stat_dict):
for key in ["sumr", "suma", "sumn", "sumr2", "suma2"]:
assert key in stat_dict, f"Statistics {key} not found in the dictionary!"
sumr = stat_dict["sumr"]
suma = stat_dict["suma"]
sumn = stat_dict["sumn"]
sumr2 = stat_dict["sumr2"]
suma2 = stat_dict["suma2"]
def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2):
all_davg = []
all_dstd = []
for type_i in range(self.ntypes):
Expand Down
Loading

0 comments on commit 993ee55

Please sign in to comment.