Skip to content

Commit

Permalink
Add support for cohere2 (#1157)
Browse files Browse the repository at this point in the history
* add support for cohere2

* revert to act_fn to silu

* fix tests and sliding window attention

* add tests

* add to tuner

* fix sliding window

* add coauthor :)

Co-authored-by: n8programs <43304488+N8python@users.noreply.github.com>

* Add rotating kvcache to save space

* some nits

* style

* nits

---------

Co-authored-by: n8programs <43304488+N8python@users.noreply.github.com>
Co-authored-by: N8 <n8@n8programs.com>
Co-authored-by: Awni Hannun <awni@apple.com>
  • Loading branch information
4 people authored Dec 16, 2024
1 parent fc0674d commit dfa4dd6
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 3 deletions.
207 changes: 207 additions & 0 deletions llms/mlx_lm/models/cohere2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright © 2023-2024 Apple Inc.

from dataclasses import dataclass
from typing import Optional, Tuple

import mlx.core as mx
import mlx.nn as nn

from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention
from .cache import KVCache, RotatingKVCache


@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int = 4096
head_dim: int = 128
num_hidden_layers: int = 32
intermediate_size: int = 14336
num_attention_heads: int = 32
num_key_value_heads: int = 8
rope_theta: float = 50000.0
vocab_size: int = 256000
layer_norm_eps: float = 1e-05
logit_scale: float = 0.0625
attention_bias: bool = False
layer_norm_bias: bool = False
sliding_window: int = 4096
sliding_window_pattern: int = 4


class Attention(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
self.args = args
self.layer_idx = layer_idx

dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim
if (head_dim * n_heads) != dim:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {dim}"
f" and `num_heads`: {n_heads})."
)
self.scale = head_dim**-0.5

attetion_bias = args.attention_bias

self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias)

self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)

self.use_sliding_window = (layer_idx + 1) % args.sliding_window_pattern != 0

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape

queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)

queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

# Apply RoPE only if sliding window is enabled
if self.use_sliding_window:
if cache is None:
queries = self.rope(queries)
keys = self.rope(keys)
else:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)

if cache is not None:
keys, values = cache.update_and_fetch(keys, values)

if self.use_sliding_window and mask is not None:
key_len = keys.shape[-2]
if mask.shape[-1] != key_len:
mask = mask[..., -key_len:]

output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)

output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)


class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)

def __call__(self, x):
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))


class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
self.hidden_size = args.hidden_size
self.n_heads = args.num_attention_heads

self.self_attn = Attention(args, layer_idx)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
)
self.args = args

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
h = self.input_layernorm(x)
attn_h = self.self_attn(h, mask, cache)
ff_h = self.mlp(h)
return attn_h + ff_h + x


class CohereModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args, layer_idx=i)
for i in range(args.num_hidden_layers)
]
self.norm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
)

def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)

T = h.shape[1]
if T > 1:
offset = cache[0].offset if cache else 0
mask = create_causal_mask(T, offset).astype(h.dtype)
else:
mask = None

if cache is None:
cache = [None] * len(self.layers)

for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)

return self.norm(h)


class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = CohereModel(args)
self.args = args

def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale
return out

def make_cache(self):
caches = []
for i in range(self.args.num_hidden_layers):
if (
i % self.args.sliding_window_pattern
== self.args.sliding_window_pattern - 1
):
caches.append(KVCache())
else:
caches.append(
RotatingKVCache(max_size=self.args.sliding_window, keep=0)
)
return caches

@property
def layers(self):
return self.model.layers
1 change: 1 addition & 0 deletions llms/mlx_lm/tuner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def to_lora(layer):
"gemma2",
"starcoder2",
"cohere",
"cohere2",
"minicpm",
"deepseek",
"olmo2",
Expand Down
7 changes: 4 additions & 3 deletions llms/mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,10 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
and prompt_cache[0].offset > quantized_kv_start
):
for i in range(len(prompt_cache)):
prompt_cache[i] = prompt_cache[i].to_quantized(
group_size=kv_group_size, bits=kv_bits
)
if isinstance(prompt_cache[i], cache.KVCache):
prompt_cache[i] = prompt_cache[i].to_quantized(
group_size=kv_group_size, bits=kv_bits
)


def generate_step(
Expand Down
16 changes: 16 additions & 0 deletions llms/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,22 @@ def test_exaone(self):
model = exaone.Model(args)
self.model_test_runner(model, args.model_type, args.vocab_size, args.num_layers)

def test_cohere2(self):
from mlx_lm.models import cohere2

args = cohere2.ModelArgs(
model_type="cohere2",
hidden_size=4096,
head_dim=128,
num_hidden_layers=40,
sliding_window=4096,
sliding_window_pattern=4,
)
model = cohere2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)


if __name__ == "__main__":
unittest.main()

0 comments on commit dfa4dd6

Please sign in to comment.