Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Fixup arch checks for ROCM #2627

Merged
merged 4 commits into from
Feb 5, 2024

Conversation

dllehr-amd
Copy link
Contributor

The ROCM stack with PyTorch supports a wide set of gfx architectures. This can be displayed by printing PYTORCH_ROCM_ARCH env. In the absence of PYTORCH_ROCM_ARCH pytorch uses theoutput from rocm_agent_enumerator to choose what to compile for.

vllm supports a subset of these, (gfx908, gfx90a,...)

Due to a need to potentially support multiple architectures at once (ex. docker image) it's important to make sure vllm is compiled with them all unless specified otherwise.

We now gather either the PYTORCH_ROCM_ARCH env or rocm_agent_enumerator output and cross reference with ROCM_SUPPORTED_ARCHS from vllm to generate a list of arches to build for.

The ROCM stack with PyTorch supports a wide set of gfx architectures.  This can be
displayed by printing PYTORCH_ROCM_ARCH env.  In the absence of PYTORCH_ROCM_ARCH
pytorch uses theoutput from rocm_agent_enumerator to choose what to compile for.

vllm supports a subset of these, (gfx908, gfx90a,...)

Due to a need to potentially support multiple architectures at once (ex. docker image)
it's important to make sure vllm is compiled with them all unless specified otherwise.

We now gather either the PYTORCH_ROCM_ARCH env or rocm_agent_enumerator output and
cross reference with ROCM_SUPPORTED_ARCHS from vllm to generate a list of
arches to build for.
@WoosukKwon
Copy link
Collaborator

Hi @dllehr-amd, thanks for submitting the PR! This is a bit confusing. It seems like there are 3 environment variables:

  1. PyTorch uses TORCH_CUDA_ARCH_LIST env var for specifying the target NVIDIA GPUs.
  2. PyTorch w/ ROCm backend uses PYTORCH_ROCM_ARCH env var for specifying the target AMD GPUs.
  3. The ROCm flash-attn repo uses GPU_ARCHS env var for specifying the target AMD GPUs.

And here we are using PYTORCH_ROCM_ARCH . Is my understanding correct?

Flash-Attention currently supports the MI2xx and MI3xx architectures.
Modify vllm's support matrix during build to reflect this.
@dllehr-amd
Copy link
Contributor Author

Hi @WoosukKwon. I made a change to vllm arch's to match the ones supported by our Flash Attention(GPU_ARCHS) As we'd be hard pressed to guarantee support on the other gfx's. There are still your three variables, but ROCM_SUPPORTED_ARCHS is reduced to the amount that Flash Attention supports.

The way this PR works now, is it'll take all of the architectures PyTorch built with, and cross reference with what vllm supports, and only build for those.

As of this PR, if you build it will only build for gfx90a and gfx942. Without any additional input from a user.

Does that make things any more clear?

Thanks!

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the fix! BTW, it'd be nicer if we can have a CI/CD pipeline for AMD GPUs :)

@WoosukKwon WoosukKwon merged commit 2ccee3d into vllm-project:main Feb 5, 2024
15 of 17 checks passed
@jamestwhedbee
Copy link
Contributor

I made a change to vllm arch's to match the ones supported by our Flash Attention(GPU_ARCHS) As we'd be hard pressed to guarantee support on the other gfx's.

@WoosukKwon @dllehr-amd just so folks know, since this PR, the ROCm Flash Attention officially supports gfx_908 (it always did, it just wasn't documented until this PR)

@jamestwhedbee
Copy link
Contributor

I've opened #2792

hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
alexm-redhat pushed a commit to neuralmagic/nm-vllm that referenced this pull request Feb 13, 2024
jvmncs pushed a commit to jvmncs/vllm that referenced this pull request Feb 14, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Feb 20, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Feb 22, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Mar 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants