Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pt): add finetune_head to argcheck #3967

Merged
merged 3 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,10 @@
doc_spin = "The settings for systems with spin."
doc_atom_exclude_types = "Exclude the atomic contribution of the listed atom types"
doc_pair_exclude_types = "The atom pairs of the listed types are not treated to be neighbors, i.e. they do not see each other."
doc_finetune_head = (

Check warning on line 1536 in deepmd/utils/argcheck.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/argcheck.py#L1536

Added line #L1536 was not covered by tests
"The chosen fitting net to fine-tune on, when doing multi-task fine-tuning. "
"If not set or set to 'RANDOM', the fitting net will be randomly initialized."
)

hybrid_models = []
if not exclude_hybrid:
Expand Down Expand Up @@ -1629,6 +1633,12 @@
fold_subdoc=True,
),
Argument("spin", dict, spin_args(), [], optional=True, doc=doc_spin),
Argument(
"finetune_head",
str,
optional=True,
doc=doc_only_pt_supported + doc_finetune_head,
),
],
[
Variant(
Expand Down
1 change: 0 additions & 1 deletion source/tests/pt/model/water/multitask.json
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
"_comment": "that's all"
},
"loss_dict": {
"_comment": " that's all",
"model_1": {
"type": "ener",
"start_pref_e": 0.02,
Expand Down
10 changes: 10 additions & 0 deletions source/tests/pt/test_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)
from deepmd.utils.argcheck import (
normalize,
)
from deepmd.utils.compat import (
update_deepmd_input,
)

from .model.test_permutation import (
model_dpa1,
Expand All @@ -39,6 +45,8 @@ def setUpModule():
class MultiTaskTrainTest:
def test_multitask_train(self):
# test multitask training
self.config = update_deepmd_input(self.config, warning=True)
self.config = normalize(self.config, multi_task=True)
trainer = get_trainer(deepcopy(self.config), shared_links=self.shared_links)
trainer.run()
# check model keys
Expand Down Expand Up @@ -124,6 +132,8 @@ def test_multitask_train(self):
finetune_model,
self.origin_config["model"],
)
self.origin_config = update_deepmd_input(self.origin_config, warning=True)
self.origin_config = normalize(self.origin_config, multi_task=True)
trainer_finetune = get_trainer(
deepcopy(self.origin_config),
finetune_model=finetune_model,
Expand Down