Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove astype for probs, all dtype should same with scores. #7577

Merged
merged 2 commits into from
Dec 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions paddlenlp/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ def greedy_search(
probs = F.softmax(next_tokens_scores)
probs = paddle.log(probs)
next_tokens = paddle.argmax(probs, axis=-1).unsqueeze(-1)
next_scores = paddle.index_sample(probs.astype("float32"), next_tokens)
next_scores = paddle.index_sample(probs, next_tokens)

if eos_token_id is not None:
next_tokens = paddle.where(unfinished_flag, next_tokens, paddle.full_like(next_tokens, pad_token_id))
Expand Down Expand Up @@ -1171,12 +1171,8 @@ def sample(
if top_p is not None and top_p < 1.0:
probs = TopPProcess(probs, top_p, min_tokens_to_keep)

# multinomial not support fp16 and bf16 currently, issue: https://github.com/PaddlePaddle/Paddle/issues/51852
if probs.dtype == paddle.bfloat16 and top_k == 1:
probs = probs.astype("float32")
next_tokens = paddle.unsqueeze(paddle.argmax(probs, axis=-1), -1)
else:
next_tokens = paddle.multinomial(probs)
# multinomial already support fp16 and bf16 currently, fix issue: https://github.com/PaddlePaddle/Paddle/issues/51852
next_tokens = paddle.multinomial(probs)

if self.config.tensor_parallel_degree > 1:
paddle.distributed.broadcast(next_tokens, 0)
Expand Down