Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fine-tuning entries bug when doing restart. #3616

Merged
merged 1 commit into from
Mar 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 41 additions & 37 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,46 +558,50 @@
]
self.wrapper.load_state_dict(state_dict)

def single_model_finetune(
_model,
_model_params,
_sample_func,
):
old_type_map, new_type_map = (
_model_params["type_map"],
_model_params["new_type_map"],
)
if isinstance(_model, EnergyModel):
_model.change_out_bias(
_sample_func,
bias_adjust_mode=_model_params.get(
"bias_adjust_mode", "change-by-statistic"
),
origin_type_map=new_type_map,
full_type_map=old_type_map,
)
else:
# need to updated
pass
if finetune_model is not None:

Check warning on line 561 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L561

Added line #L561 was not covered by tests

# finetune
if not self.multi_task:
single_model_finetune(
self.model, model_params, self.get_sample_func
)
else:
for model_key in self.model_keys:
if model_key in self.finetune_links:
log.info(
f"Model branch {model_key} will be fine-tuned. This may take a long time..."
)
single_model_finetune(
self.model[model_key],
model_params["model_dict"][model_key],
self.get_sample_func[model_key],
def single_model_finetune(

Check warning on line 563 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L563

Added line #L563 was not covered by tests
_model,
_model_params,
_sample_func,
):
old_type_map, new_type_map = (

Check warning on line 568 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L568

Added line #L568 was not covered by tests
_model_params["type_map"],
_model_params["new_type_map"],
)
if isinstance(_model, EnergyModel):
_model.change_out_bias(

Check warning on line 573 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L572-L573

Added lines #L572 - L573 were not covered by tests
_sample_func,
bias_adjust_mode=_model_params.get(
"bias_adjust_mode", "change-by-statistic"
),
origin_type_map=new_type_map,
full_type_map=old_type_map,
)
else:
log.info(f"Model branch {model_key} will resume training.")
# need to updated
pass

Check warning on line 583 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L583

Added line #L583 was not covered by tests

# finetune
if not self.multi_task:
single_model_finetune(

Check warning on line 587 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L586-L587

Added lines #L586 - L587 were not covered by tests
self.model, model_params, self.get_sample_func
)
else:
for model_key in self.model_keys:
if model_key in self.finetune_links:
log.info(

Check warning on line 593 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L591-L593

Added lines #L591 - L593 were not covered by tests
f"Model branch {model_key} will be fine-tuned. This may take a long time..."
)
single_model_finetune(

Check warning on line 596 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L596

Added line #L596 was not covered by tests
self.model[model_key],
model_params["model_dict"][model_key],
self.get_sample_func[model_key],
)
else:
log.info(

Check warning on line 602 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L602

Added line #L602 was not covered by tests
f"Model branch {model_key} will resume training."
)

if init_frz_model is not None:
frz_model = torch.jit.load(init_frz_model, map_location=DEVICE)
Expand Down