From fd21917083f36df0ecaf3f938f14527b2faebcf9 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Mon, 4 Dec 2023 07:23:01 +0000 Subject: [PATCH 1/2] dtype should same with scores. --- paddlenlp/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/generation/utils.py b/paddlenlp/generation/utils.py index 1e273c55481b..6d71b2079571 100644 --- a/paddlenlp/generation/utils.py +++ b/paddlenlp/generation/utils.py @@ -515,7 +515,7 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder def update_scores_for_generation(scores, next_scores, length, unfinished_flag): # update scores - unfinished_scores = (scores * length + next_scores) / (length + 1) + unfinished_scores = (scores * length + next_scores.astype(scores.dtype)) / (length + 1) scores = paddle.where(unfinished_flag, unfinished_scores, scores) return scores From 57c6136c3e2709dd62ea50697bc42d564436ca3f Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Mon, 4 Dec 2023 08:50:34 +0000 Subject: [PATCH 2/2] fix --- paddlenlp/generation/utils.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/paddlenlp/generation/utils.py b/paddlenlp/generation/utils.py index 6d71b2079571..8f2ff2fda9fc 100644 --- a/paddlenlp/generation/utils.py +++ b/paddlenlp/generation/utils.py @@ -515,7 +515,7 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder def update_scores_for_generation(scores, next_scores, length, unfinished_flag): # update scores - unfinished_scores = (scores * length + next_scores.astype(scores.dtype)) / (length + 1) + unfinished_scores = (scores * length + next_scores) / (length + 1) scores = paddle.where(unfinished_flag, unfinished_scores, scores) return scores @@ -1072,7 +1072,7 @@ def greedy_search( probs = F.softmax(next_tokens_scores) probs = paddle.log(probs) next_tokens = paddle.argmax(probs, axis=-1).unsqueeze(-1) - next_scores = paddle.index_sample(probs.astype("float32"), next_tokens) + next_scores = paddle.index_sample(probs, next_tokens) if eos_token_id is not None: next_tokens = paddle.where(unfinished_flag, next_tokens, paddle.full_like(next_tokens, pad_token_id)) @@ -1171,12 +1171,8 @@ def sample( if top_p is not None and top_p < 1.0: probs = TopPProcess(probs, top_p, min_tokens_to_keep) - # multinomial not support fp16 and bf16 currently, issue: https://github.com/PaddlePaddle/Paddle/issues/51852 - if probs.dtype == paddle.bfloat16 and top_k == 1: - probs = probs.astype("float32") - next_tokens = paddle.unsqueeze(paddle.argmax(probs, axis=-1), -1) - else: - next_tokens = paddle.multinomial(probs) + # multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852 + next_tokens = paddle.multinomial(probs) if self.config.tensor_parallel_degree > 1: paddle.distributed.broadcast(next_tokens, 0)