diff --git a/paddlenlp/generation/utils.py b/paddlenlp/generation/utils.py index 1e273c55481b..8f2ff2fda9fc 100644 --- a/paddlenlp/generation/utils.py +++ b/paddlenlp/generation/utils.py @@ -1072,7 +1072,7 @@ def greedy_search( probs = F.softmax(next_tokens_scores) probs = paddle.log(probs) next_tokens = paddle.argmax(probs, axis=-1).unsqueeze(-1) - next_scores = paddle.index_sample(probs.astype("float32"), next_tokens) + next_scores = paddle.index_sample(probs, next_tokens) if eos_token_id is not None: next_tokens = paddle.where(unfinished_flag, next_tokens, paddle.full_like(next_tokens, pad_token_id)) @@ -1171,12 +1171,8 @@ def sample( if top_p is not None and top_p < 1.0: probs = TopPProcess(probs, top_p, min_tokens_to_keep) - # multinomial not support fp16 and bf16 currently, issue: https://github.com/PaddlePaddle/Paddle/issues/51852 - if probs.dtype == paddle.bfloat16 and top_k == 1: - probs = probs.astype("float32") - next_tokens = paddle.unsqueeze(paddle.argmax(probs, axis=-1), -1) - else: - next_tokens = paddle.multinomial(probs) + # multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852 + next_tokens = paddle.multinomial(probs) if self.config.tensor_parallel_degree > 1: paddle.distributed.broadcast(next_tokens, 0)