From d7c64c621a96ad77a39eb90a8c13fa3fda5f9c07 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 14 Dec 2024 16:15:24 +0100 Subject: [PATCH 01/11] add support for cohere2 --- llms/mlx_lm/models/cohere2.py | 207 ++++++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 llms/mlx_lm/models/cohere2.py diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py new file mode 100644 index 000000000..ae19f4d8a --- /dev/null +++ b/llms/mlx_lm/models/cohere2.py @@ -0,0 +1,207 @@ +# Copyright © 2023-2024 Apple Inc. + +from dataclasses import dataclass +from typing import Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int = 4096 + head_dim: int = 128 + num_hidden_layers: int = 32 + intermediate_size: int = 14336 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + rope_theta: float = 50000.0 + vocab_size: int = 256000 + layer_norm_eps: float = 1e-05 + logit_scale: float = 0.0625 + attention_bias: bool = False + layer_norm_bias: bool = False + sliding_window: int = 4096 + sliding_window_pattern: int = 4 + + +class LayerNorm2D(nn.Module): + + def __init__(self, d1, d2, eps): + super().__init__() + self.weight = mx.zeros((d1, d2)) + self.eps = eps + + def __call__(self, x): + return self.weight * mx.fast.layer_norm(x, None, None, self.eps) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.args = args + self.layer_idx = layer_idx + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + head_dim = args.hidden_size // args.num_attention_heads + self.scale = head_dim**-0.5 + + attetion_bias = args.attention_bias + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias) + + self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta) + + self.sliding_window = ( + args.sliding_window + if (layer_idx + 1) % args.sliding_window_pattern != 0 + else None + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + queries = queries.reshape(B, L, self.n_heads, -1) + keys = keys.reshape(B, L, self.n_kv_heads, -1) + + queries = queries.transpose(0, 2, 1, 3) + keys = keys.transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + # sliding window attention + if self.sliding_window is not None: + keys = keys[:, : -self.sliding_window :, :] + values = values[:, : -self.sliding_window :, :] + if mask is not None: + mask = mask[:, : -self.sliding_window, :] + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + + def __call__(self, x): + return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.hidden_size = args.hidden_size + self.n_heads = args.num_attention_heads + + self.self_attn = Attention(args, layer_idx) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = nn.LayerNorm( + args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + h = self.input_layernorm(x) + attn_h = self.self_attn(h, mask, cache) + ff_h = self.mlp(h) + return attn_h + ff_h + x + + +class CohereModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args, layer_idx=i) + for i in range(args.num_hidden_layers) + ] + self.norm = nn.LayerNorm( + args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias + ) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.model_type = args.model_type + self.model = CohereModel(args) + self.args = args + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.model(inputs, cache) + out = self.model.embed_tokens.as_linear(out) + out = out * self.model.args.logit_scale + return out + + @property + def layers(self): + return self.model.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_attention_heads + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads From 5d8b36ce7c208d1e722216da0003c6c48921e406 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 14 Dec 2024 16:22:00 +0100 Subject: [PATCH 02/11] revert to act_fn to silu --- llms/mlx_lm/models/cohere2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py index ae19f4d8a..b40686794 100644 --- a/llms/mlx_lm/models/cohere2.py +++ b/llms/mlx_lm/models/cohere2.py @@ -115,7 +115,7 @@ def __init__(self, dim, hidden_dim): self.down_proj = nn.Linear(hidden_dim, dim, bias=False) def __call__(self, x): - return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x)) + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) class TransformerBlock(nn.Module): From 52595dafae140960beeb3d6a0135eb7247725df1 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 14 Dec 2024 16:39:22 +0100 Subject: [PATCH 03/11] fix tests and sliding window attention --- llms/mlx_lm/models/cohere2.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py index b40686794..6b46ddc61 100644 --- a/llms/mlx_lm/models/cohere2.py +++ b/llms/mlx_lm/models/cohere2.py @@ -48,8 +48,12 @@ def __init__(self, args: ModelArgs, layer_idx: int): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads - - head_dim = args.hidden_size // args.num_attention_heads + self.head_dim = head_dim = args.head_dim + if (head_dim * n_heads) != dim: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {dim}" + f" and `num_heads`: {n_heads})." + ) self.scale = head_dim**-0.5 attetion_bias = args.attention_bias @@ -77,11 +81,8 @@ def __call__( queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - queries = queries.reshape(B, L, self.n_heads, -1) - keys = keys.reshape(B, L, self.n_kv_heads, -1) - - queries = queries.transpose(0, 2, 1, 3) - keys = keys.transpose(0, 2, 1, 3) + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) if cache is not None: @@ -94,10 +95,10 @@ def __call__( # sliding window attention if self.sliding_window is not None: - keys = keys[:, : -self.sliding_window :, :] - values = values[:, : -self.sliding_window :, :] + keys = keys[:, :, -self.sliding_window :, :] + values = values[:, :, -self.sliding_window :, :] if mask is not None: - mask = mask[:, : -self.sliding_window, :] + mask = mask[:, -self.sliding_window :] output = mx.fast.scaled_dot_product_attention( queries, keys, values, scale=self.scale, mask=mask @@ -200,7 +201,7 @@ def layers(self): @property def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads + return self.args.head_dim @property def n_kv_heads(self): From 2f443cc6d73ffadf2484d4367345fcf415eed350 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 14 Dec 2024 16:39:46 +0100 Subject: [PATCH 04/11] add tests --- llms/tests/test_models.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 374a51137..d6decb3fa 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -851,6 +851,19 @@ def test_exaone(self): model = exaone.Model(args) self.model_test_runner(model, args.model_type, args.vocab_size, args.num_layers) + def test_cohere2(self): + from mlx_lm.models import cohere2 + + args = cohere2.ModelArgs( + model_type="cohere2", + hidden_size=4096, + head_dim=128, + num_hidden_layers=40, + sliding_window=4096, + sliding_window_pattern=4, + ) + model = cohere2.Model(args) + self.model_test_runner(model, args.model_type, args.vocab_size, args.num_hidden_layers) if __name__ == "__main__": unittest.main() From 0337646b4efc819d5ef79934041ff2e8be1bbc7f Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 14 Dec 2024 16:39:55 +0100 Subject: [PATCH 05/11] add to tuner --- llms/mlx_lm/tuner/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 6821f4343..3986952a7 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -96,6 +96,7 @@ def to_lora(layer): "gemma2", "starcoder2", "cohere", + "cohere2", "minicpm", "deepseek", "olmo2", From d7d70487eb2ee5389840ebf414bd0d696d1aabc1 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 14 Dec 2024 17:06:57 +0100 Subject: [PATCH 06/11] fix sliding window --- llms/mlx_lm/models/cohere2.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py index 6b46ddc61..c6f3e8852 100644 --- a/llms/mlx_lm/models/cohere2.py +++ b/llms/mlx_lm/models/cohere2.py @@ -85,20 +85,25 @@ def __call__( keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + # Apply RoPE only if sliding window is enabled + if self.sliding_window is not None: + if cache is None: + queries = self.rope(queries) + keys = self.rope(keys) + else: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - # sliding window attention - if self.sliding_window is not None: - keys = keys[:, :, -self.sliding_window :, :] - values = values[:, :, -self.sliding_window :, :] - if mask is not None: - mask = mask[:, -self.sliding_window :] + + # Apply sliding window attention if enabled + if self.sliding_window is not None: + window_size = self.sliding_window + keys = keys[..., -window_size:, :] + values = values[..., -window_size:, :] + if mask is not None: + mask = mask[..., -window_size:] output = mx.fast.scaled_dot_product_attention( queries, keys, values, scale=self.scale, mask=mask From 406c7f300f411345a42841ee44f284383e6c9fdb Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 14 Dec 2024 17:12:49 +0100 Subject: [PATCH 07/11] add coauthor :) Co-authored-by: n8programs <43304488+N8python@users.noreply.github.com> --- llms/mlx_lm/models/cohere2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py index c6f3e8852..a078409b6 100644 --- a/llms/mlx_lm/models/cohere2.py +++ b/llms/mlx_lm/models/cohere2.py @@ -29,7 +29,6 @@ class ModelArgs(BaseModelArgs): class LayerNorm2D(nn.Module): - def __init__(self, d1, d2, eps): super().__init__() self.weight = mx.zeros((d1, d2)) From ac58a95fbd48c66a42af810bc00d5bdbab737e40 Mon Sep 17 00:00:00 2001 From: N8 Date: Sat, 14 Dec 2024 17:08:06 -0500 Subject: [PATCH 08/11] Add rotating kvcache to save space --- llms/mlx_lm/models/base.py | 4 +++- llms/mlx_lm/models/cohere2.py | 20 ++++++++++++++------ llms/mlx_lm/utils.py | 9 ++++----- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index f02f49b1a..3b5ddcb02 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -34,13 +34,15 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non return mask * -1e9 -def create_attention_mask(h: mx.array, cache: Optional[Any] = None): +def create_attention_mask(h: mx.array, cache: Optional[Any] = None, reference_idx: Optional[int] = None): T = h.shape[1] if T > 1: window_size = None offset = 0 if cache is not None and cache[0] is not None: c = cache[0] + if reference_idx is not None: + c = cache[reference_idx] if hasattr(c, "max_size"): offset = min(c.max_size, c.offset) window_size = c.max_size diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py index a078409b6..a2854d190 100644 --- a/llms/mlx_lm/models/cohere2.py +++ b/llms/mlx_lm/models/cohere2.py @@ -6,8 +6,8 @@ import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask - +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .cache import KVCache, RotatingKVCache @dataclass class ModelArgs(BaseModelArgs): @@ -95,7 +95,6 @@ def __call__( if cache is not None: keys, values = cache.update_and_fetch(keys, values) - # Apply sliding window attention if enabled if self.sliding_window is not None: window_size = self.sliding_window @@ -104,8 +103,8 @@ def __call__( if mask is not None: mask = mask[..., -window_size:] - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) @@ -171,7 +170,7 @@ def __call__( ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + mask = create_attention_mask(h, cache, reference_idx=self.args.sliding_window_pattern - 1) if cache is None: cache = [None] * len(self.layers) @@ -198,6 +197,15 @@ def __call__( out = self.model.embed_tokens.as_linear(out) out = out * self.model.args.logit_scale return out + + def make_cache(self): + caches = [] + for i in range(self.args.num_hidden_layers): + if i % self.args.sliding_window_pattern == self.args.sliding_window_pattern - 1: + caches.append(KVCache()) + else: + caches.append(RotatingKVCache(max_size=self.args.sliding_window, keep=0)) + return caches @property def layers(self): diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index b87f5a241..10292d753 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -187,11 +187,10 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_ and prompt_cache[0].offset > quantized_kv_start ): for i in range(len(prompt_cache)): - prompt_cache[i] = prompt_cache[i].to_quantized( - group_size=kv_group_size, bits=kv_bits - ) - - + if isinstance(prompt_cache[i], cache.KVCache): + prompt_cache[i] = prompt_cache[i].to_quantized( + group_size=kv_group_size, bits=kv_bits + ) def generate_step( prompt: mx.array, model: nn.Module, From 4aee86243ec98d9fc88560f6b388417501f9e709 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 16 Dec 2024 07:45:15 -0800 Subject: [PATCH 09/11] some nits --- llms/mlx_lm/models/base.py | 4 +-- llms/mlx_lm/models/cohere2.py | 51 ++++++++++++++++------------------- 2 files changed, 24 insertions(+), 31 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 3b5ddcb02..f02f49b1a 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -34,15 +34,13 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non return mask * -1e9 -def create_attention_mask(h: mx.array, cache: Optional[Any] = None, reference_idx: Optional[int] = None): +def create_attention_mask(h: mx.array, cache: Optional[Any] = None): T = h.shape[1] if T > 1: window_size = None offset = 0 if cache is not None and cache[0] is not None: c = cache[0] - if reference_idx is not None: - c = cache[reference_idx] if hasattr(c, "max_size"): offset = min(c.max_size, c.offset) window_size = c.max_size diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py index a2854d190..392c7633d 100644 --- a/llms/mlx_lm/models/cohere2.py +++ b/llms/mlx_lm/models/cohere2.py @@ -6,9 +6,10 @@ import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention from .cache import KVCache, RotatingKVCache + @dataclass class ModelArgs(BaseModelArgs): model_type: str @@ -28,16 +29,6 @@ class ModelArgs(BaseModelArgs): sliding_window_pattern: int = 4 -class LayerNorm2D(nn.Module): - def __init__(self, d1, d2, eps): - super().__init__() - self.weight = mx.zeros((d1, d2)) - self.eps = eps - - def __call__(self, x): - return self.weight * mx.fast.layer_norm(x, None, None, self.eps) - - class Attention(nn.Module): def __init__(self, args: ModelArgs, layer_idx: int): super().__init__() @@ -64,11 +55,7 @@ def __init__(self, args: ModelArgs, layer_idx: int): self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta) - self.sliding_window = ( - args.sliding_window - if (layer_idx + 1) % args.sliding_window_pattern != 0 - else None - ) + self.use_sliding_window = (layer_idx + 1) % args.sliding_window_pattern != 0 def __call__( self, @@ -85,7 +72,7 @@ def __call__( values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) # Apply RoPE only if sliding window is enabled - if self.sliding_window is not None: + if self.use_sliding_window: if cache is None: queries = self.rope(queries) keys = self.rope(keys) @@ -95,13 +82,11 @@ def __call__( if cache is not None: keys, values = cache.update_and_fetch(keys, values) - # Apply sliding window attention if enabled - if self.sliding_window is not None: - window_size = self.sliding_window - keys = keys[..., -window_size:, :] - values = values[..., -window_size:, :] - if mask is not None: - mask = mask[..., -window_size:] + + if self.use_sliding_window and mask is not None: + key_len = keys.shape[-2] + if mask.shape[-1] != key_len: + mask = mask[..., -key_len:] output = scaled_dot_product_attention( queries, keys, values, cache=cache, scale=self.scale, mask=mask @@ -170,7 +155,12 @@ def __call__( ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache, reference_idx=self.args.sliding_window_pattern - 1) + T = h.shape[1] + if T > 1: + offset = cache[0].offset if cache else 0 + mask = create_causal_mask(T, offset).astype(h.dtype) + else: + mask = None if cache is None: cache = [None] * len(self.layers) @@ -197,14 +187,19 @@ def __call__( out = self.model.embed_tokens.as_linear(out) out = out * self.model.args.logit_scale return out - + def make_cache(self): caches = [] for i in range(self.args.num_hidden_layers): - if i % self.args.sliding_window_pattern == self.args.sliding_window_pattern - 1: + if ( + i % self.args.sliding_window_pattern + == self.args.sliding_window_pattern - 1 + ): caches.append(KVCache()) else: - caches.append(RotatingKVCache(max_size=self.args.sliding_window, keep=0)) + caches.append( + RotatingKVCache(max_size=self.args.sliding_window, keep=0) + ) return caches @property From dec2acfaceaf4d24aa8989cac2d9d079d5e8f353 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 16 Dec 2024 07:53:07 -0800 Subject: [PATCH 10/11] style --- llms/mlx_lm/utils.py | 2 ++ llms/tests/test_models.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 10292d753..4d69115e0 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -191,6 +191,8 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_ prompt_cache[i] = prompt_cache[i].to_quantized( group_size=kv_group_size, bits=kv_bits ) + + def generate_step( prompt: mx.array, model: nn.Module, diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index d6decb3fa..3097c5225 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -863,7 +863,10 @@ def test_cohere2(self): sliding_window_pattern=4, ) model = cohere2.Model(args) - self.model_test_runner(model, args.model_type, args.vocab_size, args.num_hidden_layers) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + if __name__ == "__main__": unittest.main() From 799dfde8a38e34a8da2e6858dbd6531f792a3d12 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 16 Dec 2024 07:55:06 -0800 Subject: [PATCH 11/11] nits --- llms/mlx_lm/models/cohere2.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py index 392c7633d..fcb4061b1 100644 --- a/llms/mlx_lm/models/cohere2.py +++ b/llms/mlx_lm/models/cohere2.py @@ -205,11 +205,3 @@ def make_cache(self): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.head_dim - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads