Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Will AOT compilation still be supported after JIT compilation is added? #510

Closed
danieldk opened this issue Sep 25, 2024 · 9 comments
Closed

Comments

@danieldk
Copy link

We saw that support for JIT compilation will be added in #507. We were wondering what the plans are for ahead-of-time compilation. We are happily using flashinfer in Text Generation Inference the support for KV caches with block_size=1 has really been helpful for us to support fine-grained prefix caching.

For many of our users it's pretty important that compilation is done ahead of time. When infrastructure is scaled up, we want to avoid delaying/slowing down processing of user requests due to JIT compilation and since infrastructure is often heterogeneous (both in the models served and in the GPUs used), we would have to compile most kernels anyway. So, for us it would be really useful if AOT will be supported going forward.

Thank you for your awesome work on flashinfer 🤗.

@yzh119
Copy link
Collaborator

yzh119 commented Sep 25, 2024

Hi @danieldk , thanks for bringing this up!

When infrastructure is scaled up, we want to avoid delaying/slowing down processing of user requests due to JIT compilation and since infrastructure is often heterogeneous (both in the models served and in the GPUs used), we would have to compile most kernels anyway

This is a reasonable concern, I think we can keep both JIT and AOT (for a set of "core" kernels, ~200mb). We should use "core" kernels whenever possible, and use JIT for the remaining kernels (new head dimensions, some attention variants, etc.), WDTY?

@abcdabcd987
Copy link
Member

I agree with @danieldk that AOT is important. For production use, we typically build a docker image. Then Kubernetes will spawn a pod running a container of that image. The container is ephemeral. So if AOT is missing, this would mean that every time the pod restart, we'll have to JIT compile. This would slow down the start time significantly.

For PyPI, we can ship a sdist and do JIT only. This can make sure that PyPI size is small.

For our hosted wheels, I agree with @yzh119 that AOT "core" kernels is a good idea. I think the "core" kernels should include kernels that popular pretrained models uses (e.g., Llama, QWen, DeepSeek).

I have a few suggestions additionally --

Frist, For better user experience, output a log when JITing a kernel (maybe also include elapsed time). This way, if we experience an unexpected long start time, we can know that it comes from JIT FlashInfer kernels. Logging the JIT kernel names can also help us decide what to be included in "core" kernels.

Second, wheels shouldn't pin to PyTorch versions. We can compile kernels that link to particular CUDA version and expose C ABI. We write a separate .cpp file that extracts PyTorch Tensor metadata and calls the kernel. When installing the wheel, we compile the python binding only. This way, it makes sure that compilation takes minimal time when pip install the wheel (only the time for pybind).

Shipping wheels tied to PyTorch version takes time and storage. And it might even be wrong. I don't think PyTorch explicitly guarantee that torch.Tensor ABI remains the same, even across minor versions.

Third, it would also be good to provide a customizable AOT script, just in case some users want AOT beyond the "core" kernels.

Fourth, as for the wheel size, I think even 2GB is acceptable. This is because CUDA + PyTorch already takes up maybe 10GB container size. It's already huge even without FlashInfer. So we shouldn't worry about FlashInfer takes up additional spaces.

@yzh119
Copy link
Collaborator

yzh119 commented Sep 25, 2024

Thanks @abcdabcd987 for your thoughts, here is my tentative plan:

Maintain two packages: flashinfer_aot which ships the pre-built binary in the sdist, another flashinfer which is sdist package that runs jit by default.

  • flashinfer will be hosted on pypi.
  • flashinfer_aot use self-hosted index (https://flashinfer.ai/whl), which ships binary kernels, when user pip install it, it will be linked to user's pytorch installation. It's not mandatory to install flashinfer_aot to use flashinfer.

The two packages share version numbers. flashinfer package will first check whether flashinfer_aot is installed, if so, flashinfer will prioritize using pre-compiled kernels in flashinfer_aot and only uses JIT when kernel configuration is not found in the flashinfer_aot, otherwise, it will always compile kernels with JIT.

@danieldk how does this plan sound to you?

ping @comaniac @zhyncs @Ying1123 @merrymercy @WoosukKwon

@comaniac
Copy link
Collaborator

Sounds good to me. It would be even more better if we could allow flashinfer to install flashinfer_aot; otherwise most users would probably suffer from long compile time (and require nvcc in the environment). Ideally something like the following (not sure if it's achievable or make sense):

pip install "flashinfer[aot]" # Implicitly call `pip install flashinfer_aot`

@abcdabcd987
Copy link
Member

I'd push against a separate flashinfer_aot package name. It's possible that users install both and observe confusing behaviors. Especially, when the user upgrades one but not the other. Having a single package name will at least ensure that only one copy is installed.

@danieldk
Copy link
Author

danieldk commented Sep 26, 2024

Third, it would also be good to provide a customizable AOT script, just in case some users want AOT beyond the "core" kernels.

This sounds great! I don't think we mind compiling flashinfer ourselves to get all the kernels AOT. For development we are caching builds anyway through Nix and for production docker containers we are also looking to improve build caching.

I think even outside applications like TGI, AOTing the most-used kernels and JITing the rest sounds like a good strategy.

@abcdabcd987
Copy link
Member

After offline discussion, I think the updated plan is as follows (@yzh119 please confirm):

  1. For PyPI, we publish flashinfer package as a sdist. It does not contain any precompiled kernels. Users will JIT compile when kernel is invoked.
  2. For local development, it's the same as PyPI. JIT-only.
  3. Perhaps most importantly, we will provide a "precompiled sdist" under pip index url https://flashinfer.ai/whl/cu???/.
    • The "precompiled sdist" will contain precompiled kernels for common uses.
    • These precompiled kernels (.so files) are linked with specific CUDA version.
    • We don't want to link the kernels with PyTorch because PyTorch might change ABI.
    • When pip install flashinfer -i https://flashinfer.ai/whl/cu???/, users will compile a PyTorch extension that links to the precompiled kernel .so files. Since this is only compiling the pybind, not the kernels, this compilation will be fast.
    • Kernels that are not precompiled should still be able to JIT.
  4. The script for producing the precompiled sdist should be customizable. So if users have different set of kernels that want AOT compiled, they can produce their own precompiled sdist. (I think this would satisfy @danieldk 's need.)
  5. There will be no bdist / wheel anymore. It's replaced by precompiled sdist.

@danieldk
Copy link
Author

That sounds awesome. Thank you for taking our use case into account!

@yzh119
Copy link
Collaborator

yzh119 commented Oct 9, 2024

Both JIT mode and AOT mode are supported in #507 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants