Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for cohere2 #1157

Merged
merged 12 commits into from
Dec 16, 2024
207 changes: 207 additions & 0 deletions llms/mlx_lm/models/cohere2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright © 2023-2024 Apple Inc.

from dataclasses import dataclass
from typing import Optional, Tuple

import mlx.core as mx
import mlx.nn as nn

from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention
from .cache import KVCache, RotatingKVCache


@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int = 4096
head_dim: int = 128
num_hidden_layers: int = 32
intermediate_size: int = 14336
num_attention_heads: int = 32
num_key_value_heads: int = 8
rope_theta: float = 50000.0
vocab_size: int = 256000
layer_norm_eps: float = 1e-05
logit_scale: float = 0.0625
attention_bias: bool = False
layer_norm_bias: bool = False
sliding_window: int = 4096
sliding_window_pattern: int = 4


class Attention(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
self.args = args
self.layer_idx = layer_idx

dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim
if (head_dim * n_heads) != dim:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {dim}"
f" and `num_heads`: {n_heads})."
)
self.scale = head_dim**-0.5

attetion_bias = args.attention_bias

self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias)

self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)

self.use_sliding_window = (layer_idx + 1) % args.sliding_window_pattern != 0

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape

queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)

queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

# Apply RoPE only if sliding window is enabled
if self.use_sliding_window:
if cache is None:
queries = self.rope(queries)
keys = self.rope(keys)
else:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)

if cache is not None:
keys, values = cache.update_and_fetch(keys, values)

if self.use_sliding_window and mask is not None:
key_len = keys.shape[-2]
if mask.shape[-1] != key_len:
mask = mask[..., -key_len:]
Comment on lines +86 to +89
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changed slightly. You would be over trimming the keys/values during the prefill stage otherwise.

Copy link
Contributor Author

@Blaizzy Blaizzy Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see.

I thought of this but it was giving me a shape error when I tried exactly this. Because I knew the make_cache was already handling the kv slicing when I checked the shapes.

...
===
window_size 4096
keys shape (1, 8, 4096, 128)
values shape (1, 8, 4096, 128)
mask shape after (512, 4096)
===
window_size 4608
keys shape (1, 8, 4608, 128)
values shape (1, 8, 4608, 128)
mask shape after (512, 4608)
===
window_size 4608
keys shape (1, 8, 4608, 128)
values shape (1, 8, 4608, 128)
mask shape after (512, 4608)
===
...

It seems like I should added the changes in mask (L158-163) that you added.


output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)

output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)


class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)

def __call__(self, x):
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))


class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
self.hidden_size = args.hidden_size
self.n_heads = args.num_attention_heads

self.self_attn = Attention(args, layer_idx)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
)
self.args = args

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
h = self.input_layernorm(x)
attn_h = self.self_attn(h, mask, cache)
ff_h = self.mlp(h)
return attn_h + ff_h + x


class CohereModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args, layer_idx=i)
for i in range(args.num_hidden_layers)
]
self.norm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
)

def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)

T = h.shape[1]
if T > 1:
offset = cache[0].offset if cache else 0
mask = create_causal_mask(T, offset).astype(h.dtype)
else:
mask = None

if cache is None:
cache = [None] * len(self.layers)

for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)

return self.norm(h)


class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = CohereModel(args)
self.args = args

def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale
return out

def make_cache(self):
caches = []
for i in range(self.args.num_hidden_layers):
if (
i % self.args.sliding_window_pattern
== self.args.sliding_window_pattern - 1
):
caches.append(KVCache())
else:
caches.append(
RotatingKVCache(max_size=self.args.sliding_window, keep=0)
)
return caches

@property
def layers(self):
return self.model.layers
1 change: 1 addition & 0 deletions llms/mlx_lm/tuner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def to_lora(layer):
"gemma2",
"starcoder2",
"cohere",
"cohere2",
"minicpm",
"deepseek",
"olmo2",
Expand Down
7 changes: 4 additions & 3 deletions llms/mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,10 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
and prompt_cache[0].offset > quantized_kv_start
):
for i in range(len(prompt_cache)):
prompt_cache[i] = prompt_cache[i].to_quantized(
group_size=kv_group_size, bits=kv_bits
)
if isinstance(prompt_cache[i], cache.KVCache):
prompt_cache[i] = prompt_cache[i].to_quantized(
group_size=kv_group_size, bits=kv_bits
)


def generate_step(
Expand Down
16 changes: 16 additions & 0 deletions llms/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,22 @@ def test_exaone(self):
model = exaone.Model(args)
self.model_test_runner(model, args.model_type, args.vocab_size, args.num_layers)

def test_cohere2(self):
from mlx_lm.models import cohere2

args = cohere2.ModelArgs(
model_type="cohere2",
hidden_size=4096,
head_dim=128,
num_hidden_layers=40,
sliding_window=4096,
sliding_window_pattern=4,
)
model = cohere2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)


if __name__ == "__main__":
unittest.main()