From 7c6471b54a611907149abd7ca20617def3f718a2 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 12 Sep 2024 21:23:42 -0400 Subject: [PATCH] [BugFix] fix group_topk (#8430) --- vllm/model_executor/layers/fused_moe/fused_moe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index bd13d8fecbb96..a0cb4337f9dee 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -410,6 +410,7 @@ def fused_topk( if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids @@ -443,7 +444,8 @@ def grouped_topk(hidden_states: torch.Tensor, if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids + + return topk_weights, topk_ids.to(torch.int32) def get_config_dtype_str(dtype: torch.dtype,