Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

From v1json #103

Merged
merged 14 commits into from
Mar 27, 2024
Merged
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ dptb/tests/**/*.traj
dptb/tests/**/out*/*
examples/_*
*.dat
*.vasp
*log*
dptb/tests/data/**/out*/config_*.json

bandstructure.npy
dptb/tests/data/hBN/data/set.0/xdat2.traj
dptb/tests/data/postrun/run_config.json
dptb/tests/data/test_all/test_config.json
Expand Down
91 changes: 48 additions & 43 deletions dptb/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,54 +98,59 @@ def train(
# since here we want to output jdata as a config file to inform the user what model options are used, we need to update the jdata

torch.set_default_dtype(getattr(torch, jdata["common_options"]["dtype"]))

if restart or init_model:

f = restart if restart else init_model
f = torch.load(f)

if jdata.get("model_options", None) is None:
jdata["model_options"] = f["config"]["model_options"]

# update basis
basis = f["config"]["common_options"]["basis"]
# nnsk
if len(f["config"]["model_options"])==1 and f["config"]["model_options"].get("nnsk") != None:
for asym, orb in jdata["common_options"]["basis"].items():
assert asym in basis.keys(), f"Atom {asym} not found in model's basis"
if orb != basis[asym]:
log.info(f"Initializing Orbital {orb} of Atom {asym} from {basis[asym]}")
# we have the orbitals in jdata basis correct, now we need to make sure all atom in basis are also contained in jdata basis
for asym, orb in basis.items():
if asym not in jdata["common_options"]["basis"].keys():
jdata["common_options"]["basis"][asym] = orb # add the atomtype in the checkpoint but not in the jdata basis, because it will be used to build the orbital mapper for dataset
else: # not nnsk
for asym, orb in jdata["common_options"]["basis"].items():
assert asym in basis.keys(), f"Atom {asym} not found in model's basis"
assert orb == basis[asym], f"Orbital {orb} of Atom {asym} not consistent with the model's basis, which is only allowed in nnsk training"

jdata["common_options"]["basis"] = basis

# update model options and train_options
if restart:
#
if jdata.get("train_options", None) is not None:
for obj in Trainer.object_keys:
if jdata["train_options"].get(obj) != f["config"]["train_options"].get(obj):
log.warning(f"{obj} in config file is not consistent with the checkpoint, using the one in checkpoint")
jdata["train_options"][obj] = f["config"]["train_options"][obj]
else:
jdata["train_options"] = f["config"]["train_options"]

if jdata.get("model_options", None) is None or jdata["model_options"] != f["config"]["model_options"]:
log.warning("model_options in config file is not consistent with the checkpoint, using the one in checkpoint")
jdata["model_options"] = f["config"]["model_options"] # restart does not allow to change model options
if f.split(".")[-1] == "json":
assert not restart, "json model can not be used as restart! should be a checkpoint file"
else:
# init model mode, allow model_options change
if jdata.get("train_options", None) is None:
jdata["train_options"] = f["config"]["train_options"]
if jdata.get("model_options") is None:
f = torch.load(f)

if jdata.get("model_options", None) is None:
jdata["model_options"] = f["config"]["model_options"]
del f

# update basis
basis = f["config"]["common_options"]["basis"]
# nnsk
if len(f["config"]["model_options"])==1 and f["config"]["model_options"].get("nnsk") != None:
for asym, orb in jdata["common_options"]["basis"].items():
assert asym in basis.keys(), f"Atom {asym} not found in model's basis"
if orb != basis[asym]:
log.info(f"Initializing Orbital {orb} of Atom {asym} from {basis[asym]}")
# we have the orbitals in jdata basis correct, now we need to make sure all atom in basis are also contained in jdata basis
for asym, orb in basis.items():
if asym not in jdata["common_options"]["basis"].keys():
jdata["common_options"]["basis"][asym] = orb # add the atomtype in the checkpoint but not in the jdata basis, because it will be used to build the orbital mapper for dataset
else: # not nnsk
for asym, orb in jdata["common_options"]["basis"].items():
assert asym in basis.keys(), f"Atom {asym} not found in model's basis"
assert orb == basis[asym], f"Orbital {orb} of Atom {asym} not consistent with the model's basis, which is only allowed in nnsk training"

jdata["common_options"]["basis"] = basis

# update model options and train_options
if restart:
#
if jdata.get("train_options", None) is not None:
for obj in Trainer.object_keys:
if jdata["train_options"].get(obj) != f["config"]["train_options"].get(obj):
log.warning(f"{obj} in config file is not consistent with the checkpoint, using the one in checkpoint")
jdata["train_options"][obj] = f["config"]["train_options"][obj]
else:
jdata["train_options"] = f["config"]["train_options"]

if jdata.get("model_options", None) is None or jdata["model_options"] != f["config"]["model_options"]:
log.warning("model_options in config file is not consistent with the checkpoint, using the one in checkpoint")
jdata["model_options"] = f["config"]["model_options"] # restart does not allow to change model options
else:
# init model mode, allow model_options change
if jdata.get("train_options", None) is None:
jdata["train_options"] = f["config"]["train_options"]
if jdata.get("model_options") is None:
jdata["model_options"] = f["config"]["model_options"]
del f
else:
j_must_have(jdata, "model_options")
j_must_have(jdata, "train_options")
Expand Down
31 changes: 28 additions & 3 deletions dptb/nn/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dptb.nn.nnsk import NNSK
import torch
from dptb.utils.tools import j_must_have
import copy

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -135,7 +136,31 @@ def build_model(run_options, model_options: dict={}, common_options: dict={}, st
# mix model can be initilized with a mixed reference model or a nnsk model.
model = MIX.from_reference(checkpoint, **model_options, **common_options)

if model.model_options != model_options:
# log.error("The model options are not consistent with the checkpoint, using the one in the checkpoint.")
raise ValueError("The model options are not consistent with the checkpoint, using the one in the checkpoint.")
for k, v in model.model_options.items():
if k not in model_options:
log.warning(f"The model options {k} is not defined in input model_options, set to {v}.")
else:
deep_dict_difference(k, v, model_options)

return model


def deep_dict_difference(base_key, expected_value, model_options):
"""
递归地记录嵌套字典中的选项差异。

:param base_key: 基础键名,用于构建警告消息的前缀。
:param expected_value: 期望的值,可能是字典或非字典类型。
:param model_options: 用于比较的模型选项字典。
"""
target_dict= copy.deepcopy(model_options) # 防止修改原始字典
if isinstance(expected_value, dict):
for subk, subv in expected_value.items():
if subk not in target_dict.get(base_key, {}):
log.warning(f"The model option {subk} in {base_key} is not defined in input model_options, set to {subv}.")
else:
target2 = copy.deepcopy(target_dict[base_key])
deep_dict_difference(f"{subk}", subv, target2)
else:
if expected_value != target_dict[base_key]:
log.warning(f"The model option {base_key} is set to {expected_value}, but in input it is {target_dict[base_key]}, make sure it it correct!")
1 change: 1 addition & 0 deletions dptb/nn/deeptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def __init__(
transform=False,
)
self.idp = self.nnsk.idp
assert not self.nnsk.push, "The push option is not supported in the mixed model. The push option is only supported in the nnsk model."

self.model_options = self.nnsk.model_options
self.model_options.update(self.dptb.model_options)
Expand Down
17 changes: 11 additions & 6 deletions dptb/nn/nnsk.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
device: Union[str, torch.device] = torch.device("cpu"),
transform: bool = True,
freeze: bool = False,
push: Dict=None,
push: Union[bool,dict]=False,
std: float = 0.01,
**kwargs,
) -> None:
Expand Down Expand Up @@ -179,7 +179,7 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
# reflect_params[:,self.idp_sk.pair_maps[k],:] += params[:,self.idp_sk.pair_maps[k_r],:]
# self.strain_param.data = reflect_params + params

if self.push is not None:
if self.push is not None and self.push is not False:
if abs(self.push.get("rs_thr")) + abs(self.push.get("rc_thr")) + abs(self.push.get("w_thr")) > 0:
self.push_decay(**self.push)

Expand Down Expand Up @@ -305,7 +305,7 @@ def from_reference(
dtype: Union[str, torch.dtype]=None,
device: Union[str, torch.device]=None,
push: Dict=None,
freeze: bool = None,
freeze: bool = False,
std: float = 0.01,
**kwargs,
):
Expand All @@ -332,7 +332,8 @@ def from_reference(
for k,v in common_options.items():
assert v is not None, f"You need to provide {k} when you are initializing a model from a json file."
for k,v in nnsk.items():
assert v is not None, f"You need to provide {k} when you are initializing a model from a json file."
if k != 'push':
assert v is not None, f"You need to provide {k} when you are initializing a model from a json file."

v1_model = j_loader(checkpoint)
model = cls._from_model_v1(
Expand All @@ -349,7 +350,7 @@ def from_reference(
if v is None:
common_options[k] = f["config"]["common_options"][k]
for k,v in nnsk.items():
if v is None and not k is "push" :
if v is None and k != "push" :
nnsk[k] = f["config"]["model_options"]["nnsk"][k]

model = cls(**common_options, **nnsk)
Expand Down Expand Up @@ -458,6 +459,9 @@ def _from_model_v1(
dtype: Union[str, torch.dtype] = torch.float32,
device: Union[str, torch.device] = torch.device("cpu"),
std: float = 0.01,
freeze: bool = False,
push: Union[bool,None,dict] = False,
**kwargs
):
# could support json file and .pth file checkpoint of nnsk

Expand All @@ -477,7 +481,8 @@ def _from_model_v1(
idp_sk.get_orbpair_maps()
idp_sk.get_skonsite_maps()

nnsk_model = cls(basis=basis, idp_sk=idp_sk, dtype=dtype, device=device, onsite=onsite, hopping=hopping, overlap=overlap, std=std)
nnsk_model = cls(basis=basis, idp_sk=idp_sk, onsite=onsite,
hopping=hopping, overlap=overlap, std=std,freeze=freeze, push=push, dtype=dtype, device=device,)

onsite_param = v1_model["onsite"]
hopping_param = v1_model["hopping"]
Expand Down
10 changes: 10 additions & 0 deletions dptb/tests/data/json_model/AlAs.vasp
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Primitive Cell
1.000000
0.00000000000000 2.83790087700000 2.83790087700000
2.83790087700000 0.00000000000000 2.83790087700000
2.83790087700000 2.83790087700000 0.00000000000000
Al As
1 1
DIRECT
0.0000000000000000 0.0000000000000000 0.0000000000000000 Al1
0.7500000000000000 0.7500000000000000 0.7500000000000000 As1
Loading