Skip to content

Commit

Permalink
[Distributed] metric calculation supports tp logits
Browse files Browse the repository at this point in the history
  • Loading branch information
SylarTiaNII committed May 10, 2024
1 parent 9146c1e commit 176891c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
6 changes: 3 additions & 3 deletions llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def main():
if not training_args.autotuner_benchmark:
model = AutoModelForCausalLMPipe.from_pretrained(
model_args.model_name_or_path,
tensor_parallel_output=False,
tensor_parallel_output=training_args.tensor_parallel_output,
tensor_parallel_degree=training_args.tensor_parallel_degree,
tensor_parallel_rank=training_args.tensor_parallel_rank,
use_flash_attention=model_args.use_flash_attention,
Expand All @@ -152,7 +152,7 @@ def main():
# NOTE(gongenlei): new add autotuner_benchmark
model_config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
tensor_parallel_output=False,
tensor_parallel_output=training_args.tensor_parallel_output,
tensor_parallel_degree=training_args.tensor_parallel_degree,
tensor_parallel_rank=training_args.tensor_parallel_rank,
dtype=dtype,
Expand All @@ -163,7 +163,7 @@ def main():
else:
model_config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
tensor_parallel_output=False,
tensor_parallel_output=training_args.tensor_parallel_output,
tensor_parallel_degree=training_args.tensor_parallel_degree,
tensor_parallel_rank=training_args.tensor_parallel_rank,
dtype=dtype,
Expand Down
7 changes: 7 additions & 0 deletions llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,13 @@ def prediction_step(
# keepdim in order to maintain the same shape as logits
if isinstance(logits, (list, tuple)):
logits = logits[0]
# all gather logits when enabling tensor_parallel_output
if self.args.tensor_parallel_degree > 1 and self.args.tensor_parallel_output:
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
gathered_logits = []
dist.all_gather(gathered_logits, logits, group=model_parallel_group)
logits = paddle.concat(gathered_logits, axis=-1)
return (loss, logits.argmax(axis=-1, keepdim=True), labels)

loss = None
Expand Down
4 changes: 4 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,10 @@ class TrainingArguments:
default=False,
metadata={"help": "whether to run distributed training in auto parallel mode"},
)
tensor_parallel_output: Optional[bool] = field(
default=False,
metadata={"help": "whether to output logits in distributed status"},
)

def __post_init__(self):
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
Expand Down

0 comments on commit 176891c

Please sign in to comment.