-
Notifications
You must be signed in to change notification settings - Fork 917
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for cohere2 #1157
Merged
+228
−3
Merged
Add support for cohere2 #1157
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
d7c64c6
add support for cohere2
Blaizzy 5d8b36c
revert to act_fn to silu
Blaizzy 52595da
fix tests and sliding window attention
Blaizzy 2f443cc
add tests
Blaizzy 0337646
add to tuner
Blaizzy d7d7048
fix sliding window
Blaizzy 406c7f3
add coauthor :)
Blaizzy ac58a95
Add rotating kvcache to save space
20d7925
Merge pull request #1 from N8python/add-cohere2-arch-rotating-kv-cache
Blaizzy 4aee862
some nits
awni dec2acf
style
awni 799dfde
nits
awni File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
# Copyright © 2023-2024 Apple Inc. | ||
|
||
from dataclasses import dataclass | ||
from typing import Optional, Tuple | ||
|
||
import mlx.core as mx | ||
import mlx.nn as nn | ||
|
||
from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention | ||
from .cache import KVCache, RotatingKVCache | ||
|
||
|
||
@dataclass | ||
class ModelArgs(BaseModelArgs): | ||
model_type: str | ||
hidden_size: int = 4096 | ||
head_dim: int = 128 | ||
num_hidden_layers: int = 32 | ||
intermediate_size: int = 14336 | ||
num_attention_heads: int = 32 | ||
num_key_value_heads: int = 8 | ||
rope_theta: float = 50000.0 | ||
vocab_size: int = 256000 | ||
layer_norm_eps: float = 1e-05 | ||
logit_scale: float = 0.0625 | ||
attention_bias: bool = False | ||
layer_norm_bias: bool = False | ||
sliding_window: int = 4096 | ||
sliding_window_pattern: int = 4 | ||
|
||
|
||
class Attention(nn.Module): | ||
def __init__(self, args: ModelArgs, layer_idx: int): | ||
super().__init__() | ||
self.args = args | ||
self.layer_idx = layer_idx | ||
|
||
dim = args.hidden_size | ||
self.n_heads = n_heads = args.num_attention_heads | ||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads | ||
self.head_dim = head_dim = args.head_dim | ||
if (head_dim * n_heads) != dim: | ||
raise ValueError( | ||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {dim}" | ||
f" and `num_heads`: {n_heads})." | ||
) | ||
self.scale = head_dim**-0.5 | ||
|
||
attetion_bias = args.attention_bias | ||
|
||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias) | ||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias) | ||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias) | ||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias) | ||
|
||
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta) | ||
|
||
self.use_sliding_window = (layer_idx + 1) % args.sliding_window_pattern != 0 | ||
|
||
def __call__( | ||
self, | ||
x: mx.array, | ||
mask: Optional[mx.array] = None, | ||
cache: Optional[Tuple[mx.array, mx.array]] = None, | ||
) -> mx.array: | ||
B, L, D = x.shape | ||
|
||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) | ||
|
||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) | ||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | ||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | ||
|
||
# Apply RoPE only if sliding window is enabled | ||
if self.use_sliding_window: | ||
if cache is None: | ||
queries = self.rope(queries) | ||
keys = self.rope(keys) | ||
else: | ||
queries = self.rope(queries, offset=cache.offset) | ||
keys = self.rope(keys, offset=cache.offset) | ||
|
||
if cache is not None: | ||
keys, values = cache.update_and_fetch(keys, values) | ||
|
||
if self.use_sliding_window and mask is not None: | ||
key_len = keys.shape[-2] | ||
if mask.shape[-1] != key_len: | ||
mask = mask[..., -key_len:] | ||
|
||
output = scaled_dot_product_attention( | ||
queries, keys, values, cache=cache, scale=self.scale, mask=mask | ||
) | ||
|
||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) | ||
return self.o_proj(output) | ||
|
||
|
||
class MLP(nn.Module): | ||
def __init__(self, dim, hidden_dim): | ||
super().__init__() | ||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) | ||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False) | ||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False) | ||
|
||
def __call__(self, x): | ||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) | ||
|
||
|
||
class TransformerBlock(nn.Module): | ||
def __init__(self, args: ModelArgs, layer_idx: int): | ||
super().__init__() | ||
self.hidden_size = args.hidden_size | ||
self.n_heads = args.num_attention_heads | ||
|
||
self.self_attn = Attention(args, layer_idx) | ||
self.mlp = MLP(args.hidden_size, args.intermediate_size) | ||
self.input_layernorm = nn.LayerNorm( | ||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias | ||
) | ||
self.args = args | ||
|
||
def __call__( | ||
self, | ||
x: mx.array, | ||
mask: Optional[mx.array] = None, | ||
cache: Optional[Tuple[mx.array, mx.array]] = None, | ||
) -> mx.array: | ||
h = self.input_layernorm(x) | ||
attn_h = self.self_attn(h, mask, cache) | ||
ff_h = self.mlp(h) | ||
return attn_h + ff_h + x | ||
|
||
|
||
class CohereModel(nn.Module): | ||
def __init__(self, args: ModelArgs): | ||
super().__init__() | ||
self.args = args | ||
self.vocab_size = args.vocab_size | ||
self.num_hidden_layers = args.num_hidden_layers | ||
assert self.vocab_size > 0 | ||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) | ||
self.layers = [ | ||
TransformerBlock(args=args, layer_idx=i) | ||
for i in range(args.num_hidden_layers) | ||
] | ||
self.norm = nn.LayerNorm( | ||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias | ||
) | ||
|
||
def __call__( | ||
self, | ||
inputs: mx.array, | ||
cache=None, | ||
): | ||
h = self.embed_tokens(inputs) | ||
|
||
T = h.shape[1] | ||
if T > 1: | ||
offset = cache[0].offset if cache else 0 | ||
mask = create_causal_mask(T, offset).astype(h.dtype) | ||
else: | ||
mask = None | ||
|
||
if cache is None: | ||
cache = [None] * len(self.layers) | ||
|
||
for layer, c in zip(self.layers, cache): | ||
h = layer(h, mask, c) | ||
|
||
return self.norm(h) | ||
|
||
|
||
class Model(nn.Module): | ||
def __init__(self, args: ModelArgs): | ||
super().__init__() | ||
self.model_type = args.model_type | ||
self.model = CohereModel(args) | ||
self.args = args | ||
|
||
def __call__( | ||
self, | ||
inputs: mx.array, | ||
cache=None, | ||
): | ||
out = self.model(inputs, cache) | ||
out = self.model.embed_tokens.as_linear(out) | ||
out = out * self.model.args.logit_scale | ||
return out | ||
|
||
def make_cache(self): | ||
caches = [] | ||
for i in range(self.args.num_hidden_layers): | ||
if ( | ||
i % self.args.sliding_window_pattern | ||
== self.args.sliding_window_pattern - 1 | ||
): | ||
caches.append(KVCache()) | ||
else: | ||
caches.append( | ||
RotatingKVCache(max_size=self.args.sliding_window, keep=0) | ||
) | ||
return caches | ||
|
||
@property | ||
def layers(self): | ||
return self.model.layers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -96,6 +96,7 @@ def to_lora(layer): | |
"gemma2", | ||
"starcoder2", | ||
"cohere", | ||
"cohere2", | ||
"minicpm", | ||
"deepseek", | ||
"olmo2", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This changed slightly. You would be over trimming the keys/values during the prefill stage otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see.
I thought of this but it was giving me a shape error when I tried exactly this. Because I knew the make_cache was already handling the kv slicing when I checked the shapes.
It seems like I should added the changes in mask (L158-163) that you added.