Skip to content

Commit

Permalink
Polish sequence parallel to improve performance (#861)
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy authored Nov 1, 2022
1 parent 5eceb07 commit cb2b926
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
11 changes: 7 additions & 4 deletions ppfleetx/models/language_model/gpt/dygraph/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,6 @@ def forward(self,

if self.sequence_parallel:
encoder_outputs = GatherOp.apply(encoder_outputs)
encoder_outputs = paddle.transpose(encoder_outputs, [1, 0, 2])

return encoder_outputs

Expand Down Expand Up @@ -851,10 +850,11 @@ class GPTPretrainingCriterionHybird(nn.Layer):
Criterion for GPT. It calculates the final loss.
"""

def __init__(self, topo=None):
def __init__(self, topo=None, sequence_parallel=False):
super(GPTPretrainingCriterionHybird, self).__init__()
self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none")
self.parallel_loss_func = fleet.meta_parallel.ParallelCrossEntropy()
self.sequence_parallel = sequence_parallel

def forward(self, prediction_scores, masked_lm_labels, loss_mask):
"""
Expand All @@ -877,6 +877,9 @@ def forward(self, prediction_scores, masked_lm_labels, loss_mask):
"""
hcg = fleet.get_hybrid_communicate_group()
mp_size = hcg.get_model_parallel_world_size()
if self.sequence_parallel:
masked_lm_labels = masked_lm_labels.transpose([1, 0])
loss_mask = loss_mask.transpose([1, 0])
if mp_size > 1:
masked_lm_loss = self.parallel_loss_func(
prediction_scores, masked_lm_labels.unsqueeze(2))
Expand Down Expand Up @@ -943,7 +946,6 @@ def forward(self, input):
output = self.norm(input)
if self.sequence_parallel and self.is_last:
output = GatherOp.apply(output)
output = paddle.transpose(output, [1, 0, 2])
return output


Expand Down Expand Up @@ -1066,7 +1068,8 @@ def _logits_helper(embedding, output):

super().__init__(
layers=self.descs,
loss_fn=GPTPretrainingCriterionPipe(),
loss_fn=GPTPretrainingCriterionPipe(
sequence_parallel=sequence_parallel),
topology=fleet.get_hybrid_communicate_group().topology(),
seg_method="layer:TransformerDecoderLayer",
recompute_interval=recompute_interval,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,15 @@ def scatter(input):
group = hcg.get_model_parallel_group()
parallelism = group.nranks
rank = group.rank
assert input.shape[
0] % parallelism == 0, "Input sequence length {} can't be divided exactly by sequence parallelism {}".format(
input.shape[0], parallelism)
input = paddle.split(input, num_or_sections=parallelism, axis=0)[rank]
seq_len = input.shape[0]
assert seq_len % parallelism == 0, "Input sequence length {} can't be divided exactly by sequence parallelism {}".format(
seq_len, parallelism)
interval = seq_len // parallelism
input = paddle.slice(
input,
axes=[0],
starts=[interval * rank],
ends=[interval * (rank + 1)])
return input


Expand Down
3 changes: 2 additions & 1 deletion ppfleetx/models/language_model/language_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ def get_loss_fn(self):
if self.nranks == 1:
loss_fn = gpt.GPTPretrainingCriterion()
else:
loss_fn = gpt.GPTPretrainingCriterionHybird()
loss_fn = gpt.GPTPretrainingCriterionHybird(
sequence_parallel=self.configs.Model.sequence_parallel)
return loss_fn

def pretreating_batch(self, batch):
Expand Down

0 comments on commit cb2b926

Please sign in to comment.