Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbo9674 committed Jul 25, 2024
1 parent a1b9580 commit 2d0f836
Showing 1 changed file with 7 additions and 21 deletions.
28 changes: 7 additions & 21 deletions paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,30 +547,16 @@ def _save_checkpoint(self, model, metrics=None):
else:
optim_state_dict = self.optimizer.state_dict()
optim_state_dict.pop("LR_Scheduler", None)

opt_state_keys = ["_moment1_0", "_moment2_0", "_beta1_pow_acc_0", "_beta2_pow_acc_0"]
for p_name, p in model.state_dict().items():
if paddle.distributed.get_rank() not in p.process_mesh.process_ids:
var_name = p.name
if (
var_name + "_moment1_0" in optim_state_dict
and not optim_state_dict[var_name + "_moment1_0"].is_dist()
):
optim_state_dict.pop(var_name + "_moment1_0")
if (
var_name + "_moment2_0" in optim_state_dict
and not optim_state_dict[var_name + "_moment2_0"].is_dist()
):
optim_state_dict.pop(var_name + "_moment2_0")
if (
var_name + "_beta1_pow_acc_0" in optim_state_dict
and not optim_state_dict[var_name + "_beta1_pow_acc_0"].is_dist()
):
optim_state_dict.pop(var_name + "_beta1_pow_acc_0")
if (
var_name + "_beta2_pow_acc_0" in optim_state_dict
and not optim_state_dict[var_name + "_beta2_pow_acc_0"].is_dist()
):
optim_state_dict.pop(var_name + "_beta2_pow_acc_0")
for key in opt_state_keys:
if (

Check warning on line 555 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L548-L555

Added lines #L548 - L555 were not covered by tests
var_name + key in optim_state_dict
and not optim_state_dict[var_name + key].is_dist()
):
optim_state_dict.pop(var_name + key)

Check warning on line 559 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L559

Added line #L559 was not covered by tests

state_dict = {

Check warning on line 561 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L561

Added line #L561 was not covered by tests
MODEL_NAME: model.state_dict(),
Expand Down

0 comments on commit 2d0f836

Please sign in to comment.