Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for cohere2 #1157

Merged
merged 12 commits into from
Dec 16, 2024
Merged

Add support for cohere2 #1157

merged 12 commits into from
Dec 16, 2024

Conversation

Blaizzy
Copy link
Contributor

@Blaizzy Blaizzy commented Dec 14, 2024

Adds support for Cohere2 with sliding attention.

Thanks a lot to @N8python for the inspiration!

Bf16
Screenshot 2024-12-14 at 4 53 12 PM

4bit
Screenshot 2024-12-14 at 5 10 12 PM

@Blaizzy Blaizzy changed the title add support for cohere2 Add support for cohere2 Dec 14, 2024
@Blaizzy Blaizzy marked this pull request as draft December 14, 2024 15:17
@Blaizzy Blaizzy marked this pull request as ready for review December 14, 2024 16:13
@Blaizzy
Copy link
Contributor Author

Blaizzy commented Dec 14, 2024

Thanks @N8python!

Verified this change saves ~2GB for 4bit:

Screenshot 2024-12-14 at 11 35 09 PM

and ~1.2GB for bf16:
Screenshot 2024-12-14 at 10 20 56 PM

@Blaizzy Blaizzy mentioned this pull request Dec 14, 2024
Comment on lines +86 to +89
if self.use_sliding_window and mask is not None:
key_len = keys.shape[-2]
if mask.shape[-1] != key_len:
mask = mask[..., -key_len:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changed slightly. You would be over trimming the keys/values during the prefill stage otherwise.

Copy link
Contributor Author

@Blaizzy Blaizzy Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see.

I thought of this but it was giving me a shape error when I tried exactly this. Because I knew the make_cache was already handling the kv slicing when I checked the shapes.

...
===
window_size 4096
keys shape (1, 8, 4096, 128)
values shape (1, 8, 4096, 128)
mask shape after (512, 4096)
===
window_size 4608
keys shape (1, 8, 4608, 128)
values shape (1, 8, 4608, 128)
mask shape after (512, 4608)
===
window_size 4608
keys shape (1, 8, 4608, 128)
values shape (1, 8, 4608, 128)
mask shape after (512, 4608)
===
...

It seems like I should added the changes in mask (L158-163) that you added.

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!!

@awni awni merged commit dfa4dd6 into ml-explore:main Dec 16, 2024
4 checks passed
@Blaizzy
Copy link
Contributor Author

Blaizzy commented Dec 16, 2024

Our pleasure!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants