Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jit: further accelerate compilation by spliting files and multi-threading #628

Merged
merged 2 commits into from
Nov 23, 2024

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Nov 23, 2024

This PR accelerates JIT compilation by:

  • Add a parallel_load_modules function to load necessary modules for a model in parallel using python multi-threading.
  • Splitting sampling.cu into renorm.cu and sampling.cu

The batch prefill attention template could be further split into multiple instances to accelerate compilation, we leave that for future work.

@yzh119 yzh119 merged commit f5842b8 into main Nov 23, 2024
yzh119 added a commit that referenced this pull request Nov 24, 2024
Currently unittests are slow when using flashinfer jit because we only
compile kernels the first time we run it, it's blocking and didn't
compile multiple ops in parallel. This PR add a warmup pre-hook to
kernel unittests, so that we compile all necessary kernels before
running the unittests in JIT mode, which greatly accelerate the
unittests.

This PR also fixes the several issues with #628 :
1. using thread-safe `make_dirs(..., exist_ok=True)` instead of relying
on `os.path.exists`
2. change the signature of `parallel_load_modules` to lists of
`(jit_module_creation_func, args)` instead of lambda function, because
lambda function captures variable by ref instead of value, which may
cause some unexpected errors.
@yzh119 yzh119 deleted the further-speedup-compilation branch November 24, 2024 07:58
yzh119 added a commit that referenced this pull request Nov 24, 2024
Followup of #628, this
PR splits prefill attention jit templates so that we compile different
mask modes in different files.

JIT compilation time of a prefill kernels of a certain configuration
(shape, dtype etc) could be reduced to 10 seconds after this PR.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant