Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update register_sequence_parallel_allreduce_hooks #9782

Merged
merged 9 commits into from
Jan 15, 2025
12 changes: 12 additions & 0 deletions paddlenlp/peft/lora/loraga_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
import paddle.distributed as dist
from paddle.distributed import fleet

try:
from paddle.distributed.fleet.utils.sequence_parallel_utils import (

Check warning on line 20 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L19-L20

Added lines #L19 - L20 were not covered by tests
register_sequence_parallel_allreduce_hooks,
)
except:
pass

Check warning on line 24 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L23-L24

Added lines #L23 - L24 were not covered by tests

from paddlenlp.peft import LoRAModel
from paddlenlp.peft.lora.lora_layers import (
ColumnParallelLoRALinear,
Expand Down Expand Up @@ -83,6 +90,11 @@
def _wrap_model(self, model):
"""Wrap Model without optimizer, support dp, tp and sharding"""

if self.args.tensor_parallel_degree > 1 and self.args.sequence_parallel:
register_sequence_parallel_allreduce_hooks(

Check warning on line 94 in paddlenlp/peft/lora/loraga_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/loraga_utils.py#L93-L94

Added lines #L93 - L94 were not covered by tests
model, self.args.gradient_accumulation_steps, self.args.fuse_sequence_parallel_allreduce
)

in_pipeline_parallel_mode = self.args.pipeline_parallel_degree > 1
in_sharding_parallel_mode = self.sharding is not None
in_tensor_parallel_mode = self.args.tensor_parallel_degree > 1
Expand Down
10 changes: 5 additions & 5 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,11 +434,6 @@
"We do not support skip_save_model_weight in peft model when using unified checkpoint, remove this config."
)

if args.sequence_parallel:
register_sequence_parallel_allreduce_hooks(
self.model, args.gradient_accumulation_steps, args.fuse_sequence_parallel_allreduce
)

self.do_grad_scaling = False
self.enable_autocast_context_manager = False
if args.fp16 or args.bf16:
Expand Down Expand Up @@ -2054,6 +2049,11 @@
else:
model, self.optimizer = decorated

if self.args.tensor_parallel_degree > 1 and self.args.sequence_parallel:
register_sequence_parallel_allreduce_hooks(

Check warning on line 2053 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2053

Added line #L2053 was not covered by tests
model, self.args.gradient_accumulation_steps, self.args.fuse_sequence_parallel_allreduce
)

if self.args.world_size == 1:
if self.args.amp_master_grad:
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype)
Expand Down