diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d37837a0b2ce8..b4f81527141a8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -433,6 +433,8 @@ def fused_moe( sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( topk_ids, config['BLOCK_SIZE_M'], E) + compute_type = (tl.bfloat16 + if hidden_states.dtype == torch.bfloat16 else tl.float16) invoke_fused_moe_kernel(hidden_states, w1, @@ -447,7 +449,7 @@ def fused_moe( False, topk_ids.shape[1], config, - compute_type=tl.float16, + compute_type=compute_type, use_fp8=use_fp8) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -465,7 +467,7 @@ def fused_moe( True, 1, config, - compute_type=tl.float16, + compute_type=compute_type, use_fp8=use_fp8) if inplace: