Skip to content

Commit

Permalink
Fix single-task training&data stat
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 28, 2024
1 parent 004ebd6 commit 3812866
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 14 deletions.
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None)
}
for item in merged
]
descrpt.compute_input_stats(merged_tmp)
descrpt.compute_input_stats(merged_tmp, path)

def serialize(self) -> dict:
"""Serialize the obj to dict."""
Expand Down
9 changes: 4 additions & 5 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
BaseDescriptor,
)
from deepmd.pt.model.task import (
Fitting,
BaseFitting,
)

from .dp_model import (
Expand Down Expand Up @@ -61,7 +61,7 @@ def get_zbl_model(model_params):
fitting_net["out_dim"] = descriptor.get_dim_emb()
if "ener" in fitting_net["type"]:
fitting_net["return_energy"] = True
fitting = Fitting(**fitting_net)
fitting = BaseFitting(**fitting_net)
dp_model = DPAtomicModel(descriptor, fitting, type_map=model_params["type_map"])
# pairtab
filepath = model_params["use_srtab"]
Expand Down Expand Up @@ -97,9 +97,8 @@ def get_model(model_params):
fitting_net["out_dim"] = descriptor.get_dim_emb()
if "ener" in fitting_net["type"]:
fitting_net["return_energy"] = True
fitting = Fitting(**fitting_net)

model = EnergyModel(descriptor, fitting, type_map=model_params["type_map"])
fitting = BaseFitting(**fitting_net)
model = DPModel(descriptor, fitting, type_map=model_params["type_map"])
model.model_def_script = json.dumps(model_params)
return model

Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@
# in DPAtomicModel (and other classes), but this requires the developer aware
# of it when developing it...
class BaseModel(make_base_model()):
def __init__(self):
def __init__(self, *args, **kwargs):
"""Construct a basic model for different tasks."""
super().__init__()
super().__init__(*args, **kwargs)

def compute_or_load_stat(
self,
Expand Down
1 change: 1 addition & 0 deletions deepmd/utils/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def save_numpy(self, arr: np.ndarray) -> None:
if self._name in self._keys:
del self.root[self._name]
self.root.create_dataset(self._name, data=arr)
self.root.flush()

def glob(self, pattern: str) -> List["DPPath"]:
"""Search path using the glob pattern.
Expand Down
8 changes: 2 additions & 6 deletions examples/water/dpa2/input_torch.json
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
{
"_comment": "that's all",
"model": {
"type_embedding": {
"neuron": [
8
],
"tebd_input_mode": "concat"
},
"type_map": [
"O",
"H"
],
"descriptor": {
"type": "dpa2",
"tebd_dim": 8,
"repinit_rcut": 9.0,
"repinit_rcut_smth": 8.0,
"repinit_nsel": 120,
Expand Down Expand Up @@ -74,6 +69,7 @@
"_comment": " that's all"
},
"training": {
"stat_file": "./dpa2",
"training_data": {
"systems": [
"../data/data_0",
Expand Down
2 changes: 2 additions & 0 deletions examples/water/se_atten/input_torch.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
50,
100
],
"tebd_dim": 8,
"axis_neuron": 16,
"attn": 128,
"attn_layer": 2,
Expand Down Expand Up @@ -59,6 +60,7 @@
"_comment": " that's all"
},
"training": {
"stat_file": "./dpa1",
"training_data": {
"systems": [
"../data/data_0",
Expand Down
1 change: 1 addition & 0 deletions examples/water/se_e2_a/input_torch.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"_comment": " that's all"
},
"training": {
"stat_file": "./se_e2_a",
"training_data": {
"systems": [
"../data/data_0",
Expand Down

0 comments on commit 3812866

Please sign in to comment.