Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix Caching- fix t4 triton error #2517

Merged
merged 5 commits into from
Feb 16, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion vllm/model_executor/layers/triton_kernel/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import triton
import triton.language as tl

TESLA = 'Tesla' in torch.cuda.get_device_name(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be possible to check for compute capability instead? also, we should do this inside context_attention_fwd, as calling CUDA APIs before we set CUDA_VISIBLE_DEVICES will lead to errors.

Copy link
Collaborator

@esmeetu esmeetu Jan 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can set prefix_block_size as a parameter in CacheConfig and allow user configure in LLM?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this sort of a thing should be ideally derived automatically.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yard1 @caoshiyi Does the block size affect the memory utilization or prefix speed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@esmeetu The block size is mainly dependent on the shared mem size for different GPU architectures. It will affect the prefix-prefill kernel speed a little bit but has nothing to do with the GPU memory utilization.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we set this variable in a function instead of a global variable? Setting it in global variable may lead to issues in distributed setting.


if triton.__version__ >= "2.1.0":

@triton.jit
Expand Down Expand Up @@ -618,7 +620,8 @@ def context_attention_fwd(q,
b_ctx_len,
max_input_len,
alibi_slopes=None):
BLOCK = 128

BLOCK = 128 if not TESLA else 64
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
Expand Down
Loading