Skip to content

Commit

Permalink
Remove data stat from model init
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 5, 2024
1 parent 701b913 commit 3386133
Show file tree
Hide file tree
Showing 31 changed files with 225 additions and 316 deletions.
16 changes: 16 additions & 0 deletions deepmd/dpmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
PRECISION_DICT,
NativeOP,
)
from .descriptor import (
DescrptSeA,
)
from .model import (
DPAtomicModel,
DPModel,
Expand All @@ -17,13 +20,26 @@
get_reduce_name,
model_check_output,
)
from .utils import (
EmbeddingNet,
EnvMat,
FittingNet,
NativeLayer,
NativeNet,
)

__all__ = [
"DPModel",
"DPAtomicModel",
"PRECISION_DICT",
"DEFAULT_PRECISION",
"NativeOP",
"EnvMat",
"NativeLayer",
"NativeNet",
"EmbeddingNet",
"FittingNet",
"DescrptSeA",
"ModelOutputDef",
"FittingOutputDef",
"OutputVariableDef",
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
assert not self.multi_task, "multitask mode currently not supported!"
self.type_split = self.input_param["descriptor"]["type"] in ["se_e2_a"]
self.type_map = self.input_param["type_map"]
self.dp = ModelWrapper(get_model(self.input_param, None).to(DEVICE))
self.dp = ModelWrapper(get_model(self.input_param).to(DEVICE))
self.dp.load_state_dict(state_dict)
self.rcut = self.dp.model["Default"].descriptor.get_rcut()
self.sec = np.cumsum(self.dp.model["Default"].descriptor.get_sel())
Expand Down
8 changes: 8 additions & 0 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,19 @@ class SomeDescript(Descriptor):

@classmethod
def get_stat_name(cls, config):
"""Get the name for the statistic file of the descriptor."""
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_stat_name(config)

@classmethod
def get_data_stat_key(cls, config):
"""Get the keys for the data statistic of the descriptor."""
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_data_stat_key(config)

@classmethod
def get_data_process_key(cls, config):
"""Get the keys for the data preprocess."""
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_data_process_key(config)

Expand Down
7 changes: 7 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,19 @@ def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2):

@classmethod
def get_stat_name(cls, config):
"""Get the name for the statistic file of the descriptor."""
descrpt_type = config["type"]
assert descrpt_type in ["dpa1", "se_atten"]
return f'stat_file_dpa1_rcut{config["rcut"]:.2f}_smth{config["rcut_smth"]:.2f}_sel{config["sel"]}.npz'

@classmethod
def get_data_stat_key(cls, config):
"""Get the keys for the data statistic of the descriptor."""
return ["sumr", "suma", "sumn", "sumr2", "suma2"]

@classmethod
def get_data_process_key(cls, config):
"""Get the keys for the data preprocess."""
descrpt_type = config["type"]
assert descrpt_type in ["dpa1", "se_atten"]
return {"sel": config["sel"], "rcut": config["rcut"]}
Expand Down
7 changes: 7 additions & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,15 +316,22 @@ def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2):

@classmethod
def get_stat_name(cls, config):
"""Get the name for the statistic file of the descriptor."""
descrpt_type = config["type"]
assert descrpt_type in ["dpa2"]
return (
f'stat_file_dpa2_repinit_rcut{config["repinit_rcut"]:.2f}_smth{config["repinit_rcut_smth"]:.2f}_sel{config["repinit_nsel"]}'
f'_repformer_rcut{config["repformer_rcut"]:.2f}_smth{config["repformer_rcut_smth"]:.2f}_sel{config["repformer_nsel"]}.npz'
)

@classmethod
def get_data_stat_key(cls, config):
"""Get the keys for the data statistic of the descriptor."""
return ["sumr", "suma", "sumn", "sumr2", "suma2"]

@classmethod
def get_data_process_key(cls, config):
"""Get the keys for the data preprocess."""
descrpt_type = config["type"]
assert descrpt_type in ["dpa2"]
return {
Expand Down
7 changes: 7 additions & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,19 @@ def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2):

@classmethod
def get_stat_name(cls, config):
"""Get the name for the statistic file of the descriptor."""
descrpt_type = config["type"]
assert descrpt_type in ["se_e2_a"]
return f'stat_file_sea_rcut{config["rcut"]:.2f}_smth{config["rcut_smth"]:.2f}_sel{config["sel"]}.npz'

@classmethod
def get_data_stat_key(cls, config):
"""Get the keys for the data statistic of the descriptor."""
return ["sumr", "suma", "sumn", "sumr2", "suma2"]

@classmethod
def get_data_process_key(cls, config):
"""Get the keys for the data preprocess."""
descrpt_type = config["type"]
assert descrpt_type in ["se_e2_a"]
return {"sel": config["sel"], "rcut": config["rcut"]}
Expand Down
13 changes: 2 additions & 11 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)


def get_model(model_params, sampled=None):
def get_model(model_params):
model_params = copy.deepcopy(model_params)
ntypes = len(model_params["type_map"])
# descriptor
Expand All @@ -35,16 +35,7 @@ def get_model(model_params, sampled=None):
fitting_net["return_energy"] = True
fitting = Fitting(**fitting_net)

return EnergyModel(
descriptor,
fitting,
type_map=model_params["type_map"],
type_embedding=model_params.get("type_embedding", None),
resuming=model_params.get("resuming", False),
stat_file_dir=model_params.get("stat_file_dir", None),
stat_file_path=model_params.get("stat_file_path", None),
sampled=sampled,
)
return EnergyModel(descriptor, fitting, type_map=model_params["type_map"])


__all__ = [
Expand Down
43 changes: 2 additions & 41 deletions deepmd/pt/model/model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,31 +39,9 @@ class DPAtomicModel(BaseModel, BaseAtomicModel):
type_map
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
type_embedding
Type embedding net
resuming
Whether to resume/fine-tune from checkpoint or not.
stat_file_dir
The directory to the state files.
stat_file_path
The path to the state files.
sampled
Sampled frames to compute the statistics.
"""

# I am enough with the shit interface!
def __init__(
self,
descriptor,
fitting,
type_map: Optional[List[str]],
type_embedding: Optional[dict] = None,
resuming: bool = False,
stat_file_dir=None,
stat_file_path=None,
sampled=None,
**kwargs,
):
def __init__(self, descriptor, fitting, type_map: Optional[List[str]]):
super().__init__()
ntypes = len(type_map)
self.type_map = type_map
Expand All @@ -72,17 +50,6 @@ def __init__(
self.rcut = self.descriptor.get_rcut()
self.sel = self.descriptor.get_sel()
self.fitting_net = fitting
# Statistics
fitting_net = None # TODO: hack!!! not sure if it is correct.
self.compute_or_load_stat(
fitting_net,
ntypes,
resuming=resuming,
type_map=type_map,
stat_file_dir=stat_file_dir,
stat_file_path=stat_file_path,
sampled=sampled,
)

def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
Expand Down Expand Up @@ -122,13 +89,7 @@ def deserialize(cls, data) -> "DPAtomicModel":
fitting_obj = getattr(sys.modules[__name__], data["fitting_name"]).deserialize(
data["fitting"]
)
# TODO: dirty hack to provide type_map and avoid data stat!!!
obj = cls(
descriptor_obj,
fitting_obj,
type_map=data["type_map"],
resuming=True,
)
obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"])
return obj

def forward_atomic(
Expand Down
Loading

0 comments on commit 3386133

Please sign in to comment.