Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Add w8a8 CUTLASS kernels #4749

Merged

Conversation

tlrmchlsmth
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth commented May 10, 2024

This PR adds fp8_e4m3fn and int8 GEMM kernels, using NVIDIA CUTLASS and unit tests for them. The kernels are not used in this present PR, but are planned to be used in #4525.

The main contributions of this PR is the function cutlass_scaled_mm_dq:

  • Supports symmetric quantized activations and weights
  • The activations may be either per-tensor or per-token
  • The weights may be either per-tensor or per output channel
  • int8 is supported on Turing, Ampere, Lovelace, or Hopper
  • fp8_e4m3 is supported on Ampere or Lovelace.
  • Outputs can be either bfloat16 or fp16.

PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@jeejeelee
Copy link
Contributor

Great work~

Has the vllm community begun integrating Cutlass? Is this PR part of the official roadmap?

Additionally, For the integration of Cutlass, is it based on the python module(#4525) or the method outlined in your PR?

@tlrmchlsmth
Copy link
Collaborator Author

tlrmchlsmth commented May 13, 2024

Thanks @jeejeelee -- this PR is part of a larger project to add support for w8a8 quantization (which is on the Q2 roadmap #3861). We ran into several issues with the Python interface in #4525, and it's really not supposed to be used this way, so we plan to replace the python cutlass code with these C++ kernels.

The main reason for using CUTLASS here is its ability to do operator fusion via its epilogue operations. For int8 quantization, especially in the asymmetric case, there are a variety of small operations that we'd like to fuse onto GEMMs to avoid the cost of sweeping over the outputs multiple times (see #3975).

@tlrmchlsmth tlrmchlsmth marked this pull request as ready for review May 13, 2024 15:32
@tlrmchlsmth
Copy link
Collaborator Author

@pcmoritz @comaniac There are a couple of issues to iron out still (CMakeLists changes and kernel dispatching for sure) but this should be ready to look at.

@youkaichao do you have any advice on how to handle the SM90a issues? (I know you were looking into this -- unfortunate that pytorch/pytorch@6e99f73 didn't make it into 2.3)

@jeejeelee
Copy link
Contributor

Thank you for your patient explanation. May I ask another question?

Why isn't SM75 supported? We should be able to utilize the m8n8k16

CMakeLists.txt Outdated Show resolved Hide resolved
@tlrmchlsmth
Copy link
Collaborator Author

Thank you for your patient explanation. May I ask another question?

Why isn't SM75 supported? We should be able to utilize the m8n8k16

I'll grab a T4 and see if I can get it working there

@tlrmchlsmth
Copy link
Collaborator Author

@jeejeelee I just added SM75 support as well. I didn't spent a ton of time tuning it but it's maybe 50% faster than fp16 GEMM

vllm/_custom_ops.py Outdated Show resolved Hide resolved
vllm/_custom_ops.py Outdated Show resolved Hide resolved
@pcmoritz
Copy link
Collaborator

Btw, while I was trying out this PR, I got the following error:

import torch
from vllm import _custom_ops as ops

A = torch.randn(8, 4096, dtype=torch.float16, device="cuda")
B = torch.randn(4096, 8192, dtype=torch.float16, device="cuda")

A *= 500
B *= 500

def per_tensor_quantize(tensor: torch.Tensor,
                       inv_scale: float) -> torch.Tensor:
   finfo = torch.finfo(torch.float8_e4m3fn)
   qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
   return qweight.to(torch.float8_e4m3fn)

A_scale = 448.0 / A.abs().max()
B_scale = 448.0 / B.abs().max()

Aquant = per_tensor_quantize(A, 1.0 / A_scale)
Bquant = per_tensor_quantize(B, 1.0 / B_scale)

scale_a = A_scale * torch.ones((1, 1), device="cuda")
scale_b = B_scale * torch.ones((1, 1), device="cuda")

out = ops.cutlass_scaled_mm_dq(Aquant, Bquant, scale_a, scale_b, out_dtype=torch.float16)

out

I'm getting

RuntimeError: CUDA error: misaligned address
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Can you have a look if you know what is happening here?

@pcmoritz
Copy link
Collaborator

pcmoritz commented May 15, 2024

Ah, I think I know -- I didn't transpose B appropriately / it was not in column major order :)

Can you add a check in cutlass_scaled_mm_dq to make sure the dimensions are compatible and the matrices are in the right format?

Otherwise the PR looks good to me! Happy to stamp after the above comments are addressed :)

@tlrmchlsmth
Copy link
Collaborator Author

I'll add some asserts :)

@pcmoritz
Copy link
Collaborator

Thanks! We should also assert that the tensors are contiguous :)

@pcmoritz
Copy link
Collaborator

Thanks for the fixes, I have a few more comments!

As a mental picture, it should never be possible to crash the python interpreter from python code. Asserts in the C++ level should only be used for consistency checks with previously already established invariants, never for input validations :)

@pcmoritz
Copy link
Collaborator

Otherwise the PR looks good to me now :)

@tlrmchlsmth
Copy link
Collaborator Author

Should be ready now, thanks! @pcmoritz

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic merged commit 2060e93 into vllm-project:main May 16, 2024
55 checks passed
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 19, 2024
dtrifiro pushed a commit to dtrifiro/vllm that referenced this pull request May 21, 2024
@brisker
Copy link

brisker commented May 30, 2024

@tlrmchlsmth
Any plan on w4a8 quantization support?

@tlrmchlsmth tlrmchlsmth deleted the tms/w8a8_cutlass_kernels branch June 14, 2024 17:20
@shesung
Copy link

shesung commented Jul 12, 2024

Is there some benchmark results for w8a8 speedup?

@MuYu-zhi
Copy link

MuYu-zhi commented Jul 22, 2024

@tlrmchlsmth hi, I'm invoking the cutlass_scaled_mm_dq kernel with enforce_eager=Falsemode, and raising an error,

[rank0]: File "/vllm/vllm/_custom_ops.py", line 189, in cutlass_scaled_mm_dq
[rank0]: vllm_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales)
[rank0]: RuntimeError: CUDA error: operation not permitted when stream is capturing
[rank0]: Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

It seems the implementation of this kernel is not compatible with cuda graph.

Do you have any advice for this? 3q

@tlrmchlsmth
Copy link
Collaborator Author

Hey @shesung, if you are looking for end-to-end results for w8a8, we do use the CUTLASS kernels for the fp8 results here https://twitter.com/neuralmagic/status/1812863986330910816

@tlrmchlsmth
Copy link
Collaborator Author

@MuYu-zhi are you using the kernels from this PR directly? In that case, yes they did not initially support CUDA graphs. They were also completely untuned and slow in their initial version, so I'd recommend looking at the ones from vLLM main

@MuYu-zhi
Copy link

@tlrmchlsmth I pulled from vllm main, but not the latest main, it's version 0.4.2. Does the kernel have any updates after 0.4.2? If I want to support cuda graph by myself, how? I don't have extensive experience in cuda graph.

@robertgshaw2-neuralmagic
Copy link
Collaborator

@MuYu-zhi yes, you need to upgrade to a newer version of vllm

Is there a reason you need to use 0.4.2?

@MuYu-zhi
Copy link

@robertgshaw2-neuralmagic No specific reason, it's just that the latest version was 0.4.2 when I pulled it at that time. I'll try updating it. Thanks.

Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants