Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Torch.compile is slow on MoE layers when bs > 1 #2278

Open
merrymercy opened this issue Nov 30, 2024 · 6 comments
Open

[Performance] Torch.compile is slow on MoE layers when bs > 1 #2278

merrymercy opened this issue Nov 30, 2024 · 6 comments

Comments

@merrymercy
Copy link
Contributor

merrymercy commented Nov 30, 2024

Currently, sglang implements a torch-native MoE forward based on the MoE implementation from gpt-fast. However, this implementation + torch.compile only works relatively well for batch size = 1. It is slower than the non-compiled version when bs > 1, and sometimes even when bs = 1.

Reproduce

Use this script benchmark_torch_compile_fused_moe.py. Note: The --tp 8 option is only used for shape derivation. You only need one GPU to execute it.

python3 benchmark_torch_compile_fused_moe.py --model mistralai/Mixtral-8x7B-Instruct-v0.1  --tp 8

Results:
You can see that torch.compile scales very badly with the batch size.

fused-moe-performance:
   batch_size  fused_moe_triton  fused_moe_torch_compile
         1.0          0.169472                 0.220960
         2.0          0.168448                 0.272512
         3.0          0.171456                 0.381344
         4.0          0.171520                 0.486848
@yanboliang
Copy link

yanboliang commented Dec 2, 2024

@merrymercy This is just because in gpt-fast we can fuse the gather + gemv into one kernel when bs = 1, we didn't optimize bs > 1. Check more implementation deatils at: https://www.thonking.ai/p/short-supporting-mixtral-in-gpt-fast

@Chillee
Copy link

Chillee commented Dec 3, 2024

@merrymercy I took a look, and a couple observations. I could reproduce your results (more or less), with python --model mistralai/Mixtral-8x7B-Instruct-v0.1 --tp 8

fused-moe-performance:
   batch_size  fused_moe_triton  fused_moe_torch_compile
0         1.0          0.108928                 0.168912
1         2.0          0.133792                 0.308848
2         3.0          0.152512                 0.445648
3         4.0          0.169344                 0.574368 
  1. Coordinate descent tuning isn't turned on. This prevents us from fusing the kernels together (not super great but we can't reliably generate faster kernels for gemv except in this case). However, turning that on (torch._inductor.config.coordinate_descent_tuning = True) makes performance much better, and better at BS=1.
   batch_size  fused_moe_triton  fused_moe_torch_compile
0         1.0          0.106528                 0.080672
1         2.0          0.133520                 0.126816
2         3.0          0.151360                 0.175296
3         4.0          0.168832                 0.223424
  1. This benchmark isn't just for the grouped gemm, it's also for the routing logic, which isn't negligible. Regularly, we don't fuse topk, but for topk=2, we can actually write it in a way that we can fuse. (nit: I haven't actually tested correctness, so use at your own risk haha).
def topk_decomposed(x, k=2):
    mx_vals, mx_inds = [], []
    cur_x = x
    for _ in range(k):
        mx_val, mx_ind = torch.max(cur_x, dim=-1, keepdim=True)
        mx_vals.append(mx_val)
        mx_inds.append(mx_ind)
        cur_x = torch.where(cur_x == mx_val, -float("inf"), cur_x)
    return torch.cat(mx_vals, dim=1), torch.cat(mx_inds, dim=1)
   batch_size  fused_moe_triton  fused_moe_torch_compile
0         1.0          0.104192                 0.070368
1         2.0          0.132992                 0.115040
2         3.0          0.151136                 0.163680
3         4.0          0.169184                 0.211296
  1. Automatic dynamic shapes is actually triggering here. When we recompile because of a shape change, we start compiling with dynamic shapes. By changing it to torch.compile(dynamic=False), we can retune for each shape.
fused-moe-performance:
   batch_size  fused_moe_triton  fused_moe_torch_compile
0         1.0          0.111744                 0.070656
1         2.0          0.132192                 0.116640
2         3.0          0.150528                 0.164256
3         4.0          0.167552                 0.201312

I've also noticed that for some reason, our performance on the second matmul (w2) is a lot worse. I'm not totally sure what the cause is, but I'll look into it.

@Chillee
Copy link

Chillee commented Dec 3, 2024

Also, it seems like the reason why w2 is a lot worse is that we're fusing gelu into the prologue (and gelu is a fairly expensive op). First of all, is gelu even correct here? Looking into the baseline code, it seems like it's calling silu right? https://github.com/sgl-project/sglang/blob/85e1a6f3aa5a2288ca85fe3fe922c733b6533fa7/python/sglang/srt/layers/fused_moe_triton/fused_moe.py#L748C13-L748C18

Updating the code to use silu, I now get

fused-moe-performance:
   batch_size  fused_moe_triton  fused_moe_torch_compile
0         1.0          0.108208                 0.063328
1         2.0          0.133088                 0.101248
2         3.0          0.152288                 0.139296
3         4.0          0.168736                 0.168992

@Chillee
Copy link

Chillee commented Dec 3, 2024

My code can be found here btw: https://pastebin.com/iuzxqJrL

@merrymercy
Copy link
Contributor Author

@Chillee @yanboliang Thank you so much for the quick help!

I can reproduce your results and I fixed the benchmark scripts in #2327. However, there is still something that I cannot understand.

  • When bs>1, torch.compile is faster on this single-layer benchmark with random inputs/weights. But when I plug in that in an end-to-end test, it is slower. Not sure whether this is related to the token imbalance problems. When bs>1, it turns out that fastest solution is doing torch.compile for all other layers but skipping the FusedMoE layer.
  • The topk_decomposed brings significant speedup on a model when bs=2. However, it produces wrong results. This is not easy to reproduce. I tried another decomposition method which gives me correct results but is much slower.
@torch.compile(dynamic=False)
def topk_decomposed(x, k):
    mx_vals = []
    mx_inds = []

    for i in range(k):
        mx_val, mx_ind = torch.max(x, dim=-1, keepdim=True)
        mx_vals.append(mx_val)
        mx_inds.append(mx_ind)
        x = x.scatter(1, mx_ind, float("-inf"))

    return (torch.cat(mx_vals, dim=1),
            torch.cat(mx_inds, dim=1))

So my final solution in #2327 is using torch.compile for the FusedMoE layer with torch.topk when the batch size (bs) is 1, and skipping it when bs > 1. I am happy with the final performance, so I will stop here and won't try to understand every mysterious aspect.

@Chillee
Copy link

Chillee commented Dec 5, 2024

The topk_decomposed brings significant speedup on a model when bs=2. However, it produces wrong results. This is not easy to reproduce. I tried another decomposition method which gives me correct results but is much slower.

Is the issue with my original implementation that it has incorrect results when two elements are identical? If so, I think this alternate implementation should be "correct" and also have good perf.

def topk_decomposed(x, k=2):
    mx_vals, mx_inds = [], []
    cur_x = x
    for _ in range(k):
        mx_val, mx_ind = torch.max(cur_x, dim=-1, keepdim=True)
        mx_vals.append(mx_val)
        mx_inds.append(mx_ind)
        cur_x = torch.where(torch.arange(cur_x.shape[-1], device=cur_x.device) == mx_ind, -float("inf"), cur_x)
    return torch.cat(mx_vals, dim=1), torch.cat(mx_inds, dim=1)

I'm not sure about why the perf isn't that good for BS=2 but is better in the microbenchmark. If you have an easy setup for me to reproduce I can take a look :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants