diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index 4fcf1964a44d..e19db0659d9c 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -547,30 +547,16 @@ def _save_checkpoint(self, model, metrics=None): else: optim_state_dict = self.optimizer.state_dict() optim_state_dict.pop("LR_Scheduler", None) - + opt_state_keys = ["_moment1_0", "_moment2_0", "_beta1_pow_acc_0", "_beta2_pow_acc_0"] for p_name, p in model.state_dict().items(): if paddle.distributed.get_rank() not in p.process_mesh.process_ids: var_name = p.name - if ( - var_name + "_moment1_0" in optim_state_dict - and not optim_state_dict[var_name + "_moment1_0"].is_dist() - ): - optim_state_dict.pop(var_name + "_moment1_0") - if ( - var_name + "_moment2_0" in optim_state_dict - and not optim_state_dict[var_name + "_moment2_0"].is_dist() - ): - optim_state_dict.pop(var_name + "_moment2_0") - if ( - var_name + "_beta1_pow_acc_0" in optim_state_dict - and not optim_state_dict[var_name + "_beta1_pow_acc_0"].is_dist() - ): - optim_state_dict.pop(var_name + "_beta1_pow_acc_0") - if ( - var_name + "_beta2_pow_acc_0" in optim_state_dict - and not optim_state_dict[var_name + "_beta2_pow_acc_0"].is_dist() - ): - optim_state_dict.pop(var_name + "_beta2_pow_acc_0") + for key in opt_state_keys: + if ( + var_name + key in optim_state_dict + and not optim_state_dict[var_name + key].is_dist() + ): + optim_state_dict.pop(var_name + key) state_dict = { MODEL_NAME: model.state_dict(),