Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Cohere2 #1158

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions llms/mlx_lm/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,22 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non
return mask * -1e9


def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
def create_attention_mask(h: mx.array, cache: Optional[Any] = None, reference_cache_idx: Optional[int] = None) -> mx.array:
T = h.shape[1]
if T > 1:
window_size = None
offset = 0
if cache is not None and cache[0] is not None:
c = cache[0]
if reference_cache_idx is not None:
c = cache[reference_cache_idx]
else:
c = cache[0]
if hasattr(c, "max_size"):
offset = min(c.max_size, c.offset)
window_size = c.max_size
else:
offset = c.offset

mask = create_causal_mask(T, offset, window_size=window_size)
mask = mask.astype(h.dtype)
else:
Expand Down
6 changes: 3 additions & 3 deletions llms/mlx_lm/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import mlx.nn as nn
from mlx.utils import tree_flatten, tree_map, tree_unflatten


def make_prompt_cache(
model: nn.Module,
max_kv_size: Optional[int] = None,
Expand All @@ -33,7 +32,7 @@ def make_prompt_cache(
]
else:
return [KVCache() for _ in range(num_layers)]


def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
"""
Expand Down Expand Up @@ -416,7 +415,8 @@ def trim(self, n):
return n

def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
raise NotImplementedError("RotatingKVCache Quantization NYI")
return self
#raise NotImplementedError("RotatingKVCache Quantization NYI")


class MambaCache(_BaseCache):
Expand Down
165 changes: 165 additions & 0 deletions llms/mlx_lm/models/cohere2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from dataclasses import dataclass
from typing import Any, Optional

import mlx.core as mx
import mlx.nn as nn

from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .rope_utils import initialize_rope
from .cache import KVCache, RotatingKVCache

@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
num_key_value_heads: int
rope_theta: float
vocab_size: int
layer_norm_eps: float
logit_scale: float
attention_bias: bool
# Additional Cohere2-specific arguments:
# rope_type and max_position_embeddings might influence the rope setup
rope_type: str = "default"
max_position_embeddings: int = 2048
sliding_window: Optional[int] = None,
sliding_window_pattern: Optional[int] = None,
order_of_interleaved_layers: Optional[int] = None,
use_cache: bool = True



class Cohere2Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = args.num_attention_heads
self.n_kv_heads = args.num_key_value_heads
head_dim = dim // self.n_heads
self.scale = head_dim**-0.5

self.q_proj = nn.Linear(dim, self.n_heads * head_dim, bias=args.attention_bias)
self.k_proj = nn.Linear(dim, self.n_kv_heads * head_dim, bias=args.attention_bias)
self.v_proj = nn.Linear(dim, self.n_kv_heads * head_dim, bias=args.attention_bias)
self.o_proj = nn.Linear(self.n_heads * head_dim, dim, bias=args.attention_bias)

self.sliding_window = args.sliding_window # Not yet implemented :(
self.use_qk_norm = False # Assuming QK norm not used by Cohere2 (adjust if needed)

# Initialize RoPE for Cohere2
self.rope = initialize_rope(
dims=head_dim,
base=args.rope_theta,
traditional=True,
max_position_embeddings=args.max_position_embeddings,
)

def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, rope = True) -> mx.array:
B, L, D = x.shape
q = self.q_proj(x).reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
k = self.k_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
v = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
# Apply RoPE
# In Cohere2, the original code applies RoPE before caching updates. We replicate that:
if cache is not None:
if rope:
q = self.rope(q, offset=cache.offset)
k = self.rope(k, offset=cache.offset)
k, v = cache.update_and_fetch(k, v)
if rope:
k = k[:, :, -self.sliding_window:, :]
v = v[:, :, -self.sliding_window:, :]
elif rope:
q = self.rope(q)
k = self.rope(k)
# Compute attention
out = scaled_dot_product_attention(
q, k, v, cache=cache, scale=self.scale, mask=mask
)

out = out.transpose(0, 2, 1, 3).reshape(B, L, D)
return self.o_proj(out)


class Cohere2MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
hdim = args.intermediate_size
self.gate_proj = nn.Linear(dim, hdim, bias=False)
self.up_proj = nn.Linear(dim, hdim, bias=False)
self.down_proj = nn.Linear(hdim, dim, bias=False)

def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))


class Cohere2TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.self_attn = Cohere2Attention(args)
self.mlp = Cohere2MLP(args)
self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.layer_norm_eps, affine=True, bias=False)

def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, rope = True) -> mx.array:
h = self.input_layernorm(x)
attn_h = self.self_attn(h, mask, cache, rope=rope)
ff_h = self.mlp(h)
return x + attn_h + ff_h


class Cohere2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [Cohere2TransformerBlock(args) for _ in range(args.num_hidden_layers)]
self.norm = nn.LayerNorm(args.hidden_size, eps=args.layer_norm_eps, affine=True, bias=False)
self.sliding_window = args.sliding_window
self.sliding_window_pattern = args.sliding_window_pattern
def __call__(self, inputs: mx.array, cache: Optional[Any] = None) -> mx.array:
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache, reference_cache_idx=self.sliding_window_pattern - 1)
sliding_window_mask = mask[:, -self.sliding_window:] if mask is not None else None
if cache is None:
cache = [None] * len(self.layers)
for i, (layer, c) in enumerate(zip(self.layers, cache)):
if self.sliding_window is not None:
index = i % self.sliding_window_pattern
if index < self.sliding_window_pattern - 1:
h = layer(h, mask=sliding_window_mask, cache=c)
else:
h = layer(h, mask=mask, cache=c, rope=False)


return self.norm(h)


class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type

self.model = Cohere2Model(args)
self.args = args

def __call__(self, inputs: mx.array, cache=None):
out = self.model(inputs, cache)
out = self.model.embed_tokens.as_linear(out) * self.args.logit_scale
return out

@property
def layers(self):
return self.model.layers

def make_cache(self):
caches = []
for i in range(self.args.num_hidden_layers):
if i % self.args.sliding_window_pattern == self.args.sliding_window_pattern - 1:
caches.append(KVCache())
else:
caches.append(RotatingKVCache(max_size=self.args.sliding_window, keep=0))
return caches
7 changes: 4 additions & 3 deletions llms/mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,10 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
and prompt_cache[0].offset > quantized_kv_start
):
for i in range(len(prompt_cache)):
prompt_cache[i] = prompt_cache[i].to_quantized(
group_size=kv_group_size, bits=kv_bits
)
if isinstance(prompt_cache[i], cache.KVCache):
prompt_cache[i] = prompt_cache[i].to_quantized(
group_size=kv_group_size, bits=kv_bits
)


def generate_step(
Expand Down