Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: support multitask dp test #3573

Merged
merged 1 commit into from
Mar 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
*args: List[Any],
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
head: Optional[str] = None,
**kwargs: Dict[str, Any],
):
self.output_def = output_def
Expand All @@ -99,9 +100,24 @@
if "model" in state_dict:
state_dict = state_dict["model"]
self.input_param = state_dict["_extra_state"]["model_params"]
self.input_param["resuming"] = True
self.multi_task = "model_dict" in self.input_param
assert not self.multi_task, "multitask mode currently not supported!"
if self.multi_task:
model_keys = list(self.input_param["model_dict"].keys())
assert (

Check warning on line 106 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L105-L106

Added lines #L105 - L106 were not covered by tests
head is not None
), f"Head must be set for multitask model! Available heads are: {model_keys}"
assert (

Check warning on line 109 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L109

Added line #L109 was not covered by tests
head in model_keys
), f"No head named {head} in model! Available heads are: {model_keys}"
self.input_param = self.input_param["model_dict"][head]
state_dict_head = {"_extra_state": state_dict["_extra_state"]}
for item in state_dict:
if f"model.{head}." in item:
state_dict_head[

Check warning on line 116 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L112-L116

Added lines #L112 - L116 were not covered by tests
item.replace(f"model.{head}.", "model.Default.")
] = state_dict[item].clone()
state_dict = state_dict_head

Check warning on line 119 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L119

Added line #L119 was not covered by tests
self.input_param["resuming"] = True
model = get_model(self.input_param).to(DEVICE)
model = torch.jit.script(model)
self.dp = ModelWrapper(model)
Expand Down