-
Notifications
You must be signed in to change notification settings - Fork 720
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance] Torch.compile is slow on MoE layers when bs > 1 #2278
Comments
@merrymercy This is just because in gpt-fast we can fuse the |
@merrymercy I took a look, and a couple observations. I could reproduce your results (more or less), with
I've also noticed that for some reason, our performance on the second matmul ( |
Also, it seems like the reason why Updating the code to use
|
My code can be found here btw: https://pastebin.com/iuzxqJrL |
@Chillee @yanboliang Thank you so much for the quick help! I can reproduce your results and I fixed the benchmark scripts in #2327. However, there is still something that I cannot understand.
@torch.compile(dynamic=False)
def topk_decomposed(x, k):
mx_vals = []
mx_inds = []
for i in range(k):
mx_val, mx_ind = torch.max(x, dim=-1, keepdim=True)
mx_vals.append(mx_val)
mx_inds.append(mx_ind)
x = x.scatter(1, mx_ind, float("-inf"))
return (torch.cat(mx_vals, dim=1),
torch.cat(mx_inds, dim=1)) So my final solution in #2327 is using |
Is the issue with my original implementation that it has incorrect results when two elements are identical? If so, I think this alternate implementation should be "correct" and also have good perf.
I'm not sure about why the perf isn't that good for BS=2 but is better in the microbenchmark. If you have an easy setup for me to reproduce I can take a look :) |
Currently, sglang implements a torch-native MoE forward based on the MoE implementation from gpt-fast. However, this implementation + torch.compile only works relatively well for batch size = 1. It is slower than the non-compiled version when bs > 1, and sometimes even when bs = 1.
Reproduce
Use this script benchmark_torch_compile_fused_moe.py. Note: The
--tp 8
option is only used for shape derivation. You only need one GPU to execute it.Results:
You can see that torch.compile scales very badly with the batch size.
The text was updated successfully, but these errors were encountered: