Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

moe quantization support int8 and fp8 #702

Open
wants to merge 6 commits into
base: main_perf
Choose a base branch
from

Conversation

Chi-Chu319
Copy link

@Chi-Chu319 Chi-Chu319 commented Jan 17, 2025

MOE int8, fp8 quantization support

FP8_W8A8 Benchmark

M N K E top_k Time (ms) TFLOPS Bandwidth (GB/s)
64.0 256.0 128.0 8.0 2.0 0.100965 0.084315 3.464806
64.0 1792.0 1024.0 8.0 2.0 0.099205 4.799957 153.391078
64.0 7168.0 4096.0 8.0 2.0 0.237007 31.520107 992.352761
128.0 7168.0 4096.0 8.0 2.0 0.246640 60.403025 958.416176
1024.0 7168.0 4096.0 8.0 2.0 0.367420 323.815632 741.698139
4096.0 7168.0 4096.0 8.0 2.0 0.859841 557.812676 422.931343
64.0 14336.0 4096.0 8.0 2.0 0.315952 47.744237 1488.312183
128.0 14336.0 4096.0 8.0 2.0 0.325784 93.130120 1471.217250
256.0 14336.0 4096.0 8.0 2.0 0.366155 161.160900 1305.815807
512.0 14336.0 4096.0 8.0 2.0 0.419629 276.002783 1081.493338
1024.0 14336.0 4096.0 8.0 2.0 0.584223 410.424930 891.811228
2048.0 14336.0 4096.0 8.0 2.0 0.946157 511.846169 637.512235
4096.0 14336.0 4096.0 8.0 2.0 1.739206 533.736058 436.462257

Model Results:

Model M N K E top_k Time (ms) TFLOPS Bandwidth (GB/s)
mistral-7B 4096 14336 4096 8 2 1.704461 555.939516 418.153276
mistral-7B 4096 4096 7168 8 2 0.874459 551.993634 380.378852
mistral-22B 4096 16384 6144 8 2 2.839866 582.164950 377.339608
mistral-22B 4096 6144 8192 8 2 1.402399 590.175496 378.677584

INT8_W8A16 Benchmark

M N K E top_k Time (ms) TFLOPS Bandwidth (GB/s)
64.0 256.0 128.0 8.0 2.0 0.101910 0.083517 3.533078
64.0 1792.0 1024.0 8.0 2.0 0.099375 4.801227 156.784790
64.0 7168.0 4096.0 8.0 2.0 0.250642 29.825653 939.696064
128.0 7168.0 4096.0 8.0 2.0 0.251719 59.368124 967.526905
1024.0 7168.0 4096.0 8.0 2.0 0.459621 260.903004 613.914160
4096.0 7168.0 4096.0 8.0 2.0 1.211016 396.102303 316.464194
64.0 14336.0 4096.0 8.0 2.0 0.346470 43.002082 1380.794468
128.0 14336.0 4096.0 8.0 2.0 0.373660 82.176122 1401.441990
256.0 14336.0 4096.0 8.0 2.0 0.450378 148.267036 1187.187435
512.0 14336.0 4096.0 8.0 2.0 0.510054 226.883150 957.420567
1024.0 14336.0 4096.0 8.0 2.0 0.731200 323.922259 731.295148
2048.0 14336.0 4096.0 8.0 2.0 1.217325 391.847432 494.065574
4096.0 14336.0 4096.0 8.0 2.0 2.259041 426.662887 326.649746

Model Results:

Model M N K E top_k Time (ms) TFLOPS Bandwidth (GB/s)
mistral-7B 4096 14336 4096 8 2 2.243650 427.110245 327.295447
mistral-7B 4096 4096 7168 8 2 1.171337 406.360221 303.833718
mistral-22B 4096 16384 6144 8 2 3.817804 433.771044 293.303367
mistral-22B 4096 6144 8192 8 2 1.945052 423.720677 293.735136

Baseline performance:

M N K E top_k Time (ms) TFLOPS Bandwidth (GB/s)
64.0 256.0 128.0 8.0 2.0 0.098360 0.085805 6.343928
64.0 1792.0 1024.0 8.0 2.0 0.103364 4.751255 309.106483
64.0 7168.0 4096.0 8.0 2.0 0.286658 26.329062 1649.914215
128.0 7168.0 4096.0 8.0 2.0 0.295348 50.963734 1615.952538
1024.0 7168.0 4096.0 8.0 2.0 0.661343 183.451017 769.860934
4096.0 7168.0 4096.0 8.0 2.0 1.305880 366.719310 472.501474
64.0 14336.0 4096.0 8.0 2.0 0.441543 33.773158 2129.851879
128.0 14336.0 4096.0 8.0 2.0 0.436356 68.504384 2135.694681
256.0 14336.0 4096.0 8.0 2.0 0.510952 116.823851 1846.483793
512.0 14336.0 4096.0 8.0 2.0 0.589875 214.913268 1629.424089
1024.0 14336.0 4096.0 8.0 2.0 0.999347 247.870872 1058.416671
2048.0 14336.0 4096.0 8.0 2.0 1.373657 369.402008 785.701890
4096.0 14336.0 4096.0 8.0 2.0 2.329754 417.086039 530.091024

Baseline model performance:

Model M N K E top_k Time (ms) TFLOPS Bandwidth (GB/s)
mistral-7B 4096 14336 4096 8 2 2.291119 405.266958 519.560477
mistral-7B 4096 4096 7168 8 2 1.226907 390.427737 490.864256
mistral-22B 4096 16384 6144 8 2 4.022803 415.061296 483.330016
mistral-22B 4096 6144 8192 8 2 1.986073 410.751339 481.708516
  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

Sorry, something went wrong.

@Chi-Chu319 Chi-Chu319 self-assigned this Jan 17, 2025
@Chi-Chu319 Chi-Chu319 marked this pull request as ready for review January 17, 2025 09:11
@Chi-Chu319 Chi-Chu319 requested a review from zhanglx13 January 21, 2025 17:21
tensor = tensor * scale
tensor = tensor.round_()
tensor.clamp_(-max_repr_val, max_repr_val)
tensor_quantized = tensor.to(torch.int8)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is returning the quantized tensor as int8 but the dtype can be fp8 as well right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right... now I fixed it. Apparently the torch test passed because they are using the same quantized input

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also great thanks for spotting it!! Otherwise my code could be terribly wrong

@Chi-Chu319 Chi-Chu319 requested a review from vgokhale January 24, 2025 09:10
Copy link
Collaborator

@vgokhale vgokhale left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@vgokhale
Copy link
Collaborator

@zhanglx13 would it be possible to review this? A customer is asking for it...

max_vals[max_vals == 0] = 1e-8

# Compute scale factors for each channel
scale: torch.Tensor = max_repr_val / max_vals.to(torch.float32)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for a tensor of shape M x K, what is the shape of the scale?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tensor shapes

    a = torch.randn((M, K), dtype=dtype, device='cuda')
    b = torch.randn((E, N, K), dtype=dtype, device='cuda')

In the case of fp8_w8a8:

  • a_descale is a scalar
  • b_descale is (E, ), per expert

in the case of use_int8_w8a16:

  • b_quantized is (E, N), per expert and per n

@rasmith
Copy link

rasmith commented Feb 12, 2025

@vgokhale @Chi-Chu319 How long does it take to run these unit tests? Upstream vLLM usually asks for shorter unit tests because of CI. Would it be possible to get a set of parameterizations that can be ran in upstream vLLM but won't take a long time?

@vgokhale
Copy link
Collaborator

@vgokhale @Chi-Chu319 How long does it take to run these unit tests? Upstream vLLM usually asks for shorter unit tests because of CI. Would it be possible to get a set of parameterizations that can be ran in upstream vLLM but won't take a long time?

@rasmith I don't know. How long can we take in the CI? This is just the regular MoE kernel adapted to support int8. You can change use the same set of UTs as the current vllm MoE kernel.

@rasmith
Copy link

rasmith commented Feb 12, 2025

@vgokhale @Chi-Chu319 How long does it take to run these unit tests? Upstream vLLM usually asks for shorter unit tests because of CI. Would it be possible to get a set of parameterizations that can be ran in upstream vLLM but won't take a long time?

@rasmith I don't know. How long can we take in the CI? This is just the regular MoE kernel adapted to support int8. You can change use the same set of UTs as the current vllm MoE kernel.

30 seconds according to upstream

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants