Skip to content

Commit

Permalink
From v1json (#103)
Browse files Browse the repository at this point in the history
* add example for from v1 json model using AlAs.

* fix :
1. change default of push from None  to False
2. fix bug in checking the model_options in model and input jdata.
3. fix bug in load v1 json model.

* update nnsk.py and test_sktb.py

* add test from v1 json.

* update config_sk.py and config_skenv.py

* update deeptb.py and nnsk.py

* update .gitignore and plot_from_v1.ipynb

* updata exampe of AlAs

* chore: create input.json

* add examle for load strain model from v1 json using silicon example

* add test for strain v1 json model

* fix: update train.py to make it able to load v1 josn model to further train the model.

* test: update test_sktb.py and test_trainer.py

* update train.py
  • Loading branch information
QG-phy authored Mar 27, 2024
1 parent b1e7fb6 commit 48846b9
Show file tree
Hide file tree
Showing 34 changed files with 1,587 additions and 67 deletions.
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ dptb/tests/**/*.traj
dptb/tests/**/out*/*
examples/_*
*.dat
*.vasp
*log*
dptb/tests/data/**/out*/config_*.json

bandstructure.npy
dptb/tests/data/hBN/data/set.0/xdat2.traj
dptb/tests/data/postrun/run_config.json
dptb/tests/data/test_all/test_config.json
Expand Down
91 changes: 48 additions & 43 deletions dptb/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,54 +98,59 @@ def train(
# since here we want to output jdata as a config file to inform the user what model options are used, we need to update the jdata

torch.set_default_dtype(getattr(torch, jdata["common_options"]["dtype"]))

if restart or init_model:

f = restart if restart else init_model
f = torch.load(f)

if jdata.get("model_options", None) is None:
jdata["model_options"] = f["config"]["model_options"]

# update basis
basis = f["config"]["common_options"]["basis"]
# nnsk
if len(f["config"]["model_options"])==1 and f["config"]["model_options"].get("nnsk") != None:
for asym, orb in jdata["common_options"]["basis"].items():
assert asym in basis.keys(), f"Atom {asym} not found in model's basis"
if orb != basis[asym]:
log.info(f"Initializing Orbital {orb} of Atom {asym} from {basis[asym]}")
# we have the orbitals in jdata basis correct, now we need to make sure all atom in basis are also contained in jdata basis
for asym, orb in basis.items():
if asym not in jdata["common_options"]["basis"].keys():
jdata["common_options"]["basis"][asym] = orb # add the atomtype in the checkpoint but not in the jdata basis, because it will be used to build the orbital mapper for dataset
else: # not nnsk
for asym, orb in jdata["common_options"]["basis"].items():
assert asym in basis.keys(), f"Atom {asym} not found in model's basis"
assert orb == basis[asym], f"Orbital {orb} of Atom {asym} not consistent with the model's basis, which is only allowed in nnsk training"

jdata["common_options"]["basis"] = basis

# update model options and train_options
if restart:
#
if jdata.get("train_options", None) is not None:
for obj in Trainer.object_keys:
if jdata["train_options"].get(obj) != f["config"]["train_options"].get(obj):
log.warning(f"{obj} in config file is not consistent with the checkpoint, using the one in checkpoint")
jdata["train_options"][obj] = f["config"]["train_options"][obj]
else:
jdata["train_options"] = f["config"]["train_options"]

if jdata.get("model_options", None) is None or jdata["model_options"] != f["config"]["model_options"]:
log.warning("model_options in config file is not consistent with the checkpoint, using the one in checkpoint")
jdata["model_options"] = f["config"]["model_options"] # restart does not allow to change model options
if f.split(".")[-1] == "json":
assert not restart, "json model can not be used as restart! should be a checkpoint file"
else:
# init model mode, allow model_options change
if jdata.get("train_options", None) is None:
jdata["train_options"] = f["config"]["train_options"]
if jdata.get("model_options") is None:
f = torch.load(f)

if jdata.get("model_options", None) is None:
jdata["model_options"] = f["config"]["model_options"]
del f

# update basis
basis = f["config"]["common_options"]["basis"]
# nnsk
if len(f["config"]["model_options"])==1 and f["config"]["model_options"].get("nnsk") != None:
for asym, orb in jdata["common_options"]["basis"].items():
assert asym in basis.keys(), f"Atom {asym} not found in model's basis"
if orb != basis[asym]:
log.info(f"Initializing Orbital {orb} of Atom {asym} from {basis[asym]}")
# we have the orbitals in jdata basis correct, now we need to make sure all atom in basis are also contained in jdata basis
for asym, orb in basis.items():
if asym not in jdata["common_options"]["basis"].keys():
jdata["common_options"]["basis"][asym] = orb # add the atomtype in the checkpoint but not in the jdata basis, because it will be used to build the orbital mapper for dataset
else: # not nnsk
for asym, orb in jdata["common_options"]["basis"].items():
assert asym in basis.keys(), f"Atom {asym} not found in model's basis"
assert orb == basis[asym], f"Orbital {orb} of Atom {asym} not consistent with the model's basis, which is only allowed in nnsk training"

jdata["common_options"]["basis"] = basis

# update model options and train_options
if restart:
#
if jdata.get("train_options", None) is not None:
for obj in Trainer.object_keys:
if jdata["train_options"].get(obj) != f["config"]["train_options"].get(obj):
log.warning(f"{obj} in config file is not consistent with the checkpoint, using the one in checkpoint")
jdata["train_options"][obj] = f["config"]["train_options"][obj]
else:
jdata["train_options"] = f["config"]["train_options"]

if jdata.get("model_options", None) is None or jdata["model_options"] != f["config"]["model_options"]:
log.warning("model_options in config file is not consistent with the checkpoint, using the one in checkpoint")
jdata["model_options"] = f["config"]["model_options"] # restart does not allow to change model options
else:
# init model mode, allow model_options change
if jdata.get("train_options", None) is None:
jdata["train_options"] = f["config"]["train_options"]
if jdata.get("model_options") is None:
jdata["model_options"] = f["config"]["model_options"]
del f
else:
j_must_have(jdata, "model_options")
j_must_have(jdata, "train_options")
Expand Down
31 changes: 28 additions & 3 deletions dptb/nn/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dptb.nn.nnsk import NNSK
import torch
from dptb.utils.tools import j_must_have
import copy

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -135,7 +136,31 @@ def build_model(run_options, model_options: dict={}, common_options: dict={}, st
# mix model can be initilized with a mixed reference model or a nnsk model.
model = MIX.from_reference(checkpoint, **model_options, **common_options)

if model.model_options != model_options:
# log.error("The model options are not consistent with the checkpoint, using the one in the checkpoint.")
raise ValueError("The model options are not consistent with the checkpoint, using the one in the checkpoint.")
for k, v in model.model_options.items():
if k not in model_options:
log.warning(f"The model options {k} is not defined in input model_options, set to {v}.")
else:
deep_dict_difference(k, v, model_options)

return model


def deep_dict_difference(base_key, expected_value, model_options):
"""
递归地记录嵌套字典中的选项差异。
:param base_key: 基础键名,用于构建警告消息的前缀。
:param expected_value: 期望的值,可能是字典或非字典类型。
:param model_options: 用于比较的模型选项字典。
"""
target_dict= copy.deepcopy(model_options) # 防止修改原始字典
if isinstance(expected_value, dict):
for subk, subv in expected_value.items():
if subk not in target_dict.get(base_key, {}):
log.warning(f"The model option {subk} in {base_key} is not defined in input model_options, set to {subv}.")
else:
target2 = copy.deepcopy(target_dict[base_key])
deep_dict_difference(f"{subk}", subv, target2)
else:
if expected_value != target_dict[base_key]:
log.warning(f"The model option {base_key} is set to {expected_value}, but in input it is {target_dict[base_key]}, make sure it it correct!")
1 change: 1 addition & 0 deletions dptb/nn/deeptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def __init__(
transform=False,
)
self.idp = self.nnsk.idp
assert not self.nnsk.push, "The push option is not supported in the mixed model. The push option is only supported in the nnsk model."

self.model_options = self.nnsk.model_options
self.model_options.update(self.dptb.model_options)
Expand Down
17 changes: 11 additions & 6 deletions dptb/nn/nnsk.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
device: Union[str, torch.device] = torch.device("cpu"),
transform: bool = True,
freeze: bool = False,
push: Dict=None,
push: Union[bool,dict]=False,
std: float = 0.01,
**kwargs,
) -> None:
Expand Down Expand Up @@ -179,7 +179,7 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
# reflect_params[:,self.idp_sk.pair_maps[k],:] += params[:,self.idp_sk.pair_maps[k_r],:]
# self.strain_param.data = reflect_params + params

if self.push is not None:
if self.push is not None and self.push is not False:
if abs(self.push.get("rs_thr")) + abs(self.push.get("rc_thr")) + abs(self.push.get("w_thr")) > 0:
self.push_decay(**self.push)

Expand Down Expand Up @@ -305,7 +305,7 @@ def from_reference(
dtype: Union[str, torch.dtype]=None,
device: Union[str, torch.device]=None,
push: Dict=None,
freeze: bool = None,
freeze: bool = False,
std: float = 0.01,
**kwargs,
):
Expand All @@ -332,7 +332,8 @@ def from_reference(
for k,v in common_options.items():
assert v is not None, f"You need to provide {k} when you are initializing a model from a json file."
for k,v in nnsk.items():
assert v is not None, f"You need to provide {k} when you are initializing a model from a json file."
if k != 'push':
assert v is not None, f"You need to provide {k} when you are initializing a model from a json file."

v1_model = j_loader(checkpoint)
model = cls._from_model_v1(
Expand All @@ -349,7 +350,7 @@ def from_reference(
if v is None:
common_options[k] = f["config"]["common_options"][k]
for k,v in nnsk.items():
if v is None and not k is "push" :
if v is None and k != "push" :
nnsk[k] = f["config"]["model_options"]["nnsk"][k]

model = cls(**common_options, **nnsk)
Expand Down Expand Up @@ -458,6 +459,9 @@ def _from_model_v1(
dtype: Union[str, torch.dtype] = torch.float32,
device: Union[str, torch.device] = torch.device("cpu"),
std: float = 0.01,
freeze: bool = False,
push: Union[bool,None,dict] = False,
**kwargs
):
# could support json file and .pth file checkpoint of nnsk

Expand All @@ -477,7 +481,8 @@ def _from_model_v1(
idp_sk.get_orbpair_maps()
idp_sk.get_skonsite_maps()

nnsk_model = cls(basis=basis, idp_sk=idp_sk, dtype=dtype, device=device, onsite=onsite, hopping=hopping, overlap=overlap, std=std)
nnsk_model = cls(basis=basis, idp_sk=idp_sk, onsite=onsite,
hopping=hopping, overlap=overlap, std=std,freeze=freeze, push=push, dtype=dtype, device=device,)

onsite_param = v1_model["onsite"]
hopping_param = v1_model["hopping"]
Expand Down
10 changes: 10 additions & 0 deletions dptb/tests/data/json_model/AlAs.vasp
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Primitive Cell
1.000000
0.00000000000000 2.83790087700000 2.83790087700000
2.83790087700000 0.00000000000000 2.83790087700000
2.83790087700000 2.83790087700000 0.00000000000000
Al As
1 1
DIRECT
0.0000000000000000 0.0000000000000000 0.0000000000000000 Al1
0.7500000000000000 0.7500000000000000 0.7500000000000000 As1
Loading

0 comments on commit 48846b9

Please sign in to comment.