-
Notifications
You must be signed in to change notification settings - Fork 916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for cohere2 #1157
Add support for cohere2 #1157
Conversation
Co-authored-by: n8programs <43304488+N8python@users.noreply.github.com>
Add rotating kvcache to save space
Thanks @N8python! Verified this change saves ~2GB for 4bit: |
if self.use_sliding_window and mask is not None: | ||
key_len = keys.shape[-2] | ||
if mask.shape[-1] != key_len: | ||
mask = mask[..., -key_len:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This changed slightly. You would be over trimming the keys/values during the prefill stage otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see.
I thought of this but it was giving me a shape error when I tried exactly this. Because I knew the make_cache was already handling the kv slicing when I checked the shapes.
...
===
window_size 4096
keys shape (1, 8, 4096, 128)
values shape (1, 8, 4096, 128)
mask shape after (512, 4096)
===
window_size 4608
keys shape (1, 8, 4608, 128)
values shape (1, 8, 4608, 128)
mask shape after (512, 4608)
===
window_size 4608
keys shape (1, 8, 4608, 128)
values shape (1, 8, 4608, 128)
mask shape after (512, 4608)
===
...
It seems like I should added the changes in mask (L158-163) that you added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!!
Our pleasure! |
Adds support for Cohere2 with sliding attention.
Thanks a lot to @N8python for the inspiration!
Bf16
4bit