-
-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ROCm] Fixup arch checks for ROCM #2627
Conversation
The ROCM stack with PyTorch supports a wide set of gfx architectures. This can be displayed by printing PYTORCH_ROCM_ARCH env. In the absence of PYTORCH_ROCM_ARCH pytorch uses theoutput from rocm_agent_enumerator to choose what to compile for. vllm supports a subset of these, (gfx908, gfx90a,...) Due to a need to potentially support multiple architectures at once (ex. docker image) it's important to make sure vllm is compiled with them all unless specified otherwise. We now gather either the PYTORCH_ROCM_ARCH env or rocm_agent_enumerator output and cross reference with ROCM_SUPPORTED_ARCHS from vllm to generate a list of arches to build for.
Hi @dllehr-amd, thanks for submitting the PR! This is a bit confusing. It seems like there are 3 environment variables:
And here we are using |
Flash-Attention currently supports the MI2xx and MI3xx architectures. Modify vllm's support matrix during build to reflect this.
Hi @WoosukKwon. I made a change to vllm arch's to match the ones supported by our Flash Attention(GPU_ARCHS) As we'd be hard pressed to guarantee support on the other gfx's. There are still your three variables, but ROCM_SUPPORTED_ARCHS is reduced to the amount that Flash Attention supports. The way this PR works now, is it'll take all of the architectures PyTorch built with, and cross reference with what vllm supports, and only build for those. As of this PR, if you build it will only build for gfx90a and gfx942. Without any additional input from a user. Does that make things any more clear? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for the fix! BTW, it'd be nicer if we can have a CI/CD pipeline for AMD GPUs :)
@WoosukKwon @dllehr-amd just so folks know, since this PR, the ROCm Flash Attention officially supports |
I've opened #2792 |
The ROCM stack with PyTorch supports a wide set of gfx architectures. This can be displayed by printing PYTORCH_ROCM_ARCH env. In the absence of PYTORCH_ROCM_ARCH pytorch uses theoutput from rocm_agent_enumerator to choose what to compile for.
vllm supports a subset of these, (gfx908, gfx90a,...)
Due to a need to potentially support multiple architectures at once (ex. docker image) it's important to make sure vllm is compiled with them all unless specified otherwise.
We now gather either the PYTORCH_ROCM_ARCH env or rocm_agent_enumerator output and cross reference with ROCM_SUPPORTED_ARCHS from vllm to generate a list of arches to build for.