-
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prefix Caching- fix t4 triton error #2517
Merged
Merged
Changes from 3 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,8 @@ | |
import triton | ||
import triton.language as tl | ||
|
||
TESLA = 'Tesla' in torch.cuda.get_device_name(0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we set this variable in a function instead of a global variable? Setting it in global variable may lead to issues in distributed setting. |
||
|
||
if triton.__version__ >= "2.1.0": | ||
|
||
@triton.jit | ||
|
@@ -618,7 +620,8 @@ def context_attention_fwd(q, | |
b_ctx_len, | ||
max_input_len, | ||
alibi_slopes=None): | ||
BLOCK = 128 | ||
|
||
BLOCK = 128 if not TESLA else 64 | ||
# shape constraints | ||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] | ||
assert Lq == Lk and Lk == Lv | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it be possible to check for compute capability instead? also, we should do this inside
context_attention_fwd
, as calling CUDA APIs before we setCUDA_VISIBLE_DEVICES
will lead to errors.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can set
prefix_block_size
as a parameter inCacheConfig
and allow user configure inLLM
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this sort of a thing should be ideally derived automatically.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Yard1 @caoshiyi Does the block size affect the memory utilization or prefix speed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@esmeetu The block size is mainly dependent on the shared mem size for different GPU architectures. It will affect the prefix-prefill kernel speed a little bit but has nothing to do with the GPU memory utilization.