Skip to content

Commit

Permalink
[Bugfix][Kernel] Fix compute_type for MoE kernel (vllm-project#4463)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Apr 30, 2024
1 parent d627a3d commit fa32207
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,8 @@ def fused_moe(

sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E)
compute_type = (tl.bfloat16
if hidden_states.dtype == torch.bfloat16 else tl.float16)

invoke_fused_moe_kernel(hidden_states,
w1,
Expand All @@ -447,7 +449,7 @@ def fused_moe(
False,
topk_ids.shape[1],
config,
compute_type=tl.float16,
compute_type=compute_type,
use_fp8=use_fp8)

ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
Expand All @@ -465,7 +467,7 @@ def fused_moe(
True,
1,
config,
compute_type=tl.float16,
compute_type=compute_type,
use_fp8=use_fp8)

if inplace:
Expand Down

0 comments on commit fa32207

Please sign in to comment.