-
Notifications
You must be signed in to change notification settings - Fork 913
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add support for cohere2 * revert to act_fn to silu * fix tests and sliding window attention * add tests * add to tuner * fix sliding window * add coauthor :) Co-authored-by: n8programs <43304488+N8python@users.noreply.github.com> * Add rotating kvcache to save space * some nits * style * nits --------- Co-authored-by: n8programs <43304488+N8python@users.noreply.github.com> Co-authored-by: N8 <n8@n8programs.com> Co-authored-by: Awni Hannun <awni@apple.com>
- Loading branch information
1 parent
fc0674d
commit dfa4dd6
Showing
4 changed files
with
228 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
# Copyright © 2023-2024 Apple Inc. | ||
|
||
from dataclasses import dataclass | ||
from typing import Optional, Tuple | ||
|
||
import mlx.core as mx | ||
import mlx.nn as nn | ||
|
||
from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention | ||
from .cache import KVCache, RotatingKVCache | ||
|
||
|
||
@dataclass | ||
class ModelArgs(BaseModelArgs): | ||
model_type: str | ||
hidden_size: int = 4096 | ||
head_dim: int = 128 | ||
num_hidden_layers: int = 32 | ||
intermediate_size: int = 14336 | ||
num_attention_heads: int = 32 | ||
num_key_value_heads: int = 8 | ||
rope_theta: float = 50000.0 | ||
vocab_size: int = 256000 | ||
layer_norm_eps: float = 1e-05 | ||
logit_scale: float = 0.0625 | ||
attention_bias: bool = False | ||
layer_norm_bias: bool = False | ||
sliding_window: int = 4096 | ||
sliding_window_pattern: int = 4 | ||
|
||
|
||
class Attention(nn.Module): | ||
def __init__(self, args: ModelArgs, layer_idx: int): | ||
super().__init__() | ||
self.args = args | ||
self.layer_idx = layer_idx | ||
|
||
dim = args.hidden_size | ||
self.n_heads = n_heads = args.num_attention_heads | ||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads | ||
self.head_dim = head_dim = args.head_dim | ||
if (head_dim * n_heads) != dim: | ||
raise ValueError( | ||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {dim}" | ||
f" and `num_heads`: {n_heads})." | ||
) | ||
self.scale = head_dim**-0.5 | ||
|
||
attetion_bias = args.attention_bias | ||
|
||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias) | ||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias) | ||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias) | ||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias) | ||
|
||
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta) | ||
|
||
self.use_sliding_window = (layer_idx + 1) % args.sliding_window_pattern != 0 | ||
|
||
def __call__( | ||
self, | ||
x: mx.array, | ||
mask: Optional[mx.array] = None, | ||
cache: Optional[Tuple[mx.array, mx.array]] = None, | ||
) -> mx.array: | ||
B, L, D = x.shape | ||
|
||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) | ||
|
||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) | ||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | ||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | ||
|
||
# Apply RoPE only if sliding window is enabled | ||
if self.use_sliding_window: | ||
if cache is None: | ||
queries = self.rope(queries) | ||
keys = self.rope(keys) | ||
else: | ||
queries = self.rope(queries, offset=cache.offset) | ||
keys = self.rope(keys, offset=cache.offset) | ||
|
||
if cache is not None: | ||
keys, values = cache.update_and_fetch(keys, values) | ||
|
||
if self.use_sliding_window and mask is not None: | ||
key_len = keys.shape[-2] | ||
if mask.shape[-1] != key_len: | ||
mask = mask[..., -key_len:] | ||
|
||
output = scaled_dot_product_attention( | ||
queries, keys, values, cache=cache, scale=self.scale, mask=mask | ||
) | ||
|
||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) | ||
return self.o_proj(output) | ||
|
||
|
||
class MLP(nn.Module): | ||
def __init__(self, dim, hidden_dim): | ||
super().__init__() | ||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) | ||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False) | ||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False) | ||
|
||
def __call__(self, x): | ||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) | ||
|
||
|
||
class TransformerBlock(nn.Module): | ||
def __init__(self, args: ModelArgs, layer_idx: int): | ||
super().__init__() | ||
self.hidden_size = args.hidden_size | ||
self.n_heads = args.num_attention_heads | ||
|
||
self.self_attn = Attention(args, layer_idx) | ||
self.mlp = MLP(args.hidden_size, args.intermediate_size) | ||
self.input_layernorm = nn.LayerNorm( | ||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias | ||
) | ||
self.args = args | ||
|
||
def __call__( | ||
self, | ||
x: mx.array, | ||
mask: Optional[mx.array] = None, | ||
cache: Optional[Tuple[mx.array, mx.array]] = None, | ||
) -> mx.array: | ||
h = self.input_layernorm(x) | ||
attn_h = self.self_attn(h, mask, cache) | ||
ff_h = self.mlp(h) | ||
return attn_h + ff_h + x | ||
|
||
|
||
class CohereModel(nn.Module): | ||
def __init__(self, args: ModelArgs): | ||
super().__init__() | ||
self.args = args | ||
self.vocab_size = args.vocab_size | ||
self.num_hidden_layers = args.num_hidden_layers | ||
assert self.vocab_size > 0 | ||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) | ||
self.layers = [ | ||
TransformerBlock(args=args, layer_idx=i) | ||
for i in range(args.num_hidden_layers) | ||
] | ||
self.norm = nn.LayerNorm( | ||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias | ||
) | ||
|
||
def __call__( | ||
self, | ||
inputs: mx.array, | ||
cache=None, | ||
): | ||
h = self.embed_tokens(inputs) | ||
|
||
T = h.shape[1] | ||
if T > 1: | ||
offset = cache[0].offset if cache else 0 | ||
mask = create_causal_mask(T, offset).astype(h.dtype) | ||
else: | ||
mask = None | ||
|
||
if cache is None: | ||
cache = [None] * len(self.layers) | ||
|
||
for layer, c in zip(self.layers, cache): | ||
h = layer(h, mask, c) | ||
|
||
return self.norm(h) | ||
|
||
|
||
class Model(nn.Module): | ||
def __init__(self, args: ModelArgs): | ||
super().__init__() | ||
self.model_type = args.model_type | ||
self.model = CohereModel(args) | ||
self.args = args | ||
|
||
def __call__( | ||
self, | ||
inputs: mx.array, | ||
cache=None, | ||
): | ||
out = self.model(inputs, cache) | ||
out = self.model.embed_tokens.as_linear(out) | ||
out = out * self.model.args.logit_scale | ||
return out | ||
|
||
def make_cache(self): | ||
caches = [] | ||
for i in range(self.args.num_hidden_layers): | ||
if ( | ||
i % self.args.sliding_window_pattern | ||
== self.args.sliding_window_pattern - 1 | ||
): | ||
caches.append(KVCache()) | ||
else: | ||
caches.append( | ||
RotatingKVCache(max_size=self.args.sliding_window, keep=0) | ||
) | ||
return caches | ||
|
||
@property | ||
def layers(self): | ||
return self.model.layers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -96,6 +96,7 @@ def to_lora(layer): | |
"gemma2", | ||
"starcoder2", | ||
"cohere", | ||
"cohere2", | ||
"minicpm", | ||
"deepseek", | ||
"olmo2", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters