diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index d01f25be80..ff1c350f47 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -558,46 +558,50 @@ def update_single_finetune_params( ] self.wrapper.load_state_dict(state_dict) - def single_model_finetune( - _model, - _model_params, - _sample_func, - ): - old_type_map, new_type_map = ( - _model_params["type_map"], - _model_params["new_type_map"], - ) - if isinstance(_model, EnergyModel): - _model.change_out_bias( - _sample_func, - bias_adjust_mode=_model_params.get( - "bias_adjust_mode", "change-by-statistic" - ), - origin_type_map=new_type_map, - full_type_map=old_type_map, - ) - else: - # need to updated - pass + if finetune_model is not None: - # finetune - if not self.multi_task: - single_model_finetune( - self.model, model_params, self.get_sample_func - ) - else: - for model_key in self.model_keys: - if model_key in self.finetune_links: - log.info( - f"Model branch {model_key} will be fine-tuned. This may take a long time..." - ) - single_model_finetune( - self.model[model_key], - model_params["model_dict"][model_key], - self.get_sample_func[model_key], + def single_model_finetune( + _model, + _model_params, + _sample_func, + ): + old_type_map, new_type_map = ( + _model_params["type_map"], + _model_params["new_type_map"], + ) + if isinstance(_model, EnergyModel): + _model.change_out_bias( + _sample_func, + bias_adjust_mode=_model_params.get( + "bias_adjust_mode", "change-by-statistic" + ), + origin_type_map=new_type_map, + full_type_map=old_type_map, ) else: - log.info(f"Model branch {model_key} will resume training.") + # need to updated + pass + + # finetune + if not self.multi_task: + single_model_finetune( + self.model, model_params, self.get_sample_func + ) + else: + for model_key in self.model_keys: + if model_key in self.finetune_links: + log.info( + f"Model branch {model_key} will be fine-tuned. This may take a long time..." + ) + single_model_finetune( + self.model[model_key], + model_params["model_dict"][model_key], + self.get_sample_func[model_key], + ) + else: + log.info( + f"Model branch {model_key} will resume training." + ) if init_frz_model is not None: frz_model = torch.jit.load(init_frz_model, map_location=DEVICE)