Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: Fix multitask neighbor stat #3367

Merged
merged 1 commit into from
Feb 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import json
import logging
import os
from copy import (
deepcopy,
)
from pathlib import (
Path,
)
Expand Down Expand Up @@ -75,9 +78,11 @@
model_branch="",
force_load=False,
init_frz_model=None,
shared_links=None,
):
multi_task = "model_dict" in config.get("model", {})
# argcheck
if "model_dict" not in config.get("model", {}):
if not multi_task:
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config)

Expand All @@ -88,7 +93,6 @@
assert dist.is_nccl_available()
dist.init_process_group(backend="nccl")

multi_task = "model_dict" in config["model"]
ckpt = init_model if init_model is not None else restart_model
config["model"] = change_finetune_model_params(
ckpt,
Expand All @@ -98,9 +102,6 @@
model_branch=model_branch,
)
config["model"]["resuming"] = (finetune_model is not None) or (ckpt is not None)
shared_links = None
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])

def prepare_trainer_input_single(
model_params_single, data_dict_single, loss_dict_single, suffix=""
Expand Down Expand Up @@ -252,11 +253,33 @@
SummaryPrinter()()
with open(FLAGS.INPUT) as fin:
config = json.load(fin)

# update multitask config
multi_task = "model_dict" in config["model"]
shared_links = None
if multi_task:
config["model"], shared_links = preprocess_shared_params(config["model"])

Check warning on line 261 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L261

Added line #L261 was not covered by tests

# do neighbor stat
if not FLAGS.skip_neighbor_stat:
log.info(
"Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)"
)
config["model"] = BaseModel.update_sel(config, config["model"])
if not multi_task:
config["model"] = BaseModel.update_sel(config, config["model"])
else:
training_jdata = deepcopy(config["training"])
training_jdata.pop("data_dict", {})
training_jdata.pop("model_prob", {})
for model_item in config["model"]["model_dict"]:
fake_global_jdata = {

Check warning on line 275 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L271-L275

Added lines #L271 - L275 were not covered by tests
"model": deepcopy(config["model"]["model_dict"][model_item]),
"training": deepcopy(config["training"]["data_dict"][model_item]),
}
fake_global_jdata["training"].update(training_jdata)
config["model"]["model_dict"][model_item] = BaseModel.update_sel(

Check warning on line 280 in deepmd/pt/entrypoints/main.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L279-L280

Added lines #L279 - L280 were not covered by tests
fake_global_jdata, config["model"]["model_dict"][model_item]
)

trainer = get_trainer(
config,
Expand All @@ -266,6 +289,7 @@
FLAGS.model_branch,
FLAGS.force_load,
FLAGS.init_frz_model,
shared_links=shared_links,
)
trainer.run()

Expand Down