Skip to content

Commit

Permalink
update trainer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Dec 18, 2024
1 parent 90bc68e commit bc86896
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1904,9 +1904,14 @@ def _load_rng_state(self, checkpoint):
if "hybrid_parallel_rng_state_tracker" in checkpoint_rng_state:
if self.args.tensor_parallel_degree <= 1:
checkpoint_rng_state["hybrid_parallel_rng_state_tracker"].pop("model_parallel_rng", None)
fleet.meta_parallel.get_rng_state_tracker().set_states_tracker(
checkpoint_rng_state["hybrid_parallel_rng_state_tracker"]
)
try:
fleet.meta_parallel.get_rng_state_tracker().set_states_tracker(

Check warning on line 1908 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1907-L1908

Added lines #L1907 - L1908 were not covered by tests
checkpoint_rng_state["hybrid_parallel_rng_state_tracker"]
)
except:
logger.warning(

Check warning on line 1912 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1911-L1912

Added lines #L1911 - L1912 were not covered by tests
"Hybrid paralell rng states change when training environment differs, so we dot not set state tracker here."
)
else:
logger.warning("Not found hybrid parallel RNG state.")

Expand Down

0 comments on commit bc86896

Please sign in to comment.