From ec690122aa7e422dade76d9fd4134a26f510b3df Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 19 Jul 2024 21:48:29 +0200 Subject: [PATCH 01/40] initial commit --- llms/mlx_lm/models/mamba.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 llms/mlx_lm/models/mamba.py diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py new file mode 100644 index 000000000..e69de29bb From 09fc3aeaec34a03162312afc2b130e99266da59a Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 19 Jul 2024 21:49:55 +0200 Subject: [PATCH 02/40] initial commit --- llms/mlx_lm/tuner/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index fe9740f5e..e8666c0a0 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -99,6 +99,7 @@ def to_lora(layer): "starcoder2", "cohere", "minicpm", + "mamba", ]: keys = set(["self_attn.q_proj", "self_attn.v_proj"]) if model.model_type == "mixtral": From fecfd1cb72e07cffc7c8832813f699523720d7bc Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 19 Jul 2024 22:32:34 +0200 Subject: [PATCH 03/40] Adding first lines --- llms/mlx_lm/models/mamba.py | 80 +++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index e69de29bb..aaecd091c 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -0,0 +1,80 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from base import BaseModelArgs, KVCache, create_additive_causal_mask + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str = "mamba" + d_model: int = 12 # hidden_size + d_inner: int = 2 + vocab_size: int = 623 + n_layer: int = 3# num_hidden_layers + tie_word_embeddings: bool = False + use_bias: bool = False + use_conv_bias: bool = False + conv_kernel: int = 4 + + +class DepthWiseConv1d(nn.Module): + def __init__(self, channels, kernel_size, bias, padding): + super().__init__() + + +class MambaBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.in_proj = nn.Linear(args.d_model, 2 * args.d_inner, bias=args.use_bias) + self.conv1d = DepthWiseConv1d(channels=args.d_inner, kernel_size=args.conv_kernel, bias=args.use_conv_bias, padding=args.conv_kernel-1) + + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.mixer = MambaBlock(args) + self.norm = nn.RMSNorm(args.d_model) + + def __call__(self, inputs: mx.array, cache=None): + output, cache = self.mixer(self.norm(inputs), cache) + output = output + inputs + return output, cache + + +class Mamba(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.embedding = nn.Embedding(args.vocab_size, args.d_model) + self.layers = [ResidualBlock(args) for _ in range(args.n_layer)] + self.norm_f = nn.RMSNorm(args.d_model) + + def __call__(self, inputs: mx.array, cache=None): + tokens = self.embedding(inputs) + + for i, layer in enumerate(self.layers): + h, cache[i] = layer(tokens, cache[i]) + + h = self.norm_f(h) + return h, cache + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.backbone = Mamba(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False) + + def __call__(self, inputs: mx.array, cache=None): + out = self.backbone(inputs, cache) + + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + return out + + +model = Model(ModelArgs()) +print(model) From abab1d0969faabbacc8b0163aeb89545989290e4 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 19 Jul 2024 22:42:29 +0200 Subject: [PATCH 04/40] adding x, and dt projection layers --- llms/mlx_lm/models/mamba.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index aaecd091c..8c80c9163 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -1,6 +1,8 @@ from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union +import math + import mlx.core as mx import mlx.nn as nn @@ -10,6 +12,7 @@ @dataclass class ModelArgs(BaseModelArgs): model_type: str = "mamba" + dt_rank: Union[int, str] = "auto" d_model: int = 12 # hidden_size d_inner: int = 2 vocab_size: int = 623 @@ -18,6 +21,14 @@ class ModelArgs(BaseModelArgs): use_bias: bool = False use_conv_bias: bool = False conv_kernel: int = 4 + state_size: int = 16 + expand: int = 2 + + def __post_init__(self): + self.d_inner = self.expand * self.d_model + + if self.dt_rank == 'auto': + self.dt_rank = math.ceil(self.d_model / 16) class DepthWiseConv1d(nn.Module): @@ -30,6 +41,10 @@ def __init__(self, args: ModelArgs): super().__init__() self.in_proj = nn.Linear(args.d_model, 2 * args.d_inner, bias=args.use_bias) self.conv1d = DepthWiseConv1d(channels=args.d_inner, kernel_size=args.conv_kernel, bias=args.use_conv_bias, padding=args.conv_kernel-1) + self.x_proj = nn.Linear(args.d_inner, args.dt_rank + 2 * args.state_size, bias=False) + self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True) + + dt_init_std = args.dt_rank**-0.5 * args.state_size class ResidualBlock(nn.Module): From 80d6b4d158dc71f6108e20142474604daa47842a Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 19 Jul 2024 22:53:10 +0200 Subject: [PATCH 05/40] adding the clamping mechanism --- llms/mlx_lm/models/mamba.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 8c80c9163..848a045fb 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -23,6 +23,10 @@ class ModelArgs(BaseModelArgs): conv_kernel: int = 4 state_size: int = 16 expand: int = 2 + time_step_init_scheme: str = "random" + time_step_max: float = 0.1 + time_step_min: float = 0.001 + time_step_floor: float = 0.0001 def __post_init__(self): self.d_inner = self.expand * self.d_model @@ -31,6 +35,19 @@ def __post_init__(self): self.dt_rank = math.ceil(self.d_model / 16) +def clamp(x, min=None, max=None): + if min is not None: + mask_lower = x < min + if max is not None: + mask_upper = x > max + + if min is not None: + if max is not None: + return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) + return mx.where(mask_lower, min, x) + + return mx.where(mask_upper, max, x) + class DepthWiseConv1d(nn.Module): def __init__(self, channels, kernel_size, bias, padding): super().__init__() @@ -45,6 +62,16 @@ def __init__(self, args: ModelArgs): self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True) dt_init_std = args.dt_rank**-0.5 * args.state_size + if args.time_step_init_scheme == "constant": + self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) + elif args.time_step_init_scheme == "random": + self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) + else: + raise NotImplementedError + + dt = clamp(mx.exp(mx.random.uniform(shape=[args.d_inner]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)), min=args.time_step_floor) + inv_dt = dt + mx.log1p(-mx.exp(-dt)) + self.dt_proj.bias = inv_dt class ResidualBlock(nn.Module): From 5e4489faae34f853124f1a584c76fcf106dab005 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 19 Jul 2024 23:33:17 +0200 Subject: [PATCH 06/40] First succesful inference --- llms/mlx_lm/models/mamba.py | 131 ++++++++++++++++++++++++++++++++++-- 1 file changed, 127 insertions(+), 4 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 848a045fb..dbcc3ce5f 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -16,7 +16,7 @@ class ModelArgs(BaseModelArgs): d_model: int = 12 # hidden_size d_inner: int = 2 vocab_size: int = 623 - n_layer: int = 3# num_hidden_layers + n_layer: int = 3 # num_hidden_layers tie_word_embeddings: bool = False use_bias: bool = False use_conv_bias: bool = False @@ -27,6 +27,7 @@ class ModelArgs(BaseModelArgs): time_step_max: float = 0.1 time_step_min: float = 0.001 time_step_floor: float = 0.0001 + pscan: bool = False def __post_init__(self): self.d_inner = self.expand * self.d_model @@ -48,14 +49,39 @@ def clamp(x, min=None, max=None): return mx.where(mask_upper, max, x) +def unsqueeze(x, axis): + assert axis <= len(x.shape) + if axis >= 0: + new_shape = x.shape[:axis] + tuple([1]) + x.shape[axis:] + else: + new_shape = x.shape + tuple([1]) + return x.reshape(new_shape) + + class DepthWiseConv1d(nn.Module): def __init__(self, channels, kernel_size, bias, padding): super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.bias = bias + self.padding = padding + + self.conv1d = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, bias=True, padding=padding) + + indices = mx.arange(channels) + mask = mx.zeros_like(self.conv1d.weight) + mask[indices, :, indices] = 1 + self.conv1d.weight *= mask + + def __call__(self, x): + return self.conv1d(x) + class MambaBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() + self.args = args self.in_proj = nn.Linear(args.d_model, 2 * args.d_inner, bias=args.use_bias) self.conv1d = DepthWiseConv1d(channels=args.d_inner, kernel_size=args.conv_kernel, bias=args.use_conv_bias, padding=args.conv_kernel-1) self.x_proj = nn.Linear(args.d_inner, args.dt_rank + 2 * args.state_size, bias=False) @@ -70,8 +96,98 @@ def __init__(self, args: ModelArgs): raise NotImplementedError dt = clamp(mx.exp(mx.random.uniform(shape=[args.d_inner]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)), min=args.time_step_floor) - inv_dt = dt + mx.log1p(-mx.exp(-dt)) - self.dt_proj.bias = inv_dt + self.dt_proj.bias = dt + mx.log1p(-mx.exp(-dt)) + + A = mx.repeat(mx.arange(1., 16 + 1.).reshape([1, 16]), repeats=args.d_inner, axis=0) + self.A_log = mx.log(A) + self.D = mx.ones([args.d_inner]) + + self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.use_bias) + + + def ssm_step(self, x, h): + A = -mx.exp(self.A_log) + D = self.D + deltaBC = self.x_proj(x) + delta, B, C = mx.split(deltaBC, indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) + delta = nn.softplus(self.dt_proj(delta)) + deltaA = mx.exp(unsqueeze(delta, -1) * A) + deltaB = unsqueeze(delta, -1) * unsqueeze(B, 1) + BX = deltaB * unsqueeze(x, -1) + if h is None: + h = mx.zeros([x.shape[0], self.args.d_inner, self.args.d_state]) + h = deltaA * h + BX + y = (h @ unsqueeze(C, -1)).squeeze(2) + y = y + D * x + return y, h + + def ssm(self, x): + A = -mx.exp(self.A_log) # (ED, N) + D = self.D + + deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N) + + delta, B, C = mx.split(deltaBC, indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) # (B, L, dt_rank), (B, L, N), (B, L, N) + delta = nn.softplus(self.dt_proj(delta)) # (B, L, ED) + if self.args.pscan: + y = self.selective_scan(x, delta, A, B, C, D) + else: + y = self.selective_scan_seq(x, delta, A, B, C, D) + return y + + + def selective_scan(self, x, delta, A, B, C, D): + deltaA = mx.exp(unsqueeze(delta, -1) * A) # (B, L, ED, N) + deltaB = unsqueeze(delta, -1) * unsqueeze(B, 2) # (B, L, ED, N) + + BX = deltaB * unsqueeze(x, -1) # (B, L, ED, N) + + hs = pscan(deltaA, BX) + + y = (hs @ unsqueeze(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) + + y = y + D * x + + return y + + def selective_scan_seq(self, x, delta, A, B, C, D): + _, L, _ = x.shape + + deltaA = mx.exp(unsqueeze(delta, -1) * A) # (B, L, ED, N) + deltaB = unsqueeze(delta, -1) * unsqueeze(B, 2) # (B, L, ED, N) + + BX = deltaB * unsqueeze(x, -1) # (B, L, ED, N) + + h = mx.zeros([x.shape[0], self.args.d_inner, self.args.state_size]) # (B, ED, N) + hs = [] + + for t in range(0, L): + h = deltaA[:, t] * h + BX[:, t] + hs.append(h) + + hs = mx.stack(hs, axis=1) + + y = (hs @ unsqueeze(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) + + y = y + D * x + + return y + + + def __call__(self, inputs: mx.array, cache = None): + _, L, _ = inputs.shape + + if cache is not None: + h, inputs = cache + + x, z = self.in_proj(inputs).split(indices_or_sections=2, axis=2) + x_cache = unsqueeze(x, 1) + # x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.args.conv_kernel-1, :] + x = self.conv1d(x)[:, :L, :] + # y, h = self.ssm_step(nn.silu(x), h) + output = self.ssm(nn.silu(x)) * nn.silu(z) + # inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) + return self.out_proj(output), None # (h, inputs) class ResidualBlock(nn.Module): @@ -80,7 +196,7 @@ def __init__(self, args: ModelArgs): self.mixer = MambaBlock(args) self.norm = nn.RMSNorm(args.d_model) - def __call__(self, inputs: mx.array, cache=None): + def __call__(self, inputs: mx.array, cache: Optional[mx.array] = None): output, cache = self.mixer(self.norm(inputs), cache) output = output + inputs return output, cache @@ -96,6 +212,9 @@ def __init__(self, args: ModelArgs): def __call__(self, inputs: mx.array, cache=None): tokens = self.embedding(inputs) + if cache is None: + cache = [None] * len(self.layers) + for i, layer in enumerate(self.layers): h, cache[i] = layer(tokens, cache[i]) @@ -106,6 +225,7 @@ def __call__(self, inputs: mx.array, cache=None): class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() + self.args = args self.backbone = Mamba(args) if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False) @@ -120,3 +240,6 @@ def __call__(self, inputs: mx.array, cache=None): model = Model(ModelArgs()) print(model) + +logits = model(mx.array([[3, 3, 3]])) +print(logits) From aa23983b2099764953420c8998369436aec7dcc9 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 19 Jul 2024 23:44:52 +0200 Subject: [PATCH 07/40] last commit for today - added custom geenrate function and it works as expected, will try training and then with loading a model from the hub --- llms/mlx_lm/models/mamba.py | 67 ++++++++++++++++++++++++++++++++++--- 1 file changed, 62 insertions(+), 5 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index dbcc3ce5f..7ea6507d8 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -115,7 +115,7 @@ def ssm_step(self, x, h): deltaB = unsqueeze(delta, -1) * unsqueeze(B, 1) BX = deltaB * unsqueeze(x, -1) if h is None: - h = mx.zeros([x.shape[0], self.args.d_inner, self.args.d_state]) + h = mx.zeros([x.shape[0], self.args.d_inner, self.args.state_size]) h = deltaA * h + BX y = (h @ unsqueeze(C, -1)).squeeze(2) y = y + D * x @@ -189,6 +189,31 @@ def __call__(self, inputs: mx.array, cache = None): # inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) return self.out_proj(output), None # (h, inputs) + def step(self, x, cache): + h, inputs = cache + + xz = self.in_proj(x) # (B, 2*ED) + x, z = xz.split(indices_or_sections=2, axis=1) # (B, ED), (B, ED) + + # x branch + x_cache = unsqueeze(x, 1) + x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.args.conv_kernel-1, :] # (B, ED) + + x = nn.silu(x) + y, h = self.ssm_step(x, h) + + # z branch + z = nn.silu(z) + + output = y * z + output = self.out_proj(output) # (B, D) + + # prepare cache for next call + inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, conv_kernel-1, ED) + cache = (h, inputs) + + return output, cache + class ResidualBlock(nn.Module): def __init__(self, args: ModelArgs): @@ -197,7 +222,7 @@ def __init__(self, args: ModelArgs): self.norm = nn.RMSNorm(args.d_model) def __call__(self, inputs: mx.array, cache: Optional[mx.array] = None): - output, cache = self.mixer(self.norm(inputs), cache) + output, cache = self.mixer.step(self.norm(inputs), cache) output = output + inputs return output, cache @@ -231,15 +256,47 @@ def __init__(self, args: ModelArgs): self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False) def __call__(self, inputs: mx.array, cache=None): - out = self.backbone(inputs, cache) + out, cache = self.backbone(inputs, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) - return out + return out, cache + + + def generate(self, tokenizer=None, prompt: str="Hello", n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): + self.eval() + + input_ids = mx.array([[3, 3, 3]]) # mx.array(tokenizer(prompt, return_tensors='np').input_ids) # (1, tokens_prompt) # (1, num_tokens) + + caches = [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.d_inner])) for _ in range(self.args.n_layer)] + + for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): + next_token_logits, caches = self(input_ids[:, i], caches) # (1, vocab_size), caches + + # sample (no sampling when the prompt is being processed) + if i+1 >= input_ids.shape[1]: + + if top_k is not None: + values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to biggest + mask = next_token_logits < (values[:, 0, None]) + next_token_logits = mx.where(mask, -5000, next_token_logits) # TODO -mx.inf is problematic for now + + if sample and temperature > 0: + next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) + else: + next_token = mx.argmax(next_token_logits, axis=-1)[:, None] + + input_ids = mx.concatenate([input_ids, next_token], axis=1) + + # output = [tokenizer.decode(output.tolist()) for output in input_ids][0] + + self.train() + + return next_token # output model = Model(ModelArgs()) print(model) -logits = model(mx.array([[3, 3, 3]])) +logits = model.generate() print(logits) From ec5503b64ed965232f01d4c8e5a617b63bd5df0e Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sun, 21 Jul 2024 17:12:54 +0200 Subject: [PATCH 08/40] clean up --- llms/mlx_lm/models/mamba.py | 293 +++++++++++++++++++++--------------- llms/mlx_lm/tuner/utils.py | 2 +- llms/mlx_lm/utils.py | 2 +- 3 files changed, 172 insertions(+), 125 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 7ea6507d8..0ac8e7a44 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -3,52 +3,102 @@ import math +import torch + import mlx.core as mx import mlx.nn as nn -from base import BaseModelArgs, KVCache, create_additive_causal_mask +from .base import BaseModelArgs @dataclass class ModelArgs(BaseModelArgs): - model_type: str = "mamba" - dt_rank: Union[int, str] = "auto" - d_model: int = 12 # hidden_size - d_inner: int = 2 - vocab_size: int = 623 - n_layer: int = 3 # num_hidden_layers - tie_word_embeddings: bool = False - use_bias: bool = False - use_conv_bias: bool = False - conv_kernel: int = 4 - state_size: int = 16 - expand: int = 2 - time_step_init_scheme: str = "random" - time_step_max: float = 0.1 - time_step_min: float = 0.001 - time_step_floor: float = 0.0001 + model_type: str + dt_rank: Union[int, str] + d_model: int + d_inner: int + vocab_size: int + n_layer: int + use_bias: bool + use_conv_bias: bool + conv_kernel: int + state_size: int + expand: int + time_step_init_scheme: str + time_step_max: float + time_step_min: float + time_step_floor: float pscan: bool = False + tie_word_embeddings: bool = False + num_hidden_layers: int = None + hidden_size: int = None def __post_init__(self): self.d_inner = self.expand * self.d_model - + if self.n_layer is None: + self.n_layer = self.num_hidden_layers + if self.d_model is None: + self.d_model = self.hidden_size if self.dt_rank == 'auto': self.dt_rank = math.ceil(self.d_model / 16) +def pscan_f(A, X): + Aa = A + Xa = X + B, D, L, _ = A.shape + num_steps = int(math.log2(L)) + + for k in range(num_steps): + T = 2 * (Xa.shape[2] // 2) + Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) + Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) + Xa[:, :, :, 1] += Aa[:, :, :, 1] * Xa[:, :, :, 0] + Aa[:, :, :, 1] *= Aa[:, :, :, 0] + A[:, :, 2**(k+1)-1::2**(k+1)] = Aa[:, :, :, 1] + X[:, :, 2**(k+1)-1::2**(k+1)] = Xa[:, :, :, 1] + Aa = Aa[:, :, :, 1] + Xa = Xa[:, :, :, 1] + + for k in range(num_steps-1, -1, -1): + Aa = A[:, :, 2**k-1::2**k] + Xa = X[:, :, 2**k-1::2**k] + step_len = Xa.shape[2] + T = 2 * (step_len // 2) + if T < step_len: + last_val_aa = Aa[:, :, -1] * Aa[:, :, -2] + last_val_xa = Xa[:, :, -1] + Aa[:, :, -1] * Xa[:, :, -2] + Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) + Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) + Xa[:, :, 1:, 0] += Aa[:, :, 1:, 0] * Xa[:, :, :-1, 1] + Aa[:, :, 1:, 0] *= Aa[:, :, :-1, 1] + if T == step_len: + A[:, :, 2**k-1::2**(k+1)] = Aa[:, :, :, 0] + X[:, :, 2**k-1::2**(k+1)] = Xa[:, :, :, 0] + else: + A[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Aa[:, :, :, 0], mx.array([last_val_aa]).reshape(B, D, 1, -1)], axis=2) + X[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Xa[:, :, :, 0], mx.array([last_val_xa]).reshape(B, D, 1, -1)], axis=2) + + +def pscan(A_in, X_in): + A = A_in[:].transpose(0, 2, 1, 3) + X = X_in[:].transpose(0, 2, 1, 3) + pscan_f(A, X) + return X.transpose(0, 2, 1, 3) + + def clamp(x, min=None, max=None): if min is not None: mask_lower = x < min if max is not None: mask_upper = x > max - if min is not None: if max is not None: return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) return mx.where(mask_lower, min, x) - return mx.where(mask_upper, max, x) + def unsqueeze(x, axis): assert axis <= len(x.shape) if axis >= 0: @@ -61,14 +111,11 @@ def unsqueeze(x, axis): class DepthWiseConv1d(nn.Module): def __init__(self, channels, kernel_size, bias, padding): super().__init__() - self.channels = channels self.kernel_size = kernel_size self.bias = bias self.padding = padding - self.conv1d = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, bias=True, padding=padding) - indices = mx.arange(channels) mask = mx.zeros_like(self.conv1d.weight) mask[indices, :, indices] = 1 @@ -97,14 +144,12 @@ def __init__(self, args: ModelArgs): dt = clamp(mx.exp(mx.random.uniform(shape=[args.d_inner]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)), min=args.time_step_floor) self.dt_proj.bias = dt + mx.log1p(-mx.exp(-dt)) - A = mx.repeat(mx.arange(1., 16 + 1.).reshape([1, 16]), repeats=args.d_inner, axis=0) self.A_log = mx.log(A) self.D = mx.ones([args.d_inner]) self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.use_bias) - def ssm_step(self, x, h): A = -mx.exp(self.A_log) D = self.D @@ -121,98 +166,70 @@ def ssm_step(self, x, h): y = y + D * x return y, h - def ssm(self, x): - A = -mx.exp(self.A_log) # (ED, N) + def ssm(self, x): # DONE + A = -mx.exp(self.A_log) D = self.D - - deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N) - - delta, B, C = mx.split(deltaBC, indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) # (B, L, dt_rank), (B, L, N), (B, L, N) - delta = nn.softplus(self.dt_proj(delta)) # (B, L, ED) + delta, B, C = self.x_proj(x).split(indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) + delta = nn.softplus(self.dt_proj(delta)) if self.args.pscan: y = self.selective_scan(x, delta, A, B, C, D) else: y = self.selective_scan_seq(x, delta, A, B, C, D) return y - - def selective_scan(self, x, delta, A, B, C, D): - deltaA = mx.exp(unsqueeze(delta, -1) * A) # (B, L, ED, N) - deltaB = unsqueeze(delta, -1) * unsqueeze(B, 2) # (B, L, ED, N) - - BX = deltaB * unsqueeze(x, -1) # (B, L, ED, N) - + def selective_scan(self, x, delta, A, B, C, D): # DONE + deltaA = mx.exp(unsqueeze(delta, -1) * A) + deltaB = unsqueeze(delta, -1) * unsqueeze(B, 2) + BX = deltaB * unsqueeze(x, -1) hs = pscan(deltaA, BX) - - y = (hs @ unsqueeze(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) - - y = y + D * x - - return y + y = (hs @ unsqueeze(C, -1)).squeeze(3) + return y + D * x def selective_scan_seq(self, x, delta, A, B, C, D): _, L, _ = x.shape - - deltaA = mx.exp(unsqueeze(delta, -1) * A) # (B, L, ED, N) - deltaB = unsqueeze(delta, -1) * unsqueeze(B, 2) # (B, L, ED, N) - - BX = deltaB * unsqueeze(x, -1) # (B, L, ED, N) - - h = mx.zeros([x.shape[0], self.args.d_inner, self.args.state_size]) # (B, ED, N) + deltaA = mx.exp(unsqueeze(delta, -1) * A) + deltaB = unsqueeze(delta, -1) * unsqueeze(B, 2) + BX = deltaB * unsqueeze(x, -1) + h = mx.zeros([x.shape[0], self.args.d_inner, self.args.state_size]) hs = [] - for t in range(0, L): h = deltaA[:, t] * h + BX[:, t] hs.append(h) - hs = mx.stack(hs, axis=1) + y = (hs @ unsqueeze(C, -1)).squeeze(3) + return y + D * x - y = (hs @ unsqueeze(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) - + def step(self, x, cache): # Done + h, inputs = cache + x, z = self.in_proj(x).split(indices_or_sections=2, axis=1) + x_cache = unsqueeze(x, 1) + x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.args.conv_kernel-1, :] + y, h = self.ssm_step(nn.silu(x), h) + output = y * nn.silu(z) + output = self.out_proj(output) + inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) + return output, (h, inputs) + + def ssm_step(self, x, h): # Done + A = -mx.exp(self.A_log) + D = self.D + delta, B, C = self.x_proj(x).split(indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) + delta = nn.softplus(self.dt_proj(delta)) + deltaB = unsqueeze(delta, -1) * unsqueeze(B, 1) + BX = deltaB * unsqueeze(x, -1) + if h is None: + h = mx.zeros([x.shape[0], self.args.d_inner, self.args.d_state]) + h = deltaA * h + BX + y = (h @ unsqueeze(C, -1)).squeeze(2) y = y + D * x + return y, h - return y - - - def __call__(self, inputs: mx.array, cache = None): - _, L, _ = inputs.shape - - if cache is not None: - h, inputs = cache - - x, z = self.in_proj(inputs).split(indices_or_sections=2, axis=2) - x_cache = unsqueeze(x, 1) - # x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.args.conv_kernel-1, :] + def __call__(self, x): # DONE + _, L, _ = x.shape + x, z = self.in_proj(x).split(indices_or_sections=2, axis=2) x = self.conv1d(x)[:, :L, :] - # y, h = self.ssm_step(nn.silu(x), h) output = self.ssm(nn.silu(x)) * nn.silu(z) - # inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) - return self.out_proj(output), None # (h, inputs) - - def step(self, x, cache): - h, inputs = cache - - xz = self.in_proj(x) # (B, 2*ED) - x, z = xz.split(indices_or_sections=2, axis=1) # (B, ED), (B, ED) - - # x branch - x_cache = unsqueeze(x, 1) - x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.args.conv_kernel-1, :] # (B, ED) - - x = nn.silu(x) - y, h = self.ssm_step(x, h) - - # z branch - z = nn.silu(z) - - output = y * z - output = self.out_proj(output) # (B, D) - - # prepare cache for next call - inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, conv_kernel-1, ED) - cache = (h, inputs) - - return output, cache + return self.out_proj(output) class ResidualBlock(nn.Module): @@ -236,13 +253,10 @@ def __init__(self, args: ModelArgs): def __call__(self, inputs: mx.array, cache=None): tokens = self.embedding(inputs) - if cache is None: cache = [None] * len(self.layers) - for i, layer in enumerate(self.layers): h, cache[i] = layer(tokens, cache[i]) - h = self.norm_f(h) return h, cache @@ -262,41 +276,74 @@ def __call__(self, inputs: mx.array, cache=None): out = self.model.embed_tokens.as_linear(out) return out, cache + def torch_to_mlx_depthwise_weights(self, torch_weights): + torch_weights = torch_weights.transpose(2, 1) + channels, kernel_size, _ = torch_weights.shape + + mlx_weights = torch.zeros(channels, kernel_size, channels) + + indices = torch.arange(channels) + if torch_weights[:, :, 0].type() == 'torch.BFloat16Tensor': + mlx_weights[indices, :, indices] = torch_weights[:, :, 0].float() + else: + mlx_weights[indices, :, indices] = torch_weights[:, :, 0] + + return mlx_weights + + def sanitize(self, torch_state_dict): + new_state_dict = {} + for key, value in torch_state_dict.items(): + if 'conv1d.weight' in key: + value = self.torch_to_mlx_depthwise_weights(value) + + if 'conv1d' in key: + key = key.replace('conv1d', 'conv1d.conv1d') + + if value.type() == 'torch.BFloat16Tensor': + new_state_dict[key] = value.half().numpy() + else: + new_state_dict[key] = value.numpy() + + return new_state_dict + + @property + def layers(self): + return self.model.layers - def generate(self, tokenizer=None, prompt: str="Hello", n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): - self.eval() - input_ids = mx.array([[3, 3, 3]]) # mx.array(tokenizer(prompt, return_tensors='np').input_ids) # (1, tokens_prompt) # (1, num_tokens) +# def generate(self, tokenizer=None, prompt: str="Hello", n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): +# self.eval() - caches = [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.d_inner])) for _ in range(self.args.n_layer)] +# input_ids = mx.array([[3, 3, 3]]) # mx.array(tokenizer(prompt, return_tensors='np').input_ids) # (1, tokens_prompt) # (1, num_tokens) - for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): - next_token_logits, caches = self(input_ids[:, i], caches) # (1, vocab_size), caches +# caches = [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.d_inner])) for _ in range(self.args.n_layer)] - # sample (no sampling when the prompt is being processed) - if i+1 >= input_ids.shape[1]: +# for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): +# next_token_logits, caches = self(input_ids[:, i], caches) # (1, vocab_size), caches - if top_k is not None: - values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to biggest - mask = next_token_logits < (values[:, 0, None]) - next_token_logits = mx.where(mask, -5000, next_token_logits) # TODO -mx.inf is problematic for now +# # sample (no sampling when the prompt is being processed) +# if i+1 >= input_ids.shape[1]: - if sample and temperature > 0: - next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) - else: - next_token = mx.argmax(next_token_logits, axis=-1)[:, None] +# if top_k is not None: +# values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to biggest +# mask = next_token_logits < (values[:, 0, None]) +# next_token_logits = mx.where(mask, -5000, next_token_logits) # TODO -mx.inf is problematic for now - input_ids = mx.concatenate([input_ids, next_token], axis=1) +# if sample and temperature > 0: +# next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) +# else: +# next_token = mx.argmax(next_token_logits, axis=-1)[:, None] - # output = [tokenizer.decode(output.tolist()) for output in input_ids][0] +# input_ids = mx.concatenate([input_ids, next_token], axis=1) - self.train() +# # output = [tokenizer.decode(output.tolist()) for output in input_ids][0] - return next_token # output +# self.train() +# return next_token # output -model = Model(ModelArgs()) -print(model) +# model = Model(ModelArgs()) +# print(model) -logits = model.generate() -print(logits) +# logits = model.generate() +# print(logits) diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index fbfa35dfb..6ed888664 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -99,7 +99,7 @@ def to_lora(layer): "starcoder2", "cohere", "minicpm", - "mamba", + "mamba" ]: keys = set(["self_attn.q_proj", "self_attn.v_proj"]) if model.model_type == "mixtral": diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 229ee2381..8ad4e0f28 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -28,7 +28,7 @@ # Constants MODEL_REMAPPING = { "mistral": "llama", # mistral is compatible with llama - "phi-msft": "phixtral", + "phi-msft": "phixtral" } MAX_FILE_SIZE_GB = 5 From 93c23497a92d01470622782692a434cb4f1eaa38 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 23 Jul 2024 10:56:53 +0200 Subject: [PATCH 09/40] save up --- llms/mlx_lm/models/mamba.py | 50 ++++++++++++++++++------------------- llms/mlx_lm/utils.py | 3 ++- 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 0ac8e7a44..a5038d30e 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -1,10 +1,12 @@ from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Union import math import torch +# import tokenizer + import mlx.core as mx import mlx.nn as nn @@ -13,7 +15,7 @@ @dataclass class ModelArgs(BaseModelArgs): - model_type: str + model_type: str = "mamba" dt_rank: Union[int, str] d_model: int d_inner: int @@ -43,7 +45,7 @@ def __post_init__(self): self.dt_rank = math.ceil(self.d_model / 16) -def pscan_f(A, X): +def pscan_main(A, X): Aa = A Xa = X B, D, L, _ = A.shape @@ -83,7 +85,7 @@ def pscan_f(A, X): def pscan(A_in, X_in): A = A_in[:].transpose(0, 2, 1, 3) X = X_in[:].transpose(0, 2, 1, 3) - pscan_f(A, X) + pscan_main(A, X) return X.transpose(0, 2, 1, 3) @@ -310,37 +312,35 @@ def sanitize(self, torch_state_dict): def layers(self): return self.model.layers + def generate(self, tokenizer=None, prompt: str="Hello", n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): + self.eval() -# def generate(self, tokenizer=None, prompt: str="Hello", n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): -# self.eval() - -# input_ids = mx.array([[3, 3, 3]]) # mx.array(tokenizer(prompt, return_tensors='np').input_ids) # (1, tokens_prompt) # (1, num_tokens) + input_ids = mx.array(tokenizer(prompt, return_tensors='np').input_ids) -# caches = [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.d_inner])) for _ in range(self.args.n_layer)] + caches = [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.d_inner])) for _ in range(self.args.n_layer)] -# for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): -# next_token_logits, caches = self(input_ids[:, i], caches) # (1, vocab_size), caches + for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): + next_token_logits, caches = self(input_ids[:, i], caches) -# # sample (no sampling when the prompt is being processed) -# if i+1 >= input_ids.shape[1]: + if i+1 >= input_ids.shape[1]: -# if top_k is not None: -# values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to biggest -# mask = next_token_logits < (values[:, 0, None]) -# next_token_logits = mx.where(mask, -5000, next_token_logits) # TODO -mx.inf is problematic for now + if top_k is not None: + values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to biggest + mask = next_token_logits < (values[:, 0, None]) + next_token_logits = mx.where(mask, -5000, next_token_logits) # TODO -mx.inf is problematic for now -# if sample and temperature > 0: -# next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) -# else: -# next_token = mx.argmax(next_token_logits, axis=-1)[:, None] + if sample and temperature > 0: + next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) + else: + next_token = mx.argmax(next_token_logits, axis=-1)[:, None] -# input_ids = mx.concatenate([input_ids, next_token], axis=1) + input_ids = mx.concatenate([input_ids, next_token], axis=1) -# # output = [tokenizer.decode(output.tolist()) for output in input_ids][0] + output = [tokenizer.decode(output.tolist()) for output in input_ids][0] -# self.train() + self.train() -# return next_token # output + return output # model = Model(ModelArgs()) # print(model) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 8ad4e0f28..0541a881b 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -28,7 +28,8 @@ # Constants MODEL_REMAPPING = { "mistral": "llama", # mistral is compatible with llama - "phi-msft": "phixtral" + "phi-msft": "phixtral", + "mamba": "mamba" } MAX_FILE_SIZE_GB = 5 From ba194e31b71724757113dd25d5322f51db6af685 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 23 Jul 2024 16:20:14 +0200 Subject: [PATCH 10/40] almost --- llms/mlx_lm/models/base.py | 13 + llms/mlx_lm/models/mamba-old.py | 399 +++++++++++++++++++++++++++ llms/mlx_lm/models/mamba-tiny-pld.py | 154 +++++++++++ llms/mlx_lm/models/mamba-torch.py | 147 ++++++++++ llms/mlx_lm/models/mamba.py | 304 ++++++-------------- llms/mlx_lm/utils.py | 48 ++-- 6 files changed, 827 insertions(+), 238 deletions(-) create mode 100644 llms/mlx_lm/models/mamba-old.py create mode 100644 llms/mlx_lm/models/mamba-tiny-pld.py create mode 100644 llms/mlx_lm/models/mamba-torch.py diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 8c3ecc788..9189a5ef8 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -11,6 +11,19 @@ def create_additive_causal_mask(N: int, offset: int = 0): return mask * -1e9 +class MambaCache: + def __init__(self, batch_size, intermediate_size, ssm_state_size, conv_kernel_size): + self.h = mx.zeros((batch_size, intermediate_size, ssm_state_size)) + self.conv_states = mx.zeros((batch_size, conv_kernel_size - 1, intermediate_size)) + + def update(self, new_h, new_conv_state): + self.h = new_h + self.conv_states = mx.concatenate([self.conv_states[:, 1:, :], new_conv_state], axis=1) + + @classmethod + def init_cache(cls, batch_size, intermediate_size, ssm_state_size, conv_kernel_size): + return cls(batch_size, intermediate_size, ssm_state_size, conv_kernel_size) + class KVCache: def __init__(self, head_dim, n_kv_heads): diff --git a/llms/mlx_lm/models/mamba-old.py b/llms/mlx_lm/models/mamba-old.py new file mode 100644 index 000000000..844d4fb7d --- /dev/null +++ b/llms/mlx_lm/models/mamba-old.py @@ -0,0 +1,399 @@ +from dataclasses import dataclass +from typing import Optional, Union + +import math + +import torch + +# import tokenizer + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str = "mamba" + dt_rank: Union[int, str] # time_step_rank + d_model: int + d_inner: int + vocab_size: int + n_layer: int + use_bias: bool + use_conv_bias: bool + rms_norm: bool + conv_kernel: int + state_size: int + expand: int + time_step_init_scheme: str + time_step_max: float + time_step_min: float + time_step_floor: float + pscan: bool = False + tie_word_embeddings: bool = False + num_hidden_layers: int = None + hidden_size: int = None + # time_step_scale + + def __post_init__(self): + self.d_inner = self.expand * self.d_model + if self.n_layer is None: + self.n_layer = self.num_hidden_layers + if self.d_model is None: + self.d_model = self.hidden_size + if self.dt_rank == 'auto': + self.dt_rank = math.ceil(self.d_model / 16) + + +def pscan_main(A, X): + Aa = A + Xa = X + B, D, L, _ = A.shape + num_steps = int(math.log2(L)) + + for k in range(num_steps): + T = 2 * (Xa.shape[2] // 2) + Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) + Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) + Xa[:, :, :, 1] += Aa[:, :, :, 1] * Xa[:, :, :, 0] + Aa[:, :, :, 1] *= Aa[:, :, :, 0] + A[:, :, 2**(k+1)-1::2**(k+1)] = Aa[:, :, :, 1] + X[:, :, 2**(k+1)-1::2**(k+1)] = Xa[:, :, :, 1] + Aa = Aa[:, :, :, 1] + Xa = Xa[:, :, :, 1] + + for k in range(num_steps-1, -1, -1): + Aa = A[:, :, 2**k-1::2**k] + Xa = X[:, :, 2**k-1::2**k] + step_len = Xa.shape[2] + T = 2 * (step_len // 2) + if T < step_len: + last_val_aa = Aa[:, :, -1] * Aa[:, :, -2] + last_val_xa = Xa[:, :, -1] + Aa[:, :, -1] * Xa[:, :, -2] + Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) + Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) + Xa[:, :, 1:, 0] += Aa[:, :, 1:, 0] * Xa[:, :, :-1, 1] + Aa[:, :, 1:, 0] *= Aa[:, :, :-1, 1] + if T == step_len: + A[:, :, 2**k-1::2**(k+1)] = Aa[:, :, :, 0] + X[:, :, 2**k-1::2**(k+1)] = Xa[:, :, :, 0] + else: + A[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Aa[:, :, :, 0], mx.array([last_val_aa]).reshape(B, D, 1, -1)], axis=2) + X[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Xa[:, :, :, 0], mx.array([last_val_xa]).reshape(B, D, 1, -1)], axis=2) + + +def pscan(A_in, X_in): + A = A_in[:].transpose(0, 2, 1, 3) + X = X_in[:].transpose(0, 2, 1, 3) + pscan_main(A, X) + return X.transpose(0, 2, 1, 3) + + +def clamp(x, min=None, max=None): + if min is not None: + mask_lower = x < min + if max is not None: + mask_upper = x > max + if min is not None: + if max is not None: + return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) + return mx.where(mask_lower, min, x) + return mx.where(mask_upper, max, x) + + +def unsqueeze(x, axis): + assert axis <= len(x.shape) + if axis >= 0: + new_shape = x.shape[:axis] + tuple([1]) + x.shape[axis:] + else: + new_shape = x.shape + tuple([1]) + return x.reshape(new_shape) + + +class DepthWiseConv1d(nn.Module): + def __init__(self, channels, kernel_size, bias, padding): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.bias = bias + self.padding = padding + self.conv1d = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, bias=True, padding=padding) + indices = mx.arange(channels) + mask = mx.zeros_like(self.conv1d.weight) + mask[indices, :, indices] = 1 + self.conv1d.weight *= mask + + def __call__(self, x): + return self.conv1d(x) + + +class MambaBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.in_proj = nn.Linear(args.d_model, 2 * args.d_inner, bias=args.use_bias) + # self.conv1d = DepthWiseConv1d(channels=args.d_inner, kernel_size=args.conv_kernel, bias=args.use_conv_bias, padding=args.conv_kernel-1) + self.conv1d = nn.Conv1d( + in_channels=args.d_inner, + out_channels=args.d_inner, + bias=args.conv_bias, + kernel_size=args.d_conv, + groups=args.d_inner, + padding=args.d_conv - 1, + ) + self.x_proj = nn.Linear(args.d_inner, args.dt_rank + 2 * args.state_size, bias=False) + self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True) + + dt_init_std = args.dt_rank**-0.5 * args.state_size + if args.time_step_init_scheme == "constant": + self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) + elif args.time_step_init_scheme == "random": + self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) + else: + raise NotImplementedError + + dt = clamp(mx.exp(mx.random.uniform(shape=[args.d_inner]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)), min=args.time_step_floor) + self.dt_proj.bias = dt + mx.log1p(-mx.exp(-dt)) + A = mx.repeat(mx.arange(1., 16 + 1.).reshape([1, 16]), repeats=args.d_inner, axis=0) + self.A_log = mx.log(A) + self.D = mx.ones([args.d_inner]) + + self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.use_bias) + + self.norm = nn.RMSNorm(args.d_model) + + def ssm_step(self, x, h): + A = -mx.exp(self.A_log) + D = self.D + deltaBC = self.x_proj(self.norm(x)) + delta, B, C = mx.split(deltaBC, indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) + delta = nn.softplus(self.dt_proj(delta)) + deltaA = mx.exp(unsqueeze(delta, -1) * A) + deltaB = unsqueeze(delta, -1) * unsqueeze(B, 1) + BX = deltaB * unsqueeze(x, -1) + if h is None: + h = mx.zeros([x.shape[0], self.args.d_inner, self.args.state_size]) + h = deltaA * h + BX + y = (h @ unsqueeze(C, -1)).squeeze(2) + y = y + D * x + return y, h + + def ssm(self, x): # DONE + A = -mx.exp(self.A_log) + D = self.D + delta, B, C = self.x_proj(x).split(indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) + delta = nn.softplus(self.dt_proj(delta)) + if self.args.pscan: + y = self.selective_scan(x, delta, A, B, C, D) + else: + y = self.selective_scan_seq(x, delta, A, B, C, D) + return y + + def ssm_new(self, x): + d_in, N = self.A_log.shape + A = -mx.exp(self.A_log.float()) + D = self.D.float() + delta, B, C = self.x_proj(x).split(split_size=[self.config.dt_rank, N, N], dim=-1) + delta = nn.softplus(self.dt_proj(delta)) + return self.selective_scan_new(x, delta, A, B, C, D) + + def selective_scan(self, x, delta, A, B, C, D): # DONE + deltaA = mx.exp(unsqueeze(delta, -1) * A) + deltaB = unsqueeze(delta, -1) * unsqueeze(B, 2) + BX = deltaB * unsqueeze(x, -1) + hs = pscan(deltaA, BX) + y = (hs @ unsqueeze(C, -1)).squeeze(3) + return y + D * x + + def selective_scan_new(self, u, delta, A, B, C, D): + (b, l, d_in) = u.shape + n = A.shape[1] + deltaA = mx.exp(mx.einsum(delta, A, 'b l d_in, d_in n -> b d_in l n')) + deltaB_u = mx.einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b d_in l n') + + # Perform selective scan (see scan_SSM() in The Annotated S4 [2]) + x = mx.zeros((b, d_in, n), device=deltaA.device) + ys = [] + for i in range(l): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + y = mx.einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') + ys.append(y) + y = mx.stack(ys, dim=1) # shape (b, l, d_in) + + y = y + u * D + + return y + + def selective_scan_seq(self, x, delta, A, B, C, D): + _, L, _ = x.shape + deltaA = mx.exp(unsqueeze(delta, -1) * A) + deltaB = unsqueeze(delta, -1) * unsqueeze(B, 2) + BX = deltaB * unsqueeze(x, -1) + h = mx.zeros([x.shape[0], self.args.d_inner, self.args.state_size]) + hs = [] + for t in range(0, L): + h = deltaA[:, t] * h + BX[:, t] + hs.append(h) + hs = mx.stack(hs, axis=1) + y = (hs @ unsqueeze(C, -1)).squeeze(3) + return y + D * x + + def step(self, x, cache): # Done + h, inputs = cache + x, z = self.in_proj(x).split(indices_or_sections=2, axis=1) + x_cache = unsqueeze(x, 1) + x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.args.conv_kernel-1, :] + y, h = self.ssm_step(nn.silu(x), h) + output = y * nn.silu(z) + output = self.out_proj(output) + inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) + return output, (h, inputs) + + def ssm_step(self, x, h): # Done + A = -mx.exp(self.A_log) + D = self.D + delta, B, C = self.x_proj(x).split(indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) + delta = nn.softplus(self.dt_proj(delta)) + deltaB = unsqueeze(delta, -1) * unsqueeze(B, 1) + BX = deltaB * unsqueeze(x, -1) + if h is None: + h = mx.zeros([x.shape[0], self.args.d_inner, self.args.d_state]) + h = deltaA * h + BX + y = (h @ unsqueeze(C, -1)).squeeze(2) + y = y + D * x + return y, h + + def __call__(self, x): # DONE + _, L, _ = x.shape + x, z = self.in_proj(x).split(indices_or_sections=2, axis=2) + x = self.conv1d(x)[:, :L, :] + output = self.ssm(nn.silu(x)) * nn.silu(z) + return self.out_proj(output) + + def new(self, x): + _, L, _ = x.shape + x, r = self.in_proj(x).split([self.args.d_inner, self.args.d_inner], axis=-1) + + x = mx.reshape(x, 'b l d_in -> b d_in l') + x = self.conv1d(x)[:, :, :L] + x = mx.rearrange(x, 'b d_in l -> b l d_in') + out = self.ssm_new(nn.silu(x)) * nn.silu(r) + return self.out_proj(out) + x + + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.mixer = MambaBlock(args) + self.norm = nn.RMSNorm(args.d_model) + + def __call__(self, inputs: mx.array, cache: Optional[mx.array] = None): + output, cache = self.mixer.step(self.norm(inputs), cache) + output = output + inputs + return output, cache + + +class Mamba(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.embedding = nn.Embedding(args.vocab_size, args.d_model) + self.layers = [ResidualBlock(args) for _ in range(args.n_layer)] + self.norm_f = nn.RMSNorm(args.d_model) + + def __call__(self, inputs: mx.array, cache=None): + tokens = self.embedding(inputs) + if cache is None: + cache = [None] * len(self.layers) + for i, layer in enumerate(self.layers): + h, cache[i] = layer(tokens, cache[i]) + h = self.norm_f(h) + return h, cache + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.backbone = Mamba(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False) + + def __call__(self, inputs: mx.array, cache=None): + out, cache = self.backbone(inputs, cache) + + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + return out, cache + + # def torch_to_mlx_depthwise_weights(self, torch_weights): + # torch_weights = torch_weights.transpose(2, 1) + # channels, kernel_size, _ = torch_weights.shape + + # mlx_weights = torch.zeros(channels, kernel_size, channels) + + # indices = torch.arange(channels) + # if torch_weights[:, :, 0].type() == 'torch.BFloat16Tensor': + # mlx_weights[indices, :, indices] = torch_weights[:, :, 0].float() + # else: + # mlx_weights[indices, :, indices] = torch_weights[:, :, 0] + + # return mlx_weights + + def sanitize(self, torch_state_dict): + new_state_dict = {} + for key, value in torch_state_dict.items(): + if 'conv1d.weight' in key: + value = self.torch_to_mlx_depthwise_weights(value) + + if 'conv1d' in key: + key = key.replace('conv1d', 'conv1d.conv1d') + + if value.type() == 'torch.BFloat16Tensor': + new_state_dict[key] = value.half().numpy() + else: + new_state_dict[key] = value.numpy() + + return new_state_dict + + @property + def layers(self): + return self.model.layers + + def generate(self, tokenizer=None, prompt: str="Hello", n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): + self.eval() + + input_ids = mx.array(tokenizer(prompt, return_tensors='np').input_ids) + + caches = [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.d_inner])) for _ in range(self.args.n_layer)] + + for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): + next_token_logits, caches = self(input_ids[:, i], caches) + + if i+1 >= input_ids.shape[1]: + + if top_k is not None: + values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to biggest + mask = next_token_logits < (values[:, 0, None]) + next_token_logits = mx.where(mask, -5000, next_token_logits) # TODO -mx.inf is problematic for now + + if sample and temperature > 0: + next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) + else: + next_token = mx.argmax(next_token_logits, axis=-1)[:, None] + + input_ids = mx.concatenate([input_ids, next_token], axis=1) + + output = [tokenizer.decode(output.tolist()) for output in input_ids][0] + + self.train() + + return output + +# model = Model(ModelArgs()) +# print(model) + +# logits = model.generate() +# print(logits) diff --git a/llms/mlx_lm/models/mamba-tiny-pld.py b/llms/mlx_lm/models/mamba-tiny-pld.py new file mode 100644 index 000000000..8713978d5 --- /dev/null +++ b/llms/mlx_lm/models/mamba-tiny-pld.py @@ -0,0 +1,154 @@ +from dataclasses import dataclass +from typing import Optional, Union + +import math + +import torch + +# import tokenizer + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + vocab_size: int + n_layer: int + use_conv_bias: bool + expand: int + pad_vocab_size_multiple: int + conv_kernel: int + d_model: int + state_size: int + d_inner: int + initializer_range: float + use_bias: bool + time_step_init_scheme: str + time_step_max: float + time_step_min: float + time_step_floor: float + dt_rank: Union[int, str] = "auto" + + def __post_init__(self): + self.d_inner = self.expand * self.d_model + if self.n_layer is None: + self.n_layer = self.num_hidden_layers + if self.d_model is None: + self.d_model = self.hidden_size + if self.dt_rank == "auto": + self.dt_rank = math.ceil(self.d_model / 16) + +class MambaBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.in_proj = nn.Linear(args.d_model, 2 * args.d_inner, bias=args.use_bias) + # self.conv1d = DepthWiseConv1d(channels=args.d_inner, kernel_size=args.conv_kernel, bias=args.use_conv_bias, padding=args.conv_kernel-1) + self.conv1d = nn.Conv1d( + in_channels=args.d_inner, + out_channels=args.d_inner, + bias=args.use_conv_bias, + kernel_size=args.conv_kernel, + # groups=args.d_inner, + padding=args.conv_kernel - 1, + ) + self.x_proj = nn.Linear(args.d_inner, args.dt_rank + 2 * args.state_size, bias=False) + self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True) + + A = mx.repeat(mx.arange(1, args.state_size + 1).reshape([1, 16]), repeats=args.d_inner) + + + self.A_log = mx.log(A) + self.D = mx.ones([args.d_inner]) + + self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.use_bias) + + self.norm = nn.RMSNorm(args.d_model) + + def ssm(self, x): + d_in, N = self.A_log.shape + A = -mx.exp(self.A_log.float()) + D = self.D.float() + delta, B, C = self.x_proj(x).split(split_size=[self.config.dt_rank, N, N], dim=-1) + delta = nn.softplus(self.dt_proj(delta)) + return self.selective_scan(x, delta, A, B, C, D) + + def selective_scan(self, u, delta, A, B, C, D): + (b, l, d_in) = u.shape + n = A.shape[1] + deltaA = mx.exp(mx.einsum(delta, A, 'b l d_in, d_in n -> b d_in l n')) + deltaB_u = mx.einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b d_in l n') + + # Perform selective scan (see scan_SSM() in The Annotated S4 [2]) + x = mx.zeros((b, d_in, n), device=deltaA.device) + ys = [] + for i in range(l): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + y = mx.einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') + ys.append(y) + y = mx.stack(ys, dim=1) # shape (b, l, d_in) + + y = y + u * D + + return y + + def __call__(self, x): + _, L, _ = x.shape + x, r = self.in_proj(x).split([self.args.d_inner, self.args.d_inner], axis=-1) + + x = mx.reshape(x, 'b l d_in -> b d_in l') + x = self.conv1d(x)[:, :, :L] + x = mx.rearrange(x, 'b d_in l -> b l d_in') + out = self.ssm(nn.silu(x)) * nn.silu(r) + return self.out_proj(out) + x + +class MambaModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.embedding = nn.Embedding(args.vocab_size, args.d_model) + self.layers = [MambaBlock(args) for _ in range(args.n_layer)] + self.norm_f = nn.RMSNorm(args.d_model) + + def __call__(self, inputs: mx.array_equal): + tokens = self.embedding(inputs) + for i, layer in enumerate(self.layers): + h = layer(tokens) + h = self.norm_f(h) + return h + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = self.backbone = MambaModel(args) + self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False) + self.lm_head.weight = self.model.embedding.weight + + def __call__(self, inputs: mx.array): + h = self.backbone(inputs) + return self.lm_head(h) + + @property + def layers(self): + return self.backbone.layers + + # def sanitize(self, weights): + # exclude_patterns = [ + # 'backbone.layers.mixer.A_log', + # 'backbone.layers.mixer.conv1d.weight', + # 'backbone.layers.mixer.dt_proj.weight', + # 'backbone.layers.mixer.in_proj.weight', + # 'backbone.layers.mixer.dt_proj.bias', + # 'backbone.layers.mixer.conv1d.bias', + # 'backbone.layers.mixer.D' + # ] + # return { + # k: v for k, v in weights.items() + # if not any(pattern in k for pattern in exclude_patterns) + # } \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba-torch.py b/llms/mlx_lm/models/mamba-torch.py new file mode 100644 index 000000000..ee9e286ff --- /dev/null +++ b/llms/mlx_lm/models/mamba-torch.py @@ -0,0 +1,147 @@ +import torch.nn as nn +import torch +from configuration_mamba import MambaConfig +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_utils import PreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +import math +import json +import torch +import torch.nn as nn +import torch.nn.functional as F +from dataclasses import dataclass +from einops import rearrange, repeat, einsum +from typing import Optional , Union ,Tuple + +# Dear contributors of the https://github.com/johnma2006/mamba-minimal/tree/master repository, special thanks to Albert Gu and Tri Dao for their articles. (https://arxiv.org/abs/2312.00752) + + +class MambaRMSNorm(nn.Module): + def __init__(self, + d_model: int, + eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(d_model)) + def forward(self, x): + output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight + return output + + +class MambaBlock(nn.Module): + def __init__(self, config: MambaConfig): + """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1].""" + super().__init__() + self.config = config + + self.in_proj = nn.Linear(config.d_model, config.d_inner * 2, bias=config.bias) + + self.conv1d = nn.Conv1d( + in_channels=config.d_inner, + out_channels=config.d_inner, + bias=config.conv_bias, + kernel_size=config.d_conv, + groups=config.d_inner, + padding=config.d_conv - 1, + ) + + # x_proj takes in `x` and outputs the input-specific Δ, B, C + self.x_proj = nn.Linear(config.d_inner, config.dt_rank + config.d_state * 2, bias=False) + + # dt_proj projects Δ from dt_rank to d_in + self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True) + + A = repeat(torch.arange(1, config.d_state + 1), 'n -> d n', d=config.d_inner) + self.A_log = nn.Parameter(torch.log(A)) + self.D = nn.Parameter(torch.ones(config.d_inner)) + self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias) + self.norm = MambaRMSNorm(config.d_model) + + def forward(self, x): + (b, l, d) = x.shape + x_copy = x # There was a separate class for residual, I deleted that part and added it here. + x = self.norm(x) + x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in) + (x, res) = x_and_res.split(split_size=[self.config.d_inner, self.config.d_inner], dim=-1) + + x = rearrange(x, 'b l d_in -> b d_in l') + x = self.conv1d(x)[:, :, :l] + x = rearrange(x, 'b d_in l -> b l d_in') + + x = F.silu(x) + + y = self.ssm(x) + + y = y * F.silu(res) + + output = self.out_proj(y) + x_copy + + return output + + + def ssm(self, x): + (d_in, n) = self.A_log.shape + + A = -torch.exp(self.A_log.float()) # shape (d_in, n) + D = self.D.float() + + x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n) + + (delta, B, C) = x_dbl.split(split_size=[self.config.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n) + delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in) + + y = self.selective_scan(x, delta, A, B, C, D) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2] + + return y + + + def selective_scan(self, u, delta, A, B, C, D): + (b, l, d_in) = u.shape + n = A.shape[1] + deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b d_in l n')) + deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b d_in l n') + x = torch.zeros((b, d_in, n), device=deltaA.device) + ys = [] + for i in range(l): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') + ys.append(y) + y = torch.stack(ys, dim=1) # shape (b, l, d_in) + + y = y + u * D + + return y + + +class MambaModel(MambaPreTrainedModel): + def __init__(self, config: MambaConfig): + super().__init__(config) + self.config = config + + self.embedding = nn.Embedding(config.vocab_size, config.d_model) + self.layers = nn.ModuleList([MambaBlock(config) for _ in range(config.n_layer)]) + self.norm_f = MambaRMSNorm(config.d_model) + + def forward(self, input_ids: torch.LongTensor = None): + x = self.embedding(input_ids) + all_hidden_states = list() + for layer in self.layers: + x = layer(x) + all_hidden_states.append(x) + return self.norm_f(x) + + +class MambaForCausalLM(MambaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = MambaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + self.lm_head.weight = self.model.embedding.weight + + + def forward(self, input_ids: torch.LongTensor = None): + hidden_states = self.model(input_ids=input_ids) + return self.lm_head(hidden_states) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index a5038d30e..19fd3694c 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -10,83 +10,38 @@ import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, MambaCache @dataclass class ModelArgs(BaseModelArgs): - model_type: str = "mamba" - dt_rank: Union[int, str] - d_model: int - d_inner: int + model_type: str vocab_size: int - n_layer: int - use_bias: bool - use_conv_bias: bool - conv_kernel: int - state_size: int + hidden_size: int # d_model + intermediate_size: int # d_inner + state_size: int # d_state + num_hidden_layers: int # n_layer + layer_norm_epsilon: float expand: int - time_step_init_scheme: str - time_step_max: float + conv_kernel: int + use_bias: bool # bias + use_conv_bias: bool # conv_bias + initializer_range: float + time_step_rank: int + time_step_scale: float time_step_min: float + time_step_max: float + time_step_init_scheme: str time_step_floor: float - pscan: bool = False - tie_word_embeddings: bool = False - num_hidden_layers: int = None - hidden_size: int = None + rescale_prenorm_residual: bool + use_cache: bool + use_mambapy: bool = False # pscan + dt_rank: str = "auto" def __post_init__(self): - self.d_inner = self.expand * self.d_model - if self.n_layer is None: - self.n_layer = self.num_hidden_layers - if self.d_model is None: - self.d_model = self.hidden_size - if self.dt_rank == 'auto': - self.dt_rank = math.ceil(self.d_model / 16) - - -def pscan_main(A, X): - Aa = A - Xa = X - B, D, L, _ = A.shape - num_steps = int(math.log2(L)) - - for k in range(num_steps): - T = 2 * (Xa.shape[2] // 2) - Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa[:, :, :, 1] += Aa[:, :, :, 1] * Xa[:, :, :, 0] - Aa[:, :, :, 1] *= Aa[:, :, :, 0] - A[:, :, 2**(k+1)-1::2**(k+1)] = Aa[:, :, :, 1] - X[:, :, 2**(k+1)-1::2**(k+1)] = Xa[:, :, :, 1] - Aa = Aa[:, :, :, 1] - Xa = Xa[:, :, :, 1] - - for k in range(num_steps-1, -1, -1): - Aa = A[:, :, 2**k-1::2**k] - Xa = X[:, :, 2**k-1::2**k] - step_len = Xa.shape[2] - T = 2 * (step_len // 2) - if T < step_len: - last_val_aa = Aa[:, :, -1] * Aa[:, :, -2] - last_val_xa = Xa[:, :, -1] + Aa[:, :, -1] * Xa[:, :, -2] - Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa[:, :, 1:, 0] += Aa[:, :, 1:, 0] * Xa[:, :, :-1, 1] - Aa[:, :, 1:, 0] *= Aa[:, :, :-1, 1] - if T == step_len: - A[:, :, 2**k-1::2**(k+1)] = Aa[:, :, :, 0] - X[:, :, 2**k-1::2**(k+1)] = Xa[:, :, :, 0] - else: - A[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Aa[:, :, :, 0], mx.array([last_val_aa]).reshape(B, D, 1, -1)], axis=2) - X[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Xa[:, :, :, 0], mx.array([last_val_xa]).reshape(B, D, 1, -1)], axis=2) - - -def pscan(A_in, X_in): - A = A_in[:].transpose(0, 2, 1, 3) - X = X_in[:].transpose(0, 2, 1, 3) - pscan_main(A, X) - return X.transpose(0, 2, 1, 3) + self.intermediate_size = self.expand * self.hidden_size + if self.dt_rank == "auto": + self.dt_rank = math.ceil(self.hidden_size / 16) def clamp(x, min=None, max=None): @@ -101,40 +56,30 @@ def clamp(x, min=None, max=None): return mx.where(mask_upper, max, x) -def unsqueeze(x, axis): - assert axis <= len(x.shape) - if axis >= 0: - new_shape = x.shape[:axis] + tuple([1]) + x.shape[axis:] - else: - new_shape = x.shape + tuple([1]) - return x.reshape(new_shape) - - -class DepthWiseConv1d(nn.Module): - def __init__(self, channels, kernel_size, bias, padding): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.bias = bias - self.padding = padding - self.conv1d = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, bias=True, padding=padding) - indices = mx.arange(channels) - mask = mx.zeros_like(self.conv1d.weight) - mask[indices, :, indices] = 1 - self.conv1d.weight *= mask - - def __call__(self, x): - return self.conv1d(x) - - class MambaBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args - self.in_proj = nn.Linear(args.d_model, 2 * args.d_inner, bias=args.use_bias) - self.conv1d = DepthWiseConv1d(channels=args.d_inner, kernel_size=args.conv_kernel, bias=args.use_conv_bias, padding=args.conv_kernel-1) - self.x_proj = nn.Linear(args.d_inner, args.dt_rank + 2 * args.state_size, bias=False) - self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True) + + self.hidden_size = args.hidden_size + self.ssm_state_size = args.state_size + self.conv_kernel_size = int(args.conv_kernel) # Ensure it's an int + self.intermediate_size = int(args.intermediate_size) + self.time_step_rank = int(args.time_step_rank) + self.use_conv_bias = args.use_conv_bias + + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) + + self.conv1d = nn.Conv1d( + in_channels=self.intermediate_size, + out_channels=self.intermediate_size, + kernel_size=self.conv_kernel_size, + bias=self.use_conv_bias, + padding=self.conv_kernel_size-1 + ) + + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) dt_init_std = args.dt_rank**-0.5 * args.state_size if args.time_step_init_scheme == "constant": @@ -144,104 +89,53 @@ def __init__(self, args: ModelArgs): else: raise NotImplementedError - dt = clamp(mx.exp(mx.random.uniform(shape=[args.d_inner]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)), min=args.time_step_floor) + dt = clamp(mx.exp(mx.random.uniform(shape=[self.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)), min=args.time_step_floor) self.dt_proj.bias = dt + mx.log1p(-mx.exp(-dt)) - A = mx.repeat(mx.arange(1., 16 + 1.).reshape([1, 16]), repeats=args.d_inner, axis=0) + inv_dt = dt + mx.log1p(-mx.exp(-dt)) + self.dt_proj.bias = inv_dt + + A = mx.repeat(mx.arange(1., 16 + 1.).reshape([1, 16]), repeats=self.intermediate_size, axis=0) self.A_log = mx.log(A) - self.D = mx.ones([args.d_inner]) + self.D = mx.ones([self.intermediate_size]) - self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.use_bias) + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) - def ssm_step(self, x, h): + def ssm(self, x, h): A = -mx.exp(self.A_log) D = self.D - deltaBC = self.x_proj(x) - delta, B, C = mx.split(deltaBC, indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) + delta, B, C = self.x_proj(x).split(indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.time_step_rank], axis=-1) delta = nn.softplus(self.dt_proj(delta)) deltaA = mx.exp(unsqueeze(delta, -1) * A) deltaB = unsqueeze(delta, -1) * unsqueeze(B, 1) BX = deltaB * unsqueeze(x, -1) if h is None: - h = mx.zeros([x.shape[0], self.args.d_inner, self.args.state_size]) + h = mx.zeros([x.shape[0], self.config.d_inner, self.config.d_state]) h = deltaA * h + BX y = (h @ unsqueeze(C, -1)).squeeze(2) y = y + D * x return y, h - def ssm(self, x): # DONE - A = -mx.exp(self.A_log) - D = self.D - delta, B, C = self.x_proj(x).split(indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) - delta = nn.softplus(self.dt_proj(delta)) - if self.args.pscan: - y = self.selective_scan(x, delta, A, B, C, D) - else: - y = self.selective_scan_seq(x, delta, A, B, C, D) - return y - - def selective_scan(self, x, delta, A, B, C, D): # DONE - deltaA = mx.exp(unsqueeze(delta, -1) * A) - deltaB = unsqueeze(delta, -1) * unsqueeze(B, 2) - BX = deltaB * unsqueeze(x, -1) - hs = pscan(deltaA, BX) - y = (hs @ unsqueeze(C, -1)).squeeze(3) - return y + D * x - - def selective_scan_seq(self, x, delta, A, B, C, D): - _, L, _ = x.shape - deltaA = mx.exp(unsqueeze(delta, -1) * A) - deltaB = unsqueeze(delta, -1) * unsqueeze(B, 2) - BX = deltaB * unsqueeze(x, -1) - h = mx.zeros([x.shape[0], self.args.d_inner, self.args.state_size]) - hs = [] - for t in range(0, L): - h = deltaA[:, t] * h + BX[:, t] - hs.append(h) - hs = mx.stack(hs, axis=1) - y = (hs @ unsqueeze(C, -1)).squeeze(3) - return y + D * x - - def step(self, x, cache): # Done + def __call__(self, x, cache: MambaCache): h, inputs = cache x, z = self.in_proj(x).split(indices_or_sections=2, axis=1) x_cache = unsqueeze(x, 1) - x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.args.conv_kernel-1, :] - y, h = self.ssm_step(nn.silu(x), h) + x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.config.d_conv-1, :] # (B, ED) + y, h = self.ssm(nn.silu(x), h) output = y * nn.silu(z) - output = self.out_proj(output) - inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) - return output, (h, inputs) - - def ssm_step(self, x, h): # Done - A = -mx.exp(self.A_log) - D = self.D - delta, B, C = self.x_proj(x).split(indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) - delta = nn.softplus(self.dt_proj(delta)) - deltaB = unsqueeze(delta, -1) * unsqueeze(B, 1) - BX = deltaB * unsqueeze(x, -1) - if h is None: - h = mx.zeros([x.shape[0], self.args.d_inner, self.args.d_state]) - h = deltaA * h + BX - y = (h @ unsqueeze(C, -1)).squeeze(2) - y = y + D * x - return y, h - - def __call__(self, x): # DONE - _, L, _ = x.shape - x, z = self.in_proj(x).split(indices_or_sections=2, axis=2) - x = self.conv1d(x)[:, :L, :] - output = self.ssm(nn.silu(x)) * nn.silu(z) - return self.out_proj(output) + inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, d_conv-1, ED) + # cache = (h, inputs) + cache.update(h, inputs) + return self.out_proj(output), cache class ResidualBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.mixer = MambaBlock(args) - self.norm = nn.RMSNorm(args.d_model) + self.norm = nn.RMSNorm(args.hidden_size) - def __call__(self, inputs: mx.array, cache: Optional[mx.array] = None): - output, cache = self.mixer.step(self.norm(inputs), cache) + def __call__(self, inputs: mx.array, cache): + output, cache = self.mixer(self.norm(inputs), cache) output = output + inputs return output, cache @@ -249,12 +143,12 @@ def __call__(self, inputs: mx.array, cache: Optional[mx.array] = None): class Mamba(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - self.embedding = nn.Embedding(args.vocab_size, args.d_model) - self.layers = [ResidualBlock(args) for _ in range(args.n_layer)] - self.norm_f = nn.RMSNorm(args.d_model) + self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] + self.norm_f = nn.RMSNorm(args.hidden_size) def __call__(self, inputs: mx.array, cache=None): - tokens = self.embedding(inputs) + tokens = self.embeddings(inputs) if cache is None: cache = [None] * len(self.layers) for i, layer in enumerate(self.layers): @@ -267,57 +161,34 @@ class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args + self.model_type = args.model_type self.backbone = Mamba(args) - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False) + # self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__(self, inputs: mx.array, cache=None): out, cache = self.backbone(inputs, cache) - - if self.args.tie_word_embeddings: - out = self.model.embed_tokens.as_linear(out) + out = self.backbone.embeddings.as_linear(out) return out, cache - def torch_to_mlx_depthwise_weights(self, torch_weights): - torch_weights = torch_weights.transpose(2, 1) - channels, kernel_size, _ = torch_weights.shape - - mlx_weights = torch.zeros(channels, kernel_size, channels) - - indices = torch.arange(channels) - if torch_weights[:, :, 0].type() == 'torch.BFloat16Tensor': - mlx_weights[indices, :, indices] = torch_weights[:, :, 0].float() - else: - mlx_weights[indices, :, indices] = torch_weights[:, :, 0] - - return mlx_weights - - def sanitize(self, torch_state_dict): - new_state_dict = {} - for key, value in torch_state_dict.items(): - if 'conv1d.weight' in key: - value = self.torch_to_mlx_depthwise_weights(value) - - if 'conv1d' in key: - key = key.replace('conv1d', 'conv1d.conv1d') - - if value.type() == 'torch.BFloat16Tensor': - new_state_dict[key] = value.half().numpy() - else: - new_state_dict[key] = value.numpy() - - return new_state_dict - @property def layers(self): - return self.model.layers + return self.backbone.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_hidden_layers - def generate(self, tokenizer=None, prompt: str="Hello", n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): + @property + def n_kv_heads(self): + return self.args.num_hidden_layers + + def generate(self, input_ids: mx.array, n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): self.eval() - input_ids = mx.array(tokenizer(prompt, return_tensors='np').input_ids) + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) - caches = [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.d_inner])) for _ in range(self.args.n_layer)] + caches = [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): next_token_logits, caches = self(input_ids[:, i], caches) @@ -336,14 +207,5 @@ def generate(self, tokenizer=None, prompt: str="Hello", n_tokens_to_gen: int = 5 input_ids = mx.concatenate([input_ids, next_token], axis=1) - output = [tokenizer.decode(output.tolist()) for output in input_ids][0] - self.train() - - return output - -# model = Model(ModelArgs()) -# print(model) - -# logits = model.generate() -# print(logits) + return input_ids diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 0541a881b..df88c38e2 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -141,7 +141,7 @@ def generate_step( Args: prompt (mx.array): The input prompt. - model (nn.Module): The model to use for generation. + model: The model to use for generation. temp (float): The temperature for sampling, if 0 the argmax is used. Default: ``0``. repetition_penalty (float, optional): The penalty factor for repeating @@ -199,22 +199,28 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]: def _step(y): nonlocal repetition_context - logits = model(y[None], cache=cache) - logits = logits[:, -1, :] - - if repetition_penalty: - logits = apply_repetition_penalty( - logits, repetition_context, repetition_penalty - ) - y, logprobs = sample(logits) - repetition_context.append(y.item()) + if model.args.model_type == "mamba": + output_ids = model.generate(input_ids=y, n_tokens_to_gen=1, sample=temp > 0, temperature=temp) + next_token = output_ids[:, -1:] # Get the last generated token + logprobs = mx.zeros_like(next_token) # Dummy logprobs as we don't have actual logprobs from generate method else: - y, logprobs = sample(logits) + logits = model(y[None], cache=cache) + logits = logits[:, -1, :] - if repetition_context_size: - if len(repetition_context) > repetition_context_size: - repetition_context = repetition_context[-repetition_context_size:] - return y, logprobs.squeeze(0) + if repetition_penalty: + logits = apply_repetition_penalty( + logits, repetition_context, repetition_penalty + ) + next_token, logprobs = sample(logits) + repetition_context.append(next_token.item()) + else: + next_token, logprobs = sample(logits) + + if repetition_context_size: + if len(repetition_context) > repetition_context_size: + repetition_context = repetition_context[-repetition_context_size:] + + return next_token, logprobs.squeeze(0) y, logprobs = _step(y) @@ -249,7 +255,11 @@ def stream_generate( if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) - prompt_tokens = mx.array(tokenizer.encode(prompt)) + if model.args.model_type == "mamba": + prompt_tokens = mx.array(tokenizer.encode(prompt, return_tensors='np')) + else: + prompt_tokens = mx.array(tokenizer.encode(prompt)) + detokenizer = tokenizer.detokenizer detokenizer.reset() @@ -299,7 +309,11 @@ def generate( print("=" * 10) print("Prompt:", prompt) - prompt_tokens = mx.array(tokenizer.encode(prompt)) + if model.args.model_type == "mamba": + prompt_tokens = mx.array(tokenizer.encode(prompt, return_tensors='np')) + else: + prompt_tokens = mx.array(tokenizer.encode(prompt)) + detokenizer = tokenizer.detokenizer tic = time.perf_counter() From ab28b44fd2f478d2ab7d4b9d421b220f3c641ae4 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 23 Jul 2024 19:55:18 +0200 Subject: [PATCH 11/40] update --- llms/mlx_lm/models/mamba-save.py | 234 ++++++++++++++++++++++++++++++ llms/mlx_lm/models/mamba-torch.py | 8 +- llms/mlx_lm/models/mamba.py | 118 +++++++-------- 3 files changed, 296 insertions(+), 64 deletions(-) create mode 100644 llms/mlx_lm/models/mamba-save.py diff --git a/llms/mlx_lm/models/mamba-save.py b/llms/mlx_lm/models/mamba-save.py new file mode 100644 index 000000000..49418b520 --- /dev/null +++ b/llms/mlx_lm/models/mamba-save.py @@ -0,0 +1,234 @@ +from dataclasses import dataclass +from typing import Optional, Union + +import math + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, MambaCache + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + vocab_size: int + hidden_size: int # d_model + intermediate_size: int # d_inner + state_size: int # d_state + num_hidden_layers: int # n_layer + layer_norm_epsilon: float + expand: int + conv_kernel: int # d_conv + use_bias: bool # bias + use_conv_bias: bool # conv_bias + initializer_range: float + time_step_rank: int + time_step_scale: float + time_step_min: float + time_step_max: float + time_step_init_scheme: str + time_step_floor: float + rescale_prenorm_residual: bool + use_cache: bool + use_mambapy: bool = False # pscan + dt_rank: str = "auto" + + def __post_init__(self): + self.intermediate_size = self.expand * self.hidden_size + if self.dt_rank == "auto": + self.dt_rank = math.ceil(self.hidden_size / 16) + + +def clamp(x, min=None, max=None): + if min is not None: + mask_lower = x < min + if max is not None: + mask_upper = x > max + if min is not None: + if max is not None: + return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) + return mx.where(mask_lower, min, x) + return mx.where(mask_upper, max, x) + + +def unsqueeze(x, axis): + assert axis <= len(x.shape) + if axis >= 0: + new_shape = x.shape[:axis] + tuple([1]) + x.shape[axis:] + else: + new_shape = x.shape + tuple([1]) + return x.reshape(new_shape) + + +class DepthWiseConv1d(nn.Module): + def __init__(self, channels, kernel_size, bias, padding): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.bias = bias + self.padding = padding + self.weight = mx.random.normal(shape=(channels, 1, kernel_size)) + scale = math.sqrt(1.0 / (channels * kernel_size)) + self.weight *= scale # Ensure scaling is applied correctly + if bias: + self.bias = mx.zeros((channels,)) + else: + self.bias = None + + def __call__(self, x): + out = nn.Conv1d(x, self.weight, kernel_size=self.kernel_size, bias=self.bias, padding=self.padding) + return out + + +class MambaBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.hidden_size = args.hidden_size + self.ssm_state_size = args.state_size + self.conv_kernel_size = args.conv_kernel + self.intermediate_size = args.intermediate_size + self.time_step_rank = int(args.time_step_rank) + self.use_conv_bias = args.use_conv_bias + + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) + + self.conv1d = DepthWiseConv1d( + channels=self.intermediate_size, + kernel_size=self.conv_kernel_size, + bias=self.use_conv_bias, + padding=self.conv_kernel_size-1 + ) + + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + + dt_init_std = args.dt_rank**-0.5 * args.state_size + if args.time_step_init_scheme == "constant": + self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) + elif args.time_step_init_scheme == "random": + self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) + else: + raise NotImplementedError + + dt = clamp(mx.exp(mx.random.uniform(shape=[self.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)), min=args.time_step_floor) + self.dt_proj.bias = dt + mx.log1p(-mx.exp(-dt)) + inv_dt = dt + mx.log1p(-mx.exp(-dt)) + self.dt_proj.bias = inv_dt + + A = mx.repeat(mx.arange(1, self.ssm_state_size + 1).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) + self.A_log = mx.log(A) + self.D = mx.ones([self.intermediate_size]) + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) + + def ssm(self, x, h): + A = -mx.exp(self.A_log) + D = self.D + delta, B, C = self.x_proj(x).split(split_size=[self.intermediate_size, self.intermediate_size], dim=-1) + delta = nn.softplus(self.dt_proj(delta)) + deltaA = mx.exp(mx.unsqueeze(delta, -1) * A) + deltaB = unsqueeze(delta, -1) * unsqueeze(B, 1) + BX = deltaB * unsqueeze(x, -1) + if h is None: + h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) + h = deltaA * h + BX + y = (h @ mx.unsqueeze(C, -1)).squeeze(2) + y = y + D * x + return y, h + + def __call__(self, x, cache: Optional[MambaCache]): + h, inputs = cache + x, z = self.in_proj(x).split(indices_or_sections=2, axis=1) + x_cache = unsqueeze(x, 1) + x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.conv_kernel_size-1, :] # (B, ED) + y, h = self.ssm(nn.silu(x), h) + output = y * nn.silu(z) + inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) + cache.update(h, inputs) + return self.out_proj(output), cache + + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.mixer = MambaBlock(args) + self.norm = nn.RMSNorm(args.hidden_size) + + def __call__(self, inputs: mx.array, cache): + output, cache = self.mixer(self.norm(inputs), cache) + output = output + inputs + return output, cache + + +class Mamba(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] + self.norm_f = nn.RMSNorm(args.hidden_size) + + def __call__(self, inputs: mx.array, cache=None): + tokens = self.embeddings(inputs) + if cache is None: + cache = [None] * len(self.layers) + for i, layer in enumerate(self.layers): + h, cache[i] = layer(tokens, cache[i]) + h = self.norm_f(h) + return h, cache + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.backbone = Mamba(args) + # self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__(self, inputs: mx.array, cache=None): + out, cache = self.backbone(inputs, cache) + out = self.backbone.embeddings.as_linear(out) + return out, cache + + @property + def layers(self): + return self.backbone.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_hidden_layers + + @property + def n_kv_heads(self): + return self.args.num_hidden_layers + + def generate(self, input_ids: mx.array, n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): + self.eval() + + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + + caches = [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] + + for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): + next_token_logits, caches = self(input_ids[:, i], caches) + + if i+1 >= input_ids.shape[1]: + + if top_k is not None: + values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to biggest + mask = next_token_logits < (values[:, 0, None]) + next_token_logits = mx.where(mask, -5000, next_token_logits) # TODO -mx.inf is problematic for now + + if sample and temperature > 0: + next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) + else: + next_token = mx.argmax(next_token_logits, axis=-1)[:, None] + + input_ids = mx.concatenate([input_ids, next_token], axis=1) + + self.train() + return input_ids diff --git a/llms/mlx_lm/models/mamba-torch.py b/llms/mlx_lm/models/mamba-torch.py index ee9e286ff..84deb4d3f 100644 --- a/llms/mlx_lm/models/mamba-torch.py +++ b/llms/mlx_lm/models/mamba-torch.py @@ -12,7 +12,7 @@ from dataclasses import dataclass from einops import rearrange, repeat, einsum from typing import Optional , Union ,Tuple - +l # Dear contributors of the https://github.com/johnma2006/mamba-minimal/tree/master repository, special thanks to Albert Gu and Tri Dao for their articles. (https://arxiv.org/abs/2312.00752) @@ -55,14 +55,12 @@ def __init__(self, config: MambaConfig): self.A_log = nn.Parameter(torch.log(A)) self.D = nn.Parameter(torch.ones(config.d_inner)) self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias) - self.norm = MambaRMSNorm(config.d_model) + # self.norm = MambaRMSNorm(config.d_model) def forward(self, x): (b, l, d) = x.shape x_copy = x # There was a separate class for residual, I deleted that part and added it here. - x = self.norm(x) - x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in) - (x, res) = x_and_res.split(split_size=[self.config.d_inner, self.config.d_inner], dim=-1) + x, res = self.in_proj(self.norm(x)).split(split_size=[self.config.d_inner, self.config.d_inner], dim=-1) x = rearrange(x, 'b l d_in -> b d_in l') x = self.conv1d(x)[:, :, :l] diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 19fd3694c..bf6c5ad1b 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -2,10 +2,7 @@ from typing import Optional, Union import math - -import torch - -# import tokenizer +import einsum import mlx.core as mx import mlx.nn as nn @@ -23,7 +20,7 @@ class ModelArgs(BaseModelArgs): num_hidden_layers: int # n_layer layer_norm_epsilon: float expand: int - conv_kernel: int + conv_kernel: int # d_conv use_bias: bool # bias use_conv_bias: bool # conv_bias initializer_range: float @@ -55,7 +52,6 @@ def clamp(x, min=None, max=None): return mx.where(mask_lower, min, x) return mx.where(mask_upper, max, x) - class MambaBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -63,8 +59,8 @@ def __init__(self, args: ModelArgs): self.hidden_size = args.hidden_size self.ssm_state_size = args.state_size - self.conv_kernel_size = int(args.conv_kernel) # Ensure it's an int - self.intermediate_size = int(args.intermediate_size) + self.conv_kernel_size = args.conv_kernel + self.intermediate_size = args.intermediate_size self.time_step_rank = int(args.time_step_rank) self.use_conv_bias = args.use_conv_bias @@ -78,54 +74,61 @@ def __init__(self, args: ModelArgs): padding=self.conv_kernel_size-1 ) - self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) - dt_init_std = args.dt_rank**-0.5 * args.state_size - if args.time_step_init_scheme == "constant": - self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) - elif args.time_step_init_scheme == "random": - self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) - else: - raise NotImplementedError - - dt = clamp(mx.exp(mx.random.uniform(shape=[self.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)), min=args.time_step_floor) - self.dt_proj.bias = dt + mx.log1p(-mx.exp(-dt)) - inv_dt = dt + mx.log1p(-mx.exp(-dt)) - self.dt_proj.bias = inv_dt - - A = mx.repeat(mx.arange(1., 16 + 1.).reshape([1, 16]), repeats=self.intermediate_size, axis=0) + A = mx.repeat(mx.arange(1., self.ssm_state_size + 1), "n -> d n", repeats=self.intermediate_size) self.A_log = mx.log(A) self.D = mx.ones([self.intermediate_size]) self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) - def ssm(self, x, h): - A = -mx.exp(self.A_log) - D = self.D - delta, B, C = self.x_proj(x).split(indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.time_step_rank], axis=-1) - delta = nn.softplus(self.dt_proj(delta)) - deltaA = mx.exp(unsqueeze(delta, -1) * A) - deltaB = unsqueeze(delta, -1) * unsqueeze(B, 1) - BX = deltaB * unsqueeze(x, -1) - if h is None: - h = mx.zeros([x.shape[0], self.config.d_inner, self.config.d_state]) - h = deltaA * h + BX - y = (h @ unsqueeze(C, -1)).squeeze(2) - y = y + D * x - return y, h - - def __call__(self, x, cache: MambaCache): - h, inputs = cache - x, z = self.in_proj(x).split(indices_or_sections=2, axis=1) - x_cache = unsqueeze(x, 1) - x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.config.d_conv-1, :] # (B, ED) - y, h = self.ssm(nn.silu(x), h) - output = y * nn.silu(z) - inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, d_conv-1, ED) - # cache = (h, inputs) - cache.update(h, inputs) - return self.out_proj(output), cache + def ssm(self, x): + (d_in, n) = self.A_log.shape + + A = -mx.exp(self.A_log.float()) # shape (d_in, n) + D = self.D.float() + + x_dbl = self.x_proj(x) # (b, l, time_step_rank + 2*n) + + (delta, B, C) = x_dbl.split(indices_or_sections=[self.time_step_rank, n, n], axis=-1) # delta: (b, l, time_step_rank). B, C: (b, l, n) + delta = nn.softplus(self.dt_proj(delta)) # (b, l, d_in) + + y = self.selective_scan(x, delta, A, B, C, D) + + return y + + def selective_scan(self, u, delta, A, B, C, D): + (b, l, d_in) = u.shape + n = A.shape[1] + deltaA = mx.exp(einsum(delta, A, 'b l d_in, d_in n -> b d_in l n')) + deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b d_in l n') + x = mx.zeros((b, d_in, n), device=deltaA.device) + ys = [] + for i in range(l): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') + ys.append(y) + y = mx.stack(ys, dim=1) # shape (b, l, d_in) + + y = y + u * D + return y + + def __call__(self, x): + (b, l, d) = x.shape + x_copy = x + x, res = self.in_proj(self.norm(x)).split(indices_or_sections=[self.intermediate_size, self.intermediate_size], axis=-1) + + x = mx.rearrange(x, 'b l d_in -> b d_in l') + x = self.conv1d(x)[:, :, :l] + x = mx.rearrange(x, 'b d_in l -> b l d_in') + + x = nn.silu(x) + + y = self.ssm(x) + + y = y * nn.silu(res) + return self.out_proj(y) + x_copy class ResidualBlock(nn.Module): @@ -134,10 +137,10 @@ def __init__(self, args: ModelArgs): self.mixer = MambaBlock(args) self.norm = nn.RMSNorm(args.hidden_size) - def __call__(self, inputs: mx.array, cache): - output, cache = self.mixer(self.norm(inputs), cache) + def __call__(self, inputs: mx.array): + output = self.mixer(self.norm(inputs)) output = output + inputs - return output, cache + return output class Mamba(nn.Module): @@ -147,14 +150,11 @@ def __init__(self, args: ModelArgs): self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] self.norm_f = nn.RMSNorm(args.hidden_size) - def __call__(self, inputs: mx.array, cache=None): + def __call__(self, inputs: mx.array): tokens = self.embeddings(inputs) - if cache is None: - cache = [None] * len(self.layers) for i, layer in enumerate(self.layers): - h, cache[i] = layer(tokens, cache[i]) - h = self.norm_f(h) - return h, cache + h, = layer(tokens) + return self.norm_f(h) class Model(nn.Module): @@ -166,7 +166,7 @@ def __init__(self, args: ModelArgs): # self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__(self, inputs: mx.array, cache=None): - out, cache = self.backbone(inputs, cache) + out = self.backbone(inputs) out = self.backbone.embeddings.as_linear(out) return out, cache From 2cb3dc27d70d0c9df50e87ef82128e4ba7e5117a Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 23 Jul 2024 20:03:56 +0200 Subject: [PATCH 12/40] update --- llms/mlx_lm/models/mamba.py | 30 +---------------------- llms/mlx_lm/utils.py | 47 ++++++++++++++----------------------- 2 files changed, 18 insertions(+), 59 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index bf6c5ad1b..ad59e1051 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -180,32 +180,4 @@ def head_dim(self): @property def n_kv_heads(self): - return self.args.num_hidden_layers - - def generate(self, input_ids: mx.array, n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): - self.eval() - - if input_ids.ndim == 1: - input_ids = input_ids.unsqueeze(0) - - caches = [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] - - for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): - next_token_logits, caches = self(input_ids[:, i], caches) - - if i+1 >= input_ids.shape[1]: - - if top_k is not None: - values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to biggest - mask = next_token_logits < (values[:, 0, None]) - next_token_logits = mx.where(mask, -5000, next_token_logits) # TODO -mx.inf is problematic for now - - if sample and temperature > 0: - next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) - else: - next_token = mx.argmax(next_token_logits, axis=-1)[:, None] - - input_ids = mx.concatenate([input_ids, next_token], axis=1) - - self.train() - return input_ids + return self.args.num_hidden_layers \ No newline at end of file diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index df88c38e2..1c8233eba 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -28,8 +28,7 @@ # Constants MODEL_REMAPPING = { "mistral": "llama", # mistral is compatible with llama - "phi-msft": "phixtral", - "mamba": "mamba" + "phi-msft": "phixtral" } MAX_FILE_SIZE_GB = 5 @@ -199,27 +198,22 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]: def _step(y): nonlocal repetition_context - if model.args.model_type == "mamba": - output_ids = model.generate(input_ids=y, n_tokens_to_gen=1, sample=temp > 0, temperature=temp) - next_token = output_ids[:, -1:] # Get the last generated token - logprobs = mx.zeros_like(next_token) # Dummy logprobs as we don't have actual logprobs from generate method - else: - logits = model(y[None], cache=cache) - logits = logits[:, -1, :] + logits = model(y[None], cache=cache) + logits = logits[:, -1, :] - if repetition_penalty: - logits = apply_repetition_penalty( - logits, repetition_context, repetition_penalty - ) - next_token, logprobs = sample(logits) - repetition_context.append(next_token.item()) - else: - next_token, logprobs = sample(logits) + if repetition_penalty: + logits = apply_repetition_penalty( + logits, repetition_context, repetition_penalty + ) + next_token, logprobs = sample(logits) + repetition_context.append(next_token.item()) + else: + next_token, logprobs = sample(logits) - if repetition_context_size: - if len(repetition_context) > repetition_context_size: - repetition_context = repetition_context[-repetition_context_size:] - + if repetition_context_size: + if len(repetition_context) > repetition_context_size: + repetition_context = repetition_context[-repetition_context_size:] + return next_token, logprobs.squeeze(0) y, logprobs = _step(y) @@ -255,11 +249,7 @@ def stream_generate( if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) - if model.args.model_type == "mamba": - prompt_tokens = mx.array(tokenizer.encode(prompt, return_tensors='np')) - else: - prompt_tokens = mx.array(tokenizer.encode(prompt)) - + prompt_tokens = mx.array(tokenizer.encode(prompt)) detokenizer = tokenizer.detokenizer detokenizer.reset() @@ -309,10 +299,7 @@ def generate( print("=" * 10) print("Prompt:", prompt) - if model.args.model_type == "mamba": - prompt_tokens = mx.array(tokenizer.encode(prompt, return_tensors='np')) - else: - prompt_tokens = mx.array(tokenizer.encode(prompt)) + prompt_tokens = mx.array(tokenizer.encode(prompt)) detokenizer = tokenizer.detokenizer From 17e52f6c9f7f352ace4adaf56f88207cd5d12ca4 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 24 Jul 2024 11:47:41 +0200 Subject: [PATCH 13/40] fixed cache handeling --- llms/mlx_lm/models/mamba-save.py | 191 ++++++++++++++++++++++++++++++ llms/mlx_lm/models/mamba.py | 194 +++++++++++++++++++++---------- llms/mlx_lm/utils.py | 2 +- 3 files changed, 324 insertions(+), 63 deletions(-) diff --git a/llms/mlx_lm/models/mamba-save.py b/llms/mlx_lm/models/mamba-save.py index 49418b520..f9174e187 100644 --- a/llms/mlx_lm/models/mamba-save.py +++ b/llms/mlx_lm/models/mamba-save.py @@ -232,3 +232,194 @@ def generate(self, input_ids: mx.array, n_tokens_to_gen: int = 50, sample: bool self.train() return input_ids + + + + + + + + +# from dataclasses import dataclass +# from typing import Optional, Union + +# import math +# import einsum + +# import mlx.core as mx +# import mlx.nn as nn + +# from .base import BaseModelArgs, MambaCache + + +# @dataclass +# class ModelArgs(BaseModelArgs): +# model_type: str +# vocab_size: int +# hidden_size: int # d_model +# intermediate_size: int # d_inner +# state_size: int # d_state +# num_hidden_layers: int # n_layer +# layer_norm_epsilon: float +# expand: int +# conv_kernel: int # d_conv +# use_bias: bool # bias +# use_conv_bias: bool # conv_bias +# initializer_range: float +# time_step_rank: int +# time_step_scale: float +# time_step_min: float +# time_step_max: float +# time_step_init_scheme: str +# time_step_floor: float +# rescale_prenorm_residual: bool +# use_cache: bool +# use_mambapy: bool = False # pscan +# dt_rank: str = "auto" + +# def __post_init__(self): +# self.intermediate_size = self.expand * self.hidden_size +# if self.dt_rank == "auto": +# self.dt_rank = math.ceil(self.hidden_size / 16) + + +# def clamp(x, min=None, max=None): +# if min is not None: +# mask_lower = x < min +# if max is not None: +# mask_upper = x > max +# if min is not None: +# if max is not None: +# return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) +# return mx.where(mask_lower, min, x) +# return mx.where(mask_upper, max, x) + +# class MambaBlock(nn.Module): +# def __init__(self, args: ModelArgs): +# super().__init__() +# self.args = args + +# self.hidden_size = args.hidden_size +# self.ssm_state_size = args.state_size +# self.conv_kernel_size = args.conv_kernel +# self.intermediate_size = args.intermediate_size +# self.time_step_rank = int(args.time_step_rank) +# self.use_conv_bias = args.use_conv_bias + +# self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) + +# self.conv1d = nn.Conv1d( +# in_channels=self.intermediate_size, +# out_channels=self.intermediate_size, +# kernel_size=self.conv_kernel_size, +# bias=self.use_conv_bias, +# padding=self.conv_kernel_size-1 +# ) + +# self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) +# self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + +# A = mx.repeat(mx.arange(1., self.ssm_state_size + 1), "n -> d n", repeats=self.intermediate_size) +# self.A_log = mx.log(A) +# self.D = mx.ones([self.intermediate_size]) + +# self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) + +# def ssm(self, x): +# (d_in, n) = self.A_log.shape + +# A = -mx.exp(self.A_log.float()) # shape (d_in, n) +# D = self.D.float() + +# x_dbl = self.x_proj(x) # (b, l, time_step_rank + 2*n) + +# (delta, B, C) = x_dbl.split(indices_or_sections=[self.time_step_rank, n, n], axis=-1) # delta: (b, l, time_step_rank). B, C: (b, l, n) +# delta = nn.softplus(self.dt_proj(delta)) # (b, l, d_in) + +# y = self.selective_scan(x, delta, A, B, C, D) + +# return y + +# def selective_scan(self, u, delta, A, B, C, D): +# (b, l, d_in) = u.shape +# n = A.shape[1] +# deltaA = mx.exp(einsum(delta, A, 'b l d_in, d_in n -> b d_in l n')) +# deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b d_in l n') +# x = mx.zeros((b, d_in, n), device=deltaA.device) +# ys = [] +# for i in range(l): +# x = deltaA[:, :, i] * x + deltaB_u[:, :, i] +# y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') +# ys.append(y) +# y = mx.stack(ys, dim=1) # shape (b, l, d_in) + +# y = y + u * D +# return y + +# def __call__(self, x): +# (b, l, d) = x.shape +# x_copy = x +# x, res = self.in_proj(self.norm(x)).split(indices_or_sections=[self.intermediate_size, self.intermediate_size], axis=-1) + +# x = mx.rearrange(x, 'b l d_in -> b d_in l') +# x = self.conv1d(x)[:, :, :l] +# x = mx.rearrange(x, 'b d_in l -> b l d_in') + +# x = nn.silu(x) + +# y = self.ssm(x) + +# y = y * nn.silu(res) +# return self.out_proj(y) + x_copy + + +# class ResidualBlock(nn.Module): +# def __init__(self, args: ModelArgs): +# super().__init__() +# self.mixer = MambaBlock(args) +# self.norm = nn.RMSNorm(args.hidden_size) + +# def __call__(self, inputs: mx.array): +# output = self.mixer(self.norm(inputs)) +# output = output + inputs +# return output + + +# class Mamba(nn.Module): +# def __init__(self, args: ModelArgs): +# super().__init__() +# self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) +# self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] +# self.norm_f = nn.RMSNorm(args.hidden_size) + +# def __call__(self, inputs: mx.array): +# tokens = self.embeddings(inputs) +# for i, layer in enumerate(self.layers): +# h, = layer(tokens) +# return self.norm_f(h) + + +# class Model(nn.Module): +# def __init__(self, args: ModelArgs): +# super().__init__() +# self.args = args +# self.model_type = args.model_type +# self.backbone = Mamba(args) + +# def __call__(self, inputs: mx.array, cache=None): +# out = self.backbone(inputs) +# out = self.backbone.embeddings.as_linear(out) +# return out, cache + +# @property +# def layers(self): +# return self.backbone.layers + +# @property +# def head_dim(self): +# return self.args.hidden_size // self.args.num_hidden_layers + +# @property +# def n_kv_heads(self): +# return self.args.num_hidden_layers + diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index ad59e1051..32f000cce 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -1,13 +1,12 @@ from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional import math -import einsum import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, MambaCache +from .base import BaseModelArgs @dataclass @@ -36,6 +35,21 @@ class ModelArgs(BaseModelArgs): dt_rank: str = "auto" def __post_init__(self): + if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'): + self.hidden_size = self.d_model + if not hasattr(self, 'intermediate_size') and hasattr(self, 'd_inner'): + self.intermediate_size = self.d_inner + if not hasattr(self, 'state_size') and hasattr(self, 'd_state'): + self.state_size = self.d_state + if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'): + self.num_hidden_layers = self.n_layer + if not hasattr(self, 'conv_kernel') and hasattr(self, 'd_conv'): + self.conv_kernel = self.d_conv + if not hasattr(self, 'use_bias') and hasattr(self, 'bias'): + self.use_bias = self.bias + if not hasattr(self, 'use_conv_bias') and hasattr(self, 'conv_bias'): + self.use_conv_bias = self.conv_bias + self.intermediate_size = self.expand * self.hidden_size if self.dt_rank == "auto": self.dt_rank = math.ceil(self.hidden_size / 16) @@ -52,6 +66,36 @@ def clamp(x, min=None, max=None): return mx.where(mask_lower, min, x) return mx.where(mask_upper, max, x) + +def unsqueeze(x, axis): + assert axis <= len(x.shape) + if axis >= 0: + new_shape = x.shape[:axis] + tuple([1]) + x.shape[axis:] + else: + new_shape = x.shape + tuple([1]) + return x.reshape(new_shape) + + +class DepthWiseConv1d(nn.Module): + def __init__(self, channels, kernel_size, bias, padding): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.bias = bias + self.padding = padding + self.weight = mx.random.normal(shape=(channels, 1, kernel_size)) + scale = math.sqrt(1.0 / (channels * kernel_size)) + self.weight *= scale # Ensure scaling is applied correctly + if bias: + self.bias = mx.zeros((channels,)) + else: + self.bias = None + + def __call__(self, x): + out = nn.Conv1d(x, self.weight, kernel_size=self.kernel_size, bias=self.bias, padding=self.padding) + return out + + class MambaBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -66,69 +110,60 @@ def __init__(self, args: ModelArgs): self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) - self.conv1d = nn.Conv1d( - in_channels=self.intermediate_size, - out_channels=self.intermediate_size, + self.conv1d = DepthWiseConv1d( + channels=self.intermediate_size, kernel_size=self.conv_kernel_size, bias=self.use_conv_bias, padding=self.conv_kernel_size-1 ) - self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) - A = mx.repeat(mx.arange(1., self.ssm_state_size + 1), "n -> d n", repeats=self.intermediate_size) + dt_init_std = args.dt_rank**-0.5 * args.state_size + if args.time_step_init_scheme == "constant": + self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) + elif args.time_step_init_scheme == "random": + self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) + else: + raise NotImplementedError + + dt = clamp(mx.exp(mx.random.uniform(shape=[self.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)), min=args.time_step_floor) + self.dt_proj.bias = dt + mx.log1p(-mx.exp(-dt)) + inv_dt = dt + mx.log1p(-mx.exp(-dt)) + self.dt_proj.bias = inv_dt + + A = mx.repeat(mx.arange(1, self.ssm_state_size + 1).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) self.A_log = mx.log(A) self.D = mx.ones([self.intermediate_size]) self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) - def ssm(self, x): - (d_in, n) = self.A_log.shape - - A = -mx.exp(self.A_log.float()) # shape (d_in, n) - D = self.D.float() - - x_dbl = self.x_proj(x) # (b, l, time_step_rank + 2*n) - - (delta, B, C) = x_dbl.split(indices_or_sections=[self.time_step_rank, n, n], axis=-1) # delta: (b, l, time_step_rank). B, C: (b, l, n) - delta = nn.softplus(self.dt_proj(delta)) # (b, l, d_in) - - y = self.selective_scan(x, delta, A, B, C, D) - - return y - - def selective_scan(self, u, delta, A, B, C, D): - (b, l, d_in) = u.shape - n = A.shape[1] - deltaA = mx.exp(einsum(delta, A, 'b l d_in, d_in n -> b d_in l n')) - deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b d_in l n') - x = mx.zeros((b, d_in, n), device=deltaA.device) - ys = [] - for i in range(l): - x = deltaA[:, :, i] * x + deltaB_u[:, :, i] - y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') - ys.append(y) - y = mx.stack(ys, dim=1) # shape (b, l, d_in) - - y = y + u * D - return y - - def __call__(self, x): - (b, l, d) = x.shape - x_copy = x - x, res = self.in_proj(self.norm(x)).split(indices_or_sections=[self.intermediate_size, self.intermediate_size], axis=-1) - - x = mx.rearrange(x, 'b l d_in -> b d_in l') - x = self.conv1d(x)[:, :, :l] - x = mx.rearrange(x, 'b d_in l -> b l d_in') - - x = nn.silu(x) - - y = self.ssm(x) - - y = y * nn.silu(res) - return self.out_proj(y) + x_copy + def ssm(self, x, h): + A = -mx.exp(self.A_log) + D = self.D + delta, B, C = self.x_proj(x).split(split_size=[self.intermediate_size, self.intermediate_size], axis=-1) + delta = nn.softplus(self.dt_proj(delta)) + deltaA = mx.exp(mx.unsqueeze(delta, -1) * A) + deltaB = unsqueeze(delta, -1) * unsqueeze(B, 1) + BX = deltaB * unsqueeze(x, -1) + if h is None: + h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) + h = deltaA * h + BX + y = (h @ mx.unsqueeze(C, -1)).squeeze(2) + y = y + D * x + return y, h + + def __call__(self, x, cache = None): + h, inputs = cache + x, z = self.in_proj(x).split(indices_or_sections=2, axis=1) + x_cache = unsqueeze(x, 1) + x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.conv_kernel_size-1, :] # (B, ED) + y, h = self.ssm(nn.silu(x), h) + output = y * nn.silu(z) + inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) + cache.update(h, inputs) + return self.out_proj(output), cache class ResidualBlock(nn.Module): @@ -137,10 +172,10 @@ def __init__(self, args: ModelArgs): self.mixer = MambaBlock(args) self.norm = nn.RMSNorm(args.hidden_size) - def __call__(self, inputs: mx.array): - output = self.mixer(self.norm(inputs)) + def __call__(self, inputs: mx.array, cache): + output, cache = self.mixer(self.norm(inputs), cache) output = output + inputs - return output + return output, cache class Mamba(nn.Module): @@ -150,11 +185,14 @@ def __init__(self, args: ModelArgs): self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] self.norm_f = nn.RMSNorm(args.hidden_size) - def __call__(self, inputs: mx.array): + def __call__(self, inputs: mx.array, cache=None): tokens = self.embeddings(inputs) + if cache is None: + cache = [None] * len(self.layers) for i, layer in enumerate(self.layers): - h, = layer(tokens) - return self.norm_f(h) + h, cache[i] = layer(tokens, cache[i]) + h = self.norm_f(h) + return h, cache class Model(nn.Module): @@ -166,7 +204,7 @@ def __init__(self, args: ModelArgs): # self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__(self, inputs: mx.array, cache=None): - out = self.backbone(inputs) + out, cache = self.backbone(inputs, cache) out = self.backbone.embeddings.as_linear(out) return out, cache @@ -180,4 +218,36 @@ def head_dim(self): @property def n_kv_heads(self): - return self.args.num_hidden_layers \ No newline at end of file + return self.args.num_hidden_layers + + def make_cache(self): + # return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] + return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))] + + def generate(self, input_ids: mx.array, n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): + self.eval() + + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + + caches = self.make_cache() + + for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): + next_token_logits, caches = self(input_ids[:, i], caches) + + if i+1 >= input_ids.shape[1]: + + if top_k is not None: + values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to biggest + mask = next_token_logits < (values[:, 0, None]) + next_token_logits = mx.where(mask, -5000, next_token_logits) # TODO -mx.inf is problematic for now + + if sample and temperature > 0: + next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) + else: + next_token = mx.argmax(next_token_logits, axis=-1)[:, None] + + input_ids = mx.concatenate([input_ids, next_token], axis=1) + + self.train() + return input_ids diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 1c8233eba..a06c9d2d7 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -19,7 +19,7 @@ from transformers import PreTrainedTokenizer # Local imports -from .models.base import KVCache +from .models.base import KVCache, MambaCache from .sample_utils import top_p_sampling from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import apply_lora_layers From a11563a30b928ffcfa8c18d1507a80b9829af05d Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 24 Jul 2024 13:07:29 +0200 Subject: [PATCH 14/40] fixed loading --- llms/mlx_lm/models/mamba.py | 178 ++++++++++++++++++++++++++++++++++-- llms/mlx_lm/utils.py | 10 +- 2 files changed, 176 insertions(+), 12 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 32f000cce..716e7c33a 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -76,22 +76,105 @@ def unsqueeze(x, axis): return x.reshape(new_shape) +def pscan_f(A, X): + # A : (B, D, L, N) + # X : (B, D, L, N) + + # modifies X in place by doing a parallel scan. + # more formally, X will be populated by these values : + # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0 + # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps) + + Aa = A + Xa = X + + B, D, L, _ = A.shape + + num_steps = int(math.log2(L)) + + # up sweep + for k in range(num_steps): + T = 2 * (Xa.shape[2] // 2) + + Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) + Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) + + Xa[:, :, :, 1] += Aa[:, :, :, 1] * Xa[:, :, :, 0] + Aa[:, :, :, 1] *= Aa[:, :, :, 0] + + A[:, :, 2**(k+1)-1::2**(k+1)] = Aa[:, :, :, 1] + X[:, :, 2**(k+1)-1::2**(k+1)] = Xa[:, :, :, 1] + + Aa = Aa[:, :, :, 1] + Xa = Xa[:, :, :, 1] + + # down sweep + for k in range(num_steps-1, -1, -1): + Aa = A[:, :, 2**k-1::2**k] + Xa = X[:, :, 2**k-1::2**k] + + step_len = Xa.shape[2] + T = 2 * (step_len // 2) + + if T < step_len: + last_val_aa = Aa[:, :, -1] * Aa[:, :, -2] + last_val_xa = Xa[:, :, -1] + Aa[:, :, -1] * Xa[:, :, -2] + + Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) + Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) + + Xa[:, :, 1:, 0] += Aa[:, :, 1:, 0] * Xa[:, :, :-1, 1] + Aa[:, :, 1:, 0] *= Aa[:, :, :-1, 1] + + if T == step_len: + A[:, :, 2**k-1::2**(k+1)] = Aa[:, :, :, 0] + X[:, :, 2**k-1::2**(k+1)] = Xa[:, :, :, 0] + else: + A[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Aa[:, :, :, 0], mx.array([last_val_aa]).reshape(B, D, 1, -1)], axis=2) + X[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Xa[:, :, :, 0], mx.array([last_val_xa]).reshape(B, D, 1, -1)], axis=2) + +# main function, used in the Mamba model (mamba_mlx.py) +def pscan(A_in, X_in): + """ + Applies the parallel scan operation, as defined above. Returns a new tensor. + + Args: + A_in : (B, L, ED, N) + X_in : (B, L, ED, N) + + Returns: + H : (B, L, ED, N) + """ + + A = A_in[:].transpose(0, 2, 1, 3) + X = X_in[:].transpose(0, 2, 1, 3) + + pscan_f(A, X) + + return X.transpose(0, 2, 1, 3) + + class DepthWiseConv1d(nn.Module): def __init__(self, channels, kernel_size, bias, padding): super().__init__() - self.channels = channels - self.kernel_size = kernel_size + self.channels = int(channels) + self.kernel_size = int(kernel_size) self.bias = bias self.padding = padding - self.weight = mx.random.normal(shape=(channels, 1, kernel_size)) - scale = math.sqrt(1.0 / (channels * kernel_size)) + self.weight = mx.random.normal(shape=(self.channels, 1, self.kernel_size)) + scale = math.sqrt(1.0 / (self.channels * self.kernel_size)) self.weight *= scale # Ensure scaling is applied correctly if bias: - self.bias = mx.zeros((channels,)) + self.bias = mx.zeros((self.channels,)) else: self.bias = None def __call__(self, x): + B, D, L = x.shape + assert D == self.channels, f"Input channels ({D}) must match the initialized channels ({self.channels})." + print("FORWARD PASS THROUGH CONV") + print(self.kernel_size) + print(self.weight) out = nn.Conv1d(x, self.weight, kernel_size=self.kernel_size, bias=self.bias, padding=self.padding) return out @@ -153,12 +236,58 @@ def ssm(self, x, h): y = (h @ mx.unsqueeze(C, -1)).squeeze(2) y = y + D * x return y, h + + def ssm_old(self, x): + # x : (B, L, ED) - def __call__(self, x, cache = None): + # y : (B, L, ED) + + A = -mx.exp(self.A_log) # (ED, N) + D = self.D + + deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N) + + delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, L, dt_rank), (B, L, N), (B, L, N) + delta = mx.softplus(self.dt_proj(delta)) # (B, L, ED) + + if self.config.pscan: + y = self.selective_scan(x, delta, A, B, C, D) + else: + y = self.selective_scan_seq(x, delta, A, B, C, D) + + return y + + def selective_scan(self, x, delta, A, B, C, D): + deltaA = mx.exp(unsqueeze(delta, -1) * A) # (B, L, ED, N) + deltaB = unsqueeze(delta, -1) * unsqueeze(B, 2) # (B, L, ED, N) + BX = deltaB * unsqueeze(x, -1) # (B, L, ED, N) + hs = pscan(deltaA, BX) + y = (hs @ unsqueeze(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) + y = y + D * x + return y + + def selective_scan_seq(self, x, delta, A, B, C, D): + _, L, _ = x.shape + deltaA = mx.exp(unsqueeze(delta, -1) * A) # (B, L, ED, N) + deltaB = unsqueeze(delta, -1) * unsqueeze(B, 2) # (B, L, ED, N) + BX = deltaB * unsqueeze(x, -1) # (B, L, ED, N) + h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) + hs = [] + for t in range(0, L): + h = deltaA[:, t] * h + BX[:, t] + hs.append(h) + hs = mx.stack(hs, axis=1) + y = (hs @ unsqueeze(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) + y = y + D * x + return y + + def __call__(self, x, cache=None): h, inputs = cache - x, z = self.in_proj(x).split(indices_or_sections=2, axis=1) + xz = self.in_proj(x) + split_dim = xz.shape[-1] // 2 # Correct the axis for splitting, ensuring the dimensions can be split equally + x, z = mx.split(xz, indices_or_sections=[split_dim], axis=-1) x_cache = unsqueeze(x, 1) - x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.conv_kernel_size-1, :] # (B, ED) + x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.conv_kernel_size-1, :] y, h = self.ssm(nn.silu(x), h) output = y * nn.silu(z) inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) @@ -238,9 +367,9 @@ def generate(self, input_ids: mx.array, n_tokens_to_gen: int = 50, sample: bool if i+1 >= input_ids.shape[1]: if top_k is not None: - values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to biggest + values = mx.topk(next_token_logits, k=top_k) mask = next_token_logits < (values[:, 0, None]) - next_token_logits = mx.where(mask, -5000, next_token_logits) # TODO -mx.inf is problematic for now + next_token_logits = mx.where(mask, -5000, next_token_logits) if sample and temperature > 0: next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) @@ -251,3 +380,32 @@ def generate(self, input_ids: mx.array, n_tokens_to_gen: int = 50, sample: bool self.train() return input_ids + + def generate_step(self, input_ids: mx.array, sample: bool = True, temperature: float = 1.0, top_k: int = None): + self.eval() + + if input_ids.ndim == 1: + input_ids = unsqueeze(input_ids, 0) + + caches = self.make_cache() + + # Generate the next token logits + next_token_logits, caches = self(input_ids, caches) + + # Apply top_k filtering if specified + if top_k is not None: + values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to highest + mask = next_token_logits < (values[:, -1, None]) + next_token_logits = mx.where(mask, -5000, next_token_logits) # -mx.inf is problematic for now + + # Sample the next token or take the argmax based on the temperature + if sample and temperature > 0: + next_token = mx.random.categorical(next_token_logits * (1 / temperature), num_samples=1) + else: + next_token = mx.argmax(next_token_logits, axis=-1)[:, None] + + # Concatenate the next token to the input_ids + input_ids = mx.concatenate([input_ids, next_token], axis=1) + + self.train() + return input_ids, caches diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index a06c9d2d7..cf72e709d 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -216,11 +216,17 @@ def _step(y): return next_token, logprobs.squeeze(0) - y, logprobs = _step(y) + if hasattr(model, 'generate_step'): + y, logprobs = model.generate_step(prompt) + else: + y, logprobs = _step(y) mx.async_eval(y) while True: - next_y, next_logprobs = _step(y) + if hasattr(model, 'generate_step'): + next_y, next_logprobs = model.generate_step(y) + else: + next_y, next_logprobs = _step(y) mx.async_eval(next_y) yield y.item(), logprobs y, logprobs = next_y, next_logprobs From fd566a9a196d0a5d06d5e30cf4699fd295fc208b Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 24 Jul 2024 13:16:29 +0200 Subject: [PATCH 15/40] added seperate generat_step method in the model and also in the utils to automaticaly use the generate step mthod in the model class --- llms/mlx_lm/models/mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 716e7c33a..6776efcbc 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -408,4 +408,4 @@ def generate_step(self, input_ids: mx.array, sample: bool = True, temperature: f input_ids = mx.concatenate([input_ids, next_token], axis=1) self.train() - return input_ids, caches + return input_ids, caches \ No newline at end of file From f872a4b350c6d448a3462a163eedea04b70da4b7 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 24 Jul 2024 13:43:13 +0200 Subject: [PATCH 16/40] quick update --- llms/mlx_lm/models/mamba.py | 48 ++++++++++--------------------------- 1 file changed, 13 insertions(+), 35 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 6776efcbc..0b8004d22 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -16,7 +16,7 @@ class ModelArgs(BaseModelArgs): hidden_size: int # d_model intermediate_size: int # d_inner state_size: int # d_state - num_hidden_layers: int # n_layer + num_hidden_layers: int # n_layer, n_layer layer_norm_epsilon: float expand: int conv_kernel: int # d_conv @@ -43,6 +43,8 @@ def __post_init__(self): self.state_size = self.d_state if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'): self.num_hidden_layers = self.n_layer + if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'): + self.num_hidden_layers = self.n_layers if not hasattr(self, 'conv_kernel') and hasattr(self, 'd_conv'): self.conv_kernel = self.d_conv if not hasattr(self, 'use_bias') and hasattr(self, 'bias'): @@ -77,14 +79,6 @@ def unsqueeze(x, axis): def pscan_f(A, X): - # A : (B, D, L, N) - # X : (B, D, L, N) - - # modifies X in place by doing a parallel scan. - # more formally, X will be populated by these values : - # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0 - # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps) - Aa = A Xa = X @@ -133,24 +127,10 @@ def pscan_f(A, X): A[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Aa[:, :, :, 0], mx.array([last_val_aa]).reshape(B, D, 1, -1)], axis=2) X[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Xa[:, :, :, 0], mx.array([last_val_xa]).reshape(B, D, 1, -1)], axis=2) -# main function, used in the Mamba model (mamba_mlx.py) def pscan(A_in, X_in): - """ - Applies the parallel scan operation, as defined above. Returns a new tensor. - - Args: - A_in : (B, L, ED, N) - X_in : (B, L, ED, N) - - Returns: - H : (B, L, ED, N) - """ - A = A_in[:].transpose(0, 2, 1, 3) X = X_in[:].transpose(0, 2, 1, 3) - pscan_f(A, X) - return X.transpose(0, 2, 1, 3) @@ -211,12 +191,13 @@ def __init__(self, args: ModelArgs): else: raise NotImplementedError - dt = clamp(mx.exp(mx.random.uniform(shape=[self.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)), min=args.time_step_floor) - self.dt_proj.bias = dt + mx.log1p(-mx.exp(-dt)) + dt = clamp(mx.exp( + mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min) + ), min=args.time_step_floor) inv_dt = dt + mx.log1p(-mx.exp(-dt)) self.dt_proj.bias = inv_dt - A = mx.repeat(mx.arange(1, self.ssm_state_size + 1).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) + A = mx.repeat(mx.arange(1., self.ssm_state_size + 1.).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) self.A_log = mx.log(A) self.D = mx.ones([self.intermediate_size]) @@ -250,7 +231,7 @@ def ssm_old(self, x): delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, L, dt_rank), (B, L, N), (B, L, N) delta = mx.softplus(self.dt_proj(delta)) # (B, L, ED) - if self.config.pscan: + if self.args.use_mambapy: y = self.selective_scan(x, delta, A, B, C, D) else: y = self.selective_scan_seq(x, delta, A, B, C, D) @@ -284,8 +265,8 @@ def selective_scan_seq(self, x, delta, A, B, C, D): def __call__(self, x, cache=None): h, inputs = cache xz = self.in_proj(x) - split_dim = xz.shape[-1] // 2 # Correct the axis for splitting, ensuring the dimensions can be split equally - x, z = mx.split(xz, indices_or_sections=[split_dim], axis=-1) + split_dim = xz.shape[-1] // 2 # Correct the axis for splitting, ensuring the dimensions can be split equally + x, z = mx.split(xz, indices_or_sections=[split_dim], axis=1) # [split_dim] instead of 2 x_cache = unsqueeze(x, 1) x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.conv_kernel_size-1, :] y, h = self.ssm(nn.silu(x), h) @@ -314,10 +295,8 @@ def __init__(self, args: ModelArgs): self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] self.norm_f = nn.RMSNorm(args.hidden_size) - def __call__(self, inputs: mx.array, cache=None): + def __call__(self, inputs: mx.array, cache): tokens = self.embeddings(inputs) - if cache is None: - cache = [None] * len(self.layers) for i, layer in enumerate(self.layers): h, cache[i] = layer(tokens, cache[i]) h = self.norm_f(h) @@ -330,7 +309,6 @@ def __init__(self, args: ModelArgs): self.args = args self.model_type = args.model_type self.backbone = Mamba(args) - # self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__(self, inputs: mx.array, cache=None): out, cache = self.backbone(inputs, cache) @@ -350,8 +328,8 @@ def n_kv_heads(self): return self.args.num_hidden_layers def make_cache(self): - # return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] - return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))] + return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] + # return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))] def generate(self, input_ids: mx.array, n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): self.eval() From 5bf5a4f775e1093d51c27e43646da4c2e129720f Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 24 Jul 2024 14:37:10 +0200 Subject: [PATCH 17/40] still not working --- llms/mlx_lm/models/mamba.py | 154 ++++++++++++++++++++---------------- 1 file changed, 85 insertions(+), 69 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 0b8004d22..be1a11ebe 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -69,15 +69,6 @@ def clamp(x, min=None, max=None): return mx.where(mask_upper, max, x) -def unsqueeze(x, axis): - assert axis <= len(x.shape) - if axis >= 0: - new_shape = x.shape[:axis] + tuple([1]) + x.shape[axis:] - else: - new_shape = x.shape + tuple([1]) - return x.reshape(new_shape) - - def pscan_f(A, X): Aa = A Xa = X @@ -203,78 +194,102 @@ def __init__(self, args: ModelArgs): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) + # def ssm_old(self, x): + # A = -mx.exp(self.A_log) # (ED, N) + # D = self.D + + # deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N) + + # delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, L, dt_rank), (B, L, N), (B, L, N) + # delta = mx.softplus(self.dt_proj(delta)) # (B, L, ED) + + # if self.args.use_mambapy: + # y = self.selective_scan(x, delta, A, B, C, D) + # else: + # y = self.selective_scan_seq(x, delta, A, B, C, D) + # return y + + # def selective_scan(self, x, delta, A, B, C, D): + # deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N) + # deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N) + # BX = deltaB * mx.expand_dims(x, -1) # (B, L, ED, N) + # hs = pscan(deltaA, BX) + # y = (hs @ mx.expand_dims(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) + # y = y + D * x + # return y + + # def selective_scan_seq(self, x, delta, A, B, C, D): + # _, L, _ = x.shape + # deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N) + # deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N) + # BX = deltaB * mx.expand_dims(x, -1) # (B, L, ED, N) + # h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) + # hs = [] + # for t in range(0, L): + # h = deltaA[:, t] * h + BX[:, t] + # hs.append(h) + # hs = mx.stack(hs, axis=1) + # y = (hs @ mx.expand_dims(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) + # y = y + D * x + # return y + def ssm(self, x, h): - A = -mx.exp(self.A_log) + A = -mx.exp(self.A_log) # (ED, N) D = self.D - delta, B, C = self.x_proj(x).split(split_size=[self.intermediate_size, self.intermediate_size], axis=-1) - delta = nn.softplus(self.dt_proj(delta)) - deltaA = mx.exp(mx.unsqueeze(delta, -1) * A) - deltaB = unsqueeze(delta, -1) * unsqueeze(B, 1) - BX = deltaB * unsqueeze(x, -1) + + deltaBC = self.x_proj(x) # (B, dt_rank+2*N) + + delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) + delta = nn.softplus(self.dt_proj(delta)) # (B, ED) + + deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) + deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N) + + BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) + if h is None: - h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) - h = deltaA * h + BX - y = (h @ mx.unsqueeze(C, -1)).squeeze(2) + h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) + + h = deltaA * h + BX # (B, ED, N) + + y = (h @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1) + y = y + D * x return y, h - def ssm_old(self, x): - # x : (B, L, ED) + def __call__(self, x, cache): + h, inputs = cache - # y : (B, L, ED) + xz = self.in_proj(x) # (B, 2*ED) + x, z = mx.split(xz, indices_or_sections=2, axis=-1) # (B, ED), (B, ED) - A = -mx.exp(self.A_log) # (ED, N) - D = self.D + # x branch + x_cache = mx.expand_dims(x, 1) # (B, 1, ED) - deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N) + # Ensure inputs has the correct shape + if inputs.ndim == 2: + inputs = mx.expand_dims(inputs, 1) # Add a dimension if it's missing - delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, L, dt_rank), (B, L, N), (B, L, N) - delta = mx.softplus(self.dt_proj(delta)) # (B, L, ED) + print(f"inputs shape: {inputs.shape}") + print(f"x_cache shape: {x_cache.shape}") + + conv_input = mx.concatenate([inputs, x_cache], axis=1) # (B, d_conv, ED) + x = self.conv1d(conv_input)[:, -1, :] # (B, ED) - if self.args.use_mambapy: - y = self.selective_scan(x, delta, A, B, C, D) - else: - y = self.selective_scan_seq(x, delta, A, B, C, D) + x = nn.silu(x) + y, h = self.ssm(x, h) - return y - - def selective_scan(self, x, delta, A, B, C, D): - deltaA = mx.exp(unsqueeze(delta, -1) * A) # (B, L, ED, N) - deltaB = unsqueeze(delta, -1) * unsqueeze(B, 2) # (B, L, ED, N) - BX = deltaB * unsqueeze(x, -1) # (B, L, ED, N) - hs = pscan(deltaA, BX) - y = (hs @ unsqueeze(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) - y = y + D * x - return y - - def selective_scan_seq(self, x, delta, A, B, C, D): - _, L, _ = x.shape - deltaA = mx.exp(unsqueeze(delta, -1) * A) # (B, L, ED, N) - deltaB = unsqueeze(delta, -1) * unsqueeze(B, 2) # (B, L, ED, N) - BX = deltaB * unsqueeze(x, -1) # (B, L, ED, N) - h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) - hs = [] - for t in range(0, L): - h = deltaA[:, t] * h + BX[:, t] - hs.append(h) - hs = mx.stack(hs, axis=1) - y = (hs @ unsqueeze(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) - y = y + D * x - return y + # z branch + z = nn.silu(z) - def __call__(self, x, cache=None): - h, inputs = cache - xz = self.in_proj(x) - split_dim = xz.shape[-1] // 2 # Correct the axis for splitting, ensuring the dimensions can be split equally - x, z = mx.split(xz, indices_or_sections=[split_dim], axis=1) # [split_dim] instead of 2 - x_cache = unsqueeze(x, 1) - x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.conv_kernel_size-1, :] - y, h = self.ssm(nn.silu(x), h) - output = y * nn.silu(z) - inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) - cache.update(h, inputs) - return self.out_proj(output), cache + output = y * z + output = self.out_proj(output) # (B, D) + + # prepare cache for next call + inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, d_conv-1, ED) + cache = (h, inputs) + return output, cache class ResidualBlock(nn.Module): def __init__(self, args: ModelArgs): @@ -328,6 +343,7 @@ def n_kv_heads(self): return self.args.num_hidden_layers def make_cache(self): + # return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] # return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))] @@ -335,7 +351,7 @@ def generate(self, input_ids: mx.array, n_tokens_to_gen: int = 50, sample: bool self.eval() if input_ids.ndim == 1: - input_ids = input_ids.unsqueeze(0) + input_ids = mx.expand_dims(input_ids, 0) caches = self.make_cache() @@ -363,7 +379,7 @@ def generate_step(self, input_ids: mx.array, sample: bool = True, temperature: f self.eval() if input_ids.ndim == 1: - input_ids = unsqueeze(input_ids, 0) + input_ids = mx.expand_dims(input_ids, 0) caches = self.make_cache() From 6d65fcb1a3600fba7b8175742938b4a0b04fc670 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Thu, 25 Jul 2024 14:12:27 +0200 Subject: [PATCH 18/40] save --- llms/mlx_lm/models/mamba-save.py | 407 +++++++++++++++++++++++++++++++ llms/mlx_lm/models/mamba.py | 124 +--------- 2 files changed, 415 insertions(+), 116 deletions(-) diff --git a/llms/mlx_lm/models/mamba-save.py b/llms/mlx_lm/models/mamba-save.py index f9174e187..9858158f3 100644 --- a/llms/mlx_lm/models/mamba-save.py +++ b/llms/mlx_lm/models/mamba-save.py @@ -423,3 +423,410 @@ def generate(self, input_ids: mx.array, n_tokens_to_gen: int = 50, sample: bool # def n_kv_heads(self): # return self.args.num_hidden_layers + + +from dataclasses import dataclass +from typing import Optional + +import math + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + vocab_size: int + hidden_size: int # d_model + intermediate_size: int # d_inner + state_size: int # d_state + num_hidden_layers: int # n_layer, n_layer + layer_norm_epsilon: float + expand: int + conv_kernel: int # d_conv + use_bias: bool # bias + use_conv_bias: bool # conv_bias + initializer_range: float + time_step_rank: int + time_step_scale: float + time_step_min: float + time_step_max: float + time_step_init_scheme: str + time_step_floor: float + rescale_prenorm_residual: bool + use_cache: bool + use_mambapy: bool = False # pscan + dt_rank: str = "auto" + + def __post_init__(self): + if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'): + self.hidden_size = self.d_model + if not hasattr(self, 'intermediate_size') and hasattr(self, 'd_inner'): + self.intermediate_size = self.d_inner + if not hasattr(self, 'state_size') and hasattr(self, 'd_state'): + self.state_size = self.d_state + if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'): + self.num_hidden_layers = self.n_layer + if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'): + self.num_hidden_layers = self.n_layers + if not hasattr(self, 'conv_kernel') and hasattr(self, 'd_conv'): + self.conv_kernel = self.d_conv + if not hasattr(self, 'use_bias') and hasattr(self, 'bias'): + self.use_bias = self.bias + if not hasattr(self, 'use_conv_bias') and hasattr(self, 'conv_bias'): + self.use_conv_bias = self.conv_bias + + self.intermediate_size = self.expand * self.hidden_size + if self.dt_rank == "auto": + self.dt_rank = math.ceil(self.hidden_size / 16) + + +def clamp(x, min=None, max=None): + if min is not None: + mask_lower = x < min + if max is not None: + mask_upper = x > max + if min is not None: + if max is not None: + return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) + return mx.where(mask_lower, min, x) + return mx.where(mask_upper, max, x) + + +def pscan_f(A, X): + Aa = A + Xa = X + + B, D, L, _ = A.shape + + num_steps = int(math.log2(L)) + + # up sweep + for k in range(num_steps): + T = 2 * (Xa.shape[2] // 2) + + Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) + Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) + + Xa[:, :, :, 1] += Aa[:, :, :, 1] * Xa[:, :, :, 0] + Aa[:, :, :, 1] *= Aa[:, :, :, 0] + + A[:, :, 2**(k+1)-1::2**(k+1)] = Aa[:, :, :, 1] + X[:, :, 2**(k+1)-1::2**(k+1)] = Xa[:, :, :, 1] + + Aa = Aa[:, :, :, 1] + Xa = Xa[:, :, :, 1] + + # down sweep + for k in range(num_steps-1, -1, -1): + Aa = A[:, :, 2**k-1::2**k] + Xa = X[:, :, 2**k-1::2**k] + + step_len = Xa.shape[2] + T = 2 * (step_len // 2) + + if T < step_len: + last_val_aa = Aa[:, :, -1] * Aa[:, :, -2] + last_val_xa = Xa[:, :, -1] + Aa[:, :, -1] * Xa[:, :, -2] + + Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) + Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) + + Xa[:, :, 1:, 0] += Aa[:, :, 1:, 0] * Xa[:, :, :-1, 1] + Aa[:, :, 1:, 0] *= Aa[:, :, :-1, 1] + + if T == step_len: + A[:, :, 2**k-1::2**(k+1)] = Aa[:, :, :, 0] + X[:, :, 2**k-1::2**(k+1)] = Xa[:, :, :, 0] + else: + A[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Aa[:, :, :, 0], mx.array([last_val_aa]).reshape(B, D, 1, -1)], axis=2) + X[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Xa[:, :, :, 0], mx.array([last_val_xa]).reshape(B, D, 1, -1)], axis=2) + +def pscan(A_in, X_in): + A = A_in[:].transpose(0, 2, 1, 3) + X = X_in[:].transpose(0, 2, 1, 3) + pscan_f(A, X) + return X.transpose(0, 2, 1, 3) + + +class DepthWiseConv1d(nn.Module): + def __init__(self, channels, kernel_size, bias, padding): + super().__init__() + self.channels = int(channels) + self.kernel_size = int(kernel_size) + self.bias = bias + self.padding = padding + self.weight = mx.random.normal(shape=(self.channels, 1, self.kernel_size)) + scale = math.sqrt(1.0 / (self.channels * self.kernel_size)) + self.weight *= scale # Ensure scaling is applied correctly + if bias: + self.bias = mx.zeros((self.channels,)) + else: + self.bias = None + + def __call__(self, x): + B, D, L = x.shape + assert D == self.channels, f"Input channels ({D}) must match the initialized channels ({self.channels})." + print("FORWARD PASS THROUGH CONV") + print(self.kernel_size) + print(self.weight) + out = nn.Conv1d(x, self.weight, kernel_size=self.kernel_size, bias=self.bias, padding=self.padding) + return out + + +class MambaBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.hidden_size = args.hidden_size + self.ssm_state_size = args.state_size + self.conv_kernel_size = args.conv_kernel + self.intermediate_size = args.intermediate_size + self.time_step_rank = int(args.time_step_rank) + self.use_conv_bias = args.use_conv_bias + + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) + + self.conv1d = DepthWiseConv1d( + channels=self.intermediate_size, + kernel_size=self.conv_kernel_size, + bias=self.use_conv_bias, + padding=self.conv_kernel_size-1 + ) + + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + + dt_init_std = args.dt_rank**-0.5 * args.state_size + if args.time_step_init_scheme == "constant": + self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) + elif args.time_step_init_scheme == "random": + self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) + else: + raise NotImplementedError + + dt = clamp(mx.exp( + mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min) + ), min=args.time_step_floor) + inv_dt = dt + mx.log1p(-mx.exp(-dt)) + self.dt_proj.bias = inv_dt + + A = mx.repeat(mx.arange(1., self.ssm_state_size + 1.).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) + self.A_log = mx.log(A) + self.D = mx.ones([self.intermediate_size]) + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) + + # def ssm_old(self, x): + # A = -mx.exp(self.A_log) # (ED, N) + # D = self.D + + # deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N) + + # delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, L, dt_rank), (B, L, N), (B, L, N) + # delta = mx.softplus(self.dt_proj(delta)) # (B, L, ED) + + # if self.args.use_mambapy: + # y = self.selective_scan(x, delta, A, B, C, D) + # else: + # y = self.selective_scan_seq(x, delta, A, B, C, D) + # return y + + # def selective_scan(self, x, delta, A, B, C, D): + # deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N) + # deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N) + # BX = deltaB * mx.expand_dims(x, -1) # (B, L, ED, N) + # hs = pscan(deltaA, BX) + # y = (hs @ mx.expand_dims(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) + # y = y + D * x + # return y + + # def selective_scan_seq(self, x, delta, A, B, C, D): + # _, L, _ = x.shape + # deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N) + # deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N) + # BX = deltaB * mx.expand_dims(x, -1) # (B, L, ED, N) + # h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) + # hs = [] + # for t in range(0, L): + # h = deltaA[:, t] * h + BX[:, t] + # hs.append(h) + # hs = mx.stack(hs, axis=1) + # y = (hs @ mx.expand_dims(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) + # y = y + D * x + # return y + + def ssm(self, x, h): + A = -mx.exp(self.A_log) # (ED, N) + D = self.D + + deltaBC = self.x_proj(x) # (B, dt_rank+2*N) + + delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) + delta = nn.softplus(self.dt_proj(delta)) # (B, ED) + + deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) + deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N) + + BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) + + if h is None: + h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) + + h = deltaA * h + BX # (B, ED, N) + + y = (h @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1) + + y = y + D * x + return y, h + + def __call__(self, x, cache): + h, inputs = cache + + xz = self.in_proj(x) # (B, 2*ED) + x, z = mx.split(xz, indices_or_sections=2, axis=-1) # (B, ED), (B, ED) + + # x branch + x_cache = mx.expand_dims(x, 1) # (B, 1, ED) + + # Ensure inputs has the correct shape + if inputs.ndim == 2: + inputs = mx.expand_dims(inputs, 1) # Add a dimension if it's missing + + print(f"inputs shape: {inputs.shape}") + print(f"x_cache shape: {x_cache.shape}") + + conv_input = mx.concatenate([inputs, x_cache], axis=1) # (B, d_conv, ED) + x = self.conv1d(conv_input)[:, -1, :] # (B, ED) + + x = nn.silu(x) + y, h = self.ssm(x, h) + + # z branch + z = nn.silu(z) + + output = y * z + output = self.out_proj(output) # (B, D) + + # prepare cache for next call + inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, d_conv-1, ED) + cache = (h, inputs) + + return output, cache + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.mixer = MambaBlock(args) + self.norm = nn.RMSNorm(args.hidden_size) + + def __call__(self, inputs: mx.array, cache): + output, cache = self.mixer(self.norm(inputs), cache) + output = output + inputs + return output, cache + + +class Mamba(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] + self.norm_f = nn.RMSNorm(args.hidden_size) + + def __call__(self, inputs: mx.array, cache): + tokens = self.embeddings(inputs) + for i, layer in enumerate(self.layers): + h, cache[i] = layer(tokens, cache[i]) + h = self.norm_f(h) + return h, cache + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.backbone = Mamba(args) + + def __call__(self, inputs: mx.array, cache=None): + out, cache = self.backbone(inputs, cache) + out = self.backbone.embeddings.as_linear(out) + return out, cache + + @property + def layers(self): + return self.backbone.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_hidden_layers + + @property + def n_kv_heads(self): + return self.args.num_hidden_layers + + def make_cache(self): + # return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] + return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] + # return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))] + + def generate(self, input_ids: mx.array, n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): + self.eval() + + if input_ids.ndim == 1: + input_ids = mx.expand_dims(input_ids, 0) + + caches = self.make_cache() + + for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): + next_token_logits, caches = self(input_ids[:, i], caches) + + if i+1 >= input_ids.shape[1]: + + if top_k is not None: + values = mx.topk(next_token_logits, k=top_k) + mask = next_token_logits < (values[:, 0, None]) + next_token_logits = mx.where(mask, -5000, next_token_logits) + + if sample and temperature > 0: + next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) + else: + next_token = mx.argmax(next_token_logits, axis=-1)[:, None] + + input_ids = mx.concatenate([input_ids, next_token], axis=1) + + self.train() + return input_ids + + def generate_step(self, input_ids: mx.array, sample: bool = True, temperature: float = 1.0, top_k: int = None): + self.eval() + + if input_ids.ndim == 1: + input_ids = mx.expand_dims(input_ids, 0) + + caches = self.make_cache() + + # Generate the next token logits + next_token_logits, caches = self(input_ids, caches) + + # Apply top_k filtering if specified + if top_k is not None: + values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to highest + mask = next_token_logits < (values[:, -1, None]) + next_token_logits = mx.where(mask, -5000, next_token_logits) # -mx.inf is problematic for now + + # Sample the next token or take the argmax based on the temperature + if sample and temperature > 0: + next_token = mx.random.categorical(next_token_logits * (1 / temperature), num_samples=1) + else: + next_token = mx.argmax(next_token_logits, axis=-1)[:, None] + + # Concatenate the next token to the input_ids + input_ids = mx.concatenate([input_ids, next_token], axis=1) + + self.train() + return input_ids, caches \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index be1a11ebe..3ed07dd87 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -13,15 +13,15 @@ class ModelArgs(BaseModelArgs): model_type: str vocab_size: int - hidden_size: int # d_model - intermediate_size: int # d_inner - state_size: int # d_state - num_hidden_layers: int # n_layer, n_layer + hidden_size: int + intermediate_size: int + state_size: int + num_hidden_layers: int layer_norm_epsilon: float expand: int - conv_kernel: int # d_conv - use_bias: bool # bias - use_conv_bias: bool # conv_bias + conv_kernel: int + use_bias: bool + use_conv_bias: bool initializer_range: float time_step_rank: int time_step_scale: float @@ -31,7 +31,7 @@ class ModelArgs(BaseModelArgs): time_step_floor: float rescale_prenorm_residual: bool use_cache: bool - use_mambapy: bool = False # pscan + use_mambapy: bool = False dt_rank: str = "auto" def __post_init__(self): @@ -56,75 +56,6 @@ def __post_init__(self): if self.dt_rank == "auto": self.dt_rank = math.ceil(self.hidden_size / 16) - -def clamp(x, min=None, max=None): - if min is not None: - mask_lower = x < min - if max is not None: - mask_upper = x > max - if min is not None: - if max is not None: - return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) - return mx.where(mask_lower, min, x) - return mx.where(mask_upper, max, x) - - -def pscan_f(A, X): - Aa = A - Xa = X - - B, D, L, _ = A.shape - - num_steps = int(math.log2(L)) - - # up sweep - for k in range(num_steps): - T = 2 * (Xa.shape[2] // 2) - - Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) - - Xa[:, :, :, 1] += Aa[:, :, :, 1] * Xa[:, :, :, 0] - Aa[:, :, :, 1] *= Aa[:, :, :, 0] - - A[:, :, 2**(k+1)-1::2**(k+1)] = Aa[:, :, :, 1] - X[:, :, 2**(k+1)-1::2**(k+1)] = Xa[:, :, :, 1] - - Aa = Aa[:, :, :, 1] - Xa = Xa[:, :, :, 1] - - # down sweep - for k in range(num_steps-1, -1, -1): - Aa = A[:, :, 2**k-1::2**k] - Xa = X[:, :, 2**k-1::2**k] - - step_len = Xa.shape[2] - T = 2 * (step_len // 2) - - if T < step_len: - last_val_aa = Aa[:, :, -1] * Aa[:, :, -2] - last_val_xa = Xa[:, :, -1] + Aa[:, :, -1] * Xa[:, :, -2] - - Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) - - Xa[:, :, 1:, 0] += Aa[:, :, 1:, 0] * Xa[:, :, :-1, 1] - Aa[:, :, 1:, 0] *= Aa[:, :, :-1, 1] - - if T == step_len: - A[:, :, 2**k-1::2**(k+1)] = Aa[:, :, :, 0] - X[:, :, 2**k-1::2**(k+1)] = Xa[:, :, :, 0] - else: - A[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Aa[:, :, :, 0], mx.array([last_val_aa]).reshape(B, D, 1, -1)], axis=2) - X[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Xa[:, :, :, 0], mx.array([last_val_xa]).reshape(B, D, 1, -1)], axis=2) - -def pscan(A_in, X_in): - A = A_in[:].transpose(0, 2, 1, 3) - X = X_in[:].transpose(0, 2, 1, 3) - pscan_f(A, X) - return X.transpose(0, 2, 1, 3) - - class DepthWiseConv1d(nn.Module): def __init__(self, channels, kernel_size, bias, padding): super().__init__() @@ -194,45 +125,6 @@ def __init__(self, args: ModelArgs): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) - # def ssm_old(self, x): - # A = -mx.exp(self.A_log) # (ED, N) - # D = self.D - - # deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N) - - # delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, L, dt_rank), (B, L, N), (B, L, N) - # delta = mx.softplus(self.dt_proj(delta)) # (B, L, ED) - - # if self.args.use_mambapy: - # y = self.selective_scan(x, delta, A, B, C, D) - # else: - # y = self.selective_scan_seq(x, delta, A, B, C, D) - # return y - - # def selective_scan(self, x, delta, A, B, C, D): - # deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N) - # deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N) - # BX = deltaB * mx.expand_dims(x, -1) # (B, L, ED, N) - # hs = pscan(deltaA, BX) - # y = (hs @ mx.expand_dims(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) - # y = y + D * x - # return y - - # def selective_scan_seq(self, x, delta, A, B, C, D): - # _, L, _ = x.shape - # deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N) - # deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N) - # BX = deltaB * mx.expand_dims(x, -1) # (B, L, ED, N) - # h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) - # hs = [] - # for t in range(0, L): - # h = deltaA[:, t] * h + BX[:, t] - # hs.append(h) - # hs = mx.stack(hs, axis=1) - # y = (hs @ mx.expand_dims(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) - # y = y + D * x - # return y - def ssm(self, x, h): A = -mx.exp(self.A_log) # (ED, N) D = self.D From 380f8960b83ad93564a4c9803a92bcdb7472caeb Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Thu, 25 Jul 2024 19:02:43 +0200 Subject: [PATCH 19/40] still not working --- llms/mlx_lm/models/mamba.py | 41 +++++++++++++++++-------------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 3ed07dd87..340cce831 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -79,7 +79,19 @@ def __call__(self, x): print(self.weight) out = nn.Conv1d(x, self.weight, kernel_size=self.kernel_size, bias=self.bias, padding=self.padding) return out - + + +def clamp(x, min=None, max=None): + if min is not None: + mask_lower = x < min + if max is not None: + mask_upper = x > max + if min is not None: + if max is not None: + return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) + return mx.where(mask_lower, min, x) + return mx.where(mask_upper, max, x) + class MambaBlock(nn.Module): def __init__(self, args: ModelArgs): @@ -151,13 +163,9 @@ def ssm(self, x, h): def __call__(self, x, cache): h, inputs = cache - - xz = self.in_proj(x) # (B, 2*ED) - x, z = mx.split(xz, indices_or_sections=2, axis=-1) # (B, ED), (B, ED) - + x, z = self.in_proj(x).split(indices_or_sections=2, axis=-1) # (B, ED), (B, ED) # x branch x_cache = mx.expand_dims(x, 1) # (B, 1, ED) - # Ensure inputs has the correct shape if inputs.ndim == 2: inputs = mx.expand_dims(inputs, 1) # Add a dimension if it's missing @@ -165,23 +173,13 @@ def __call__(self, x, cache): print(f"inputs shape: {inputs.shape}") print(f"x_cache shape: {x_cache.shape}") - conv_input = mx.concatenate([inputs, x_cache], axis=1) # (B, d_conv, ED) + conv_input = mx.concatenate([inputs, x_cache], axis=1) # (B, d_conv, ED) <---------- Here is the problem ValueError: [concatenate] All the input arrays must have the same number of dimensions. However, got arrays with dimensions 3 and 4. ||| inputs shape: (1, 3, 1536) x_cache shape: (1, 1, 5, 1536) x = self.conv1d(conv_input)[:, -1, :] # (B, ED) - - x = nn.silu(x) - y, h = self.ssm(x, h) - - # z branch - z = nn.silu(z) - - output = y * z - output = self.out_proj(output) # (B, D) - + y, h = self.ssm(nn.silu(x), h) + output = y * nn.silu(z) # * z branch # prepare cache for next call inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, d_conv-1, ED) - cache = (h, inputs) - - return output, cache + return self.out_proj(output), (h, inputs) # (B, D), cache class ResidualBlock(nn.Module): def __init__(self, args: ModelArgs): @@ -236,8 +234,7 @@ def n_kv_heads(self): def make_cache(self): # return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] - return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] - # return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))] + return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))] def generate(self, input_ids: mx.array, n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): self.eval() From 73768cc69dea2ce6788fd4cca5d6ec083aeb7513 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 16 Aug 2024 14:38:42 +0200 Subject: [PATCH 20/40] initial commit --- llms/mlx_lm/models/mamba-old.py | 399 +++++++++++++ llms/mlx_lm/models/mamba-save.py | 832 +++++++++++++++++++++++++++ llms/mlx_lm/models/mamba-tiny-pld.py | 154 +++++ llms/mlx_lm/models/mamba-torch.py | 145 +++++ llms/mlx_lm/models/mamba.py | 294 ++++++++++ 5 files changed, 1824 insertions(+) create mode 100644 llms/mlx_lm/models/mamba-old.py create mode 100644 llms/mlx_lm/models/mamba-save.py create mode 100644 llms/mlx_lm/models/mamba-tiny-pld.py create mode 100644 llms/mlx_lm/models/mamba-torch.py create mode 100644 llms/mlx_lm/models/mamba.py diff --git a/llms/mlx_lm/models/mamba-old.py b/llms/mlx_lm/models/mamba-old.py new file mode 100644 index 000000000..844d4fb7d --- /dev/null +++ b/llms/mlx_lm/models/mamba-old.py @@ -0,0 +1,399 @@ +from dataclasses import dataclass +from typing import Optional, Union + +import math + +import torch + +# import tokenizer + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str = "mamba" + dt_rank: Union[int, str] # time_step_rank + d_model: int + d_inner: int + vocab_size: int + n_layer: int + use_bias: bool + use_conv_bias: bool + rms_norm: bool + conv_kernel: int + state_size: int + expand: int + time_step_init_scheme: str + time_step_max: float + time_step_min: float + time_step_floor: float + pscan: bool = False + tie_word_embeddings: bool = False + num_hidden_layers: int = None + hidden_size: int = None + # time_step_scale + + def __post_init__(self): + self.d_inner = self.expand * self.d_model + if self.n_layer is None: + self.n_layer = self.num_hidden_layers + if self.d_model is None: + self.d_model = self.hidden_size + if self.dt_rank == 'auto': + self.dt_rank = math.ceil(self.d_model / 16) + + +def pscan_main(A, X): + Aa = A + Xa = X + B, D, L, _ = A.shape + num_steps = int(math.log2(L)) + + for k in range(num_steps): + T = 2 * (Xa.shape[2] // 2) + Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) + Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) + Xa[:, :, :, 1] += Aa[:, :, :, 1] * Xa[:, :, :, 0] + Aa[:, :, :, 1] *= Aa[:, :, :, 0] + A[:, :, 2**(k+1)-1::2**(k+1)] = Aa[:, :, :, 1] + X[:, :, 2**(k+1)-1::2**(k+1)] = Xa[:, :, :, 1] + Aa = Aa[:, :, :, 1] + Xa = Xa[:, :, :, 1] + + for k in range(num_steps-1, -1, -1): + Aa = A[:, :, 2**k-1::2**k] + Xa = X[:, :, 2**k-1::2**k] + step_len = Xa.shape[2] + T = 2 * (step_len // 2) + if T < step_len: + last_val_aa = Aa[:, :, -1] * Aa[:, :, -2] + last_val_xa = Xa[:, :, -1] + Aa[:, :, -1] * Xa[:, :, -2] + Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) + Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) + Xa[:, :, 1:, 0] += Aa[:, :, 1:, 0] * Xa[:, :, :-1, 1] + Aa[:, :, 1:, 0] *= Aa[:, :, :-1, 1] + if T == step_len: + A[:, :, 2**k-1::2**(k+1)] = Aa[:, :, :, 0] + X[:, :, 2**k-1::2**(k+1)] = Xa[:, :, :, 0] + else: + A[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Aa[:, :, :, 0], mx.array([last_val_aa]).reshape(B, D, 1, -1)], axis=2) + X[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Xa[:, :, :, 0], mx.array([last_val_xa]).reshape(B, D, 1, -1)], axis=2) + + +def pscan(A_in, X_in): + A = A_in[:].transpose(0, 2, 1, 3) + X = X_in[:].transpose(0, 2, 1, 3) + pscan_main(A, X) + return X.transpose(0, 2, 1, 3) + + +def clamp(x, min=None, max=None): + if min is not None: + mask_lower = x < min + if max is not None: + mask_upper = x > max + if min is not None: + if max is not None: + return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) + return mx.where(mask_lower, min, x) + return mx.where(mask_upper, max, x) + + +def unsqueeze(x, axis): + assert axis <= len(x.shape) + if axis >= 0: + new_shape = x.shape[:axis] + tuple([1]) + x.shape[axis:] + else: + new_shape = x.shape + tuple([1]) + return x.reshape(new_shape) + + +class DepthWiseConv1d(nn.Module): + def __init__(self, channels, kernel_size, bias, padding): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.bias = bias + self.padding = padding + self.conv1d = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, bias=True, padding=padding) + indices = mx.arange(channels) + mask = mx.zeros_like(self.conv1d.weight) + mask[indices, :, indices] = 1 + self.conv1d.weight *= mask + + def __call__(self, x): + return self.conv1d(x) + + +class MambaBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.in_proj = nn.Linear(args.d_model, 2 * args.d_inner, bias=args.use_bias) + # self.conv1d = DepthWiseConv1d(channels=args.d_inner, kernel_size=args.conv_kernel, bias=args.use_conv_bias, padding=args.conv_kernel-1) + self.conv1d = nn.Conv1d( + in_channels=args.d_inner, + out_channels=args.d_inner, + bias=args.conv_bias, + kernel_size=args.d_conv, + groups=args.d_inner, + padding=args.d_conv - 1, + ) + self.x_proj = nn.Linear(args.d_inner, args.dt_rank + 2 * args.state_size, bias=False) + self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True) + + dt_init_std = args.dt_rank**-0.5 * args.state_size + if args.time_step_init_scheme == "constant": + self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) + elif args.time_step_init_scheme == "random": + self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) + else: + raise NotImplementedError + + dt = clamp(mx.exp(mx.random.uniform(shape=[args.d_inner]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)), min=args.time_step_floor) + self.dt_proj.bias = dt + mx.log1p(-mx.exp(-dt)) + A = mx.repeat(mx.arange(1., 16 + 1.).reshape([1, 16]), repeats=args.d_inner, axis=0) + self.A_log = mx.log(A) + self.D = mx.ones([args.d_inner]) + + self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.use_bias) + + self.norm = nn.RMSNorm(args.d_model) + + def ssm_step(self, x, h): + A = -mx.exp(self.A_log) + D = self.D + deltaBC = self.x_proj(self.norm(x)) + delta, B, C = mx.split(deltaBC, indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) + delta = nn.softplus(self.dt_proj(delta)) + deltaA = mx.exp(unsqueeze(delta, -1) * A) + deltaB = unsqueeze(delta, -1) * unsqueeze(B, 1) + BX = deltaB * unsqueeze(x, -1) + if h is None: + h = mx.zeros([x.shape[0], self.args.d_inner, self.args.state_size]) + h = deltaA * h + BX + y = (h @ unsqueeze(C, -1)).squeeze(2) + y = y + D * x + return y, h + + def ssm(self, x): # DONE + A = -mx.exp(self.A_log) + D = self.D + delta, B, C = self.x_proj(x).split(indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) + delta = nn.softplus(self.dt_proj(delta)) + if self.args.pscan: + y = self.selective_scan(x, delta, A, B, C, D) + else: + y = self.selective_scan_seq(x, delta, A, B, C, D) + return y + + def ssm_new(self, x): + d_in, N = self.A_log.shape + A = -mx.exp(self.A_log.float()) + D = self.D.float() + delta, B, C = self.x_proj(x).split(split_size=[self.config.dt_rank, N, N], dim=-1) + delta = nn.softplus(self.dt_proj(delta)) + return self.selective_scan_new(x, delta, A, B, C, D) + + def selective_scan(self, x, delta, A, B, C, D): # DONE + deltaA = mx.exp(unsqueeze(delta, -1) * A) + deltaB = unsqueeze(delta, -1) * unsqueeze(B, 2) + BX = deltaB * unsqueeze(x, -1) + hs = pscan(deltaA, BX) + y = (hs @ unsqueeze(C, -1)).squeeze(3) + return y + D * x + + def selective_scan_new(self, u, delta, A, B, C, D): + (b, l, d_in) = u.shape + n = A.shape[1] + deltaA = mx.exp(mx.einsum(delta, A, 'b l d_in, d_in n -> b d_in l n')) + deltaB_u = mx.einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b d_in l n') + + # Perform selective scan (see scan_SSM() in The Annotated S4 [2]) + x = mx.zeros((b, d_in, n), device=deltaA.device) + ys = [] + for i in range(l): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + y = mx.einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') + ys.append(y) + y = mx.stack(ys, dim=1) # shape (b, l, d_in) + + y = y + u * D + + return y + + def selective_scan_seq(self, x, delta, A, B, C, D): + _, L, _ = x.shape + deltaA = mx.exp(unsqueeze(delta, -1) * A) + deltaB = unsqueeze(delta, -1) * unsqueeze(B, 2) + BX = deltaB * unsqueeze(x, -1) + h = mx.zeros([x.shape[0], self.args.d_inner, self.args.state_size]) + hs = [] + for t in range(0, L): + h = deltaA[:, t] * h + BX[:, t] + hs.append(h) + hs = mx.stack(hs, axis=1) + y = (hs @ unsqueeze(C, -1)).squeeze(3) + return y + D * x + + def step(self, x, cache): # Done + h, inputs = cache + x, z = self.in_proj(x).split(indices_or_sections=2, axis=1) + x_cache = unsqueeze(x, 1) + x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.args.conv_kernel-1, :] + y, h = self.ssm_step(nn.silu(x), h) + output = y * nn.silu(z) + output = self.out_proj(output) + inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) + return output, (h, inputs) + + def ssm_step(self, x, h): # Done + A = -mx.exp(self.A_log) + D = self.D + delta, B, C = self.x_proj(x).split(indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) + delta = nn.softplus(self.dt_proj(delta)) + deltaB = unsqueeze(delta, -1) * unsqueeze(B, 1) + BX = deltaB * unsqueeze(x, -1) + if h is None: + h = mx.zeros([x.shape[0], self.args.d_inner, self.args.d_state]) + h = deltaA * h + BX + y = (h @ unsqueeze(C, -1)).squeeze(2) + y = y + D * x + return y, h + + def __call__(self, x): # DONE + _, L, _ = x.shape + x, z = self.in_proj(x).split(indices_or_sections=2, axis=2) + x = self.conv1d(x)[:, :L, :] + output = self.ssm(nn.silu(x)) * nn.silu(z) + return self.out_proj(output) + + def new(self, x): + _, L, _ = x.shape + x, r = self.in_proj(x).split([self.args.d_inner, self.args.d_inner], axis=-1) + + x = mx.reshape(x, 'b l d_in -> b d_in l') + x = self.conv1d(x)[:, :, :L] + x = mx.rearrange(x, 'b d_in l -> b l d_in') + out = self.ssm_new(nn.silu(x)) * nn.silu(r) + return self.out_proj(out) + x + + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.mixer = MambaBlock(args) + self.norm = nn.RMSNorm(args.d_model) + + def __call__(self, inputs: mx.array, cache: Optional[mx.array] = None): + output, cache = self.mixer.step(self.norm(inputs), cache) + output = output + inputs + return output, cache + + +class Mamba(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.embedding = nn.Embedding(args.vocab_size, args.d_model) + self.layers = [ResidualBlock(args) for _ in range(args.n_layer)] + self.norm_f = nn.RMSNorm(args.d_model) + + def __call__(self, inputs: mx.array, cache=None): + tokens = self.embedding(inputs) + if cache is None: + cache = [None] * len(self.layers) + for i, layer in enumerate(self.layers): + h, cache[i] = layer(tokens, cache[i]) + h = self.norm_f(h) + return h, cache + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.backbone = Mamba(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False) + + def __call__(self, inputs: mx.array, cache=None): + out, cache = self.backbone(inputs, cache) + + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + return out, cache + + # def torch_to_mlx_depthwise_weights(self, torch_weights): + # torch_weights = torch_weights.transpose(2, 1) + # channels, kernel_size, _ = torch_weights.shape + + # mlx_weights = torch.zeros(channels, kernel_size, channels) + + # indices = torch.arange(channels) + # if torch_weights[:, :, 0].type() == 'torch.BFloat16Tensor': + # mlx_weights[indices, :, indices] = torch_weights[:, :, 0].float() + # else: + # mlx_weights[indices, :, indices] = torch_weights[:, :, 0] + + # return mlx_weights + + def sanitize(self, torch_state_dict): + new_state_dict = {} + for key, value in torch_state_dict.items(): + if 'conv1d.weight' in key: + value = self.torch_to_mlx_depthwise_weights(value) + + if 'conv1d' in key: + key = key.replace('conv1d', 'conv1d.conv1d') + + if value.type() == 'torch.BFloat16Tensor': + new_state_dict[key] = value.half().numpy() + else: + new_state_dict[key] = value.numpy() + + return new_state_dict + + @property + def layers(self): + return self.model.layers + + def generate(self, tokenizer=None, prompt: str="Hello", n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): + self.eval() + + input_ids = mx.array(tokenizer(prompt, return_tensors='np').input_ids) + + caches = [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.d_inner])) for _ in range(self.args.n_layer)] + + for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): + next_token_logits, caches = self(input_ids[:, i], caches) + + if i+1 >= input_ids.shape[1]: + + if top_k is not None: + values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to biggest + mask = next_token_logits < (values[:, 0, None]) + next_token_logits = mx.where(mask, -5000, next_token_logits) # TODO -mx.inf is problematic for now + + if sample and temperature > 0: + next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) + else: + next_token = mx.argmax(next_token_logits, axis=-1)[:, None] + + input_ids = mx.concatenate([input_ids, next_token], axis=1) + + output = [tokenizer.decode(output.tolist()) for output in input_ids][0] + + self.train() + + return output + +# model = Model(ModelArgs()) +# print(model) + +# logits = model.generate() +# print(logits) diff --git a/llms/mlx_lm/models/mamba-save.py b/llms/mlx_lm/models/mamba-save.py new file mode 100644 index 000000000..9858158f3 --- /dev/null +++ b/llms/mlx_lm/models/mamba-save.py @@ -0,0 +1,832 @@ +from dataclasses import dataclass +from typing import Optional, Union + +import math + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, MambaCache + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + vocab_size: int + hidden_size: int # d_model + intermediate_size: int # d_inner + state_size: int # d_state + num_hidden_layers: int # n_layer + layer_norm_epsilon: float + expand: int + conv_kernel: int # d_conv + use_bias: bool # bias + use_conv_bias: bool # conv_bias + initializer_range: float + time_step_rank: int + time_step_scale: float + time_step_min: float + time_step_max: float + time_step_init_scheme: str + time_step_floor: float + rescale_prenorm_residual: bool + use_cache: bool + use_mambapy: bool = False # pscan + dt_rank: str = "auto" + + def __post_init__(self): + self.intermediate_size = self.expand * self.hidden_size + if self.dt_rank == "auto": + self.dt_rank = math.ceil(self.hidden_size / 16) + + +def clamp(x, min=None, max=None): + if min is not None: + mask_lower = x < min + if max is not None: + mask_upper = x > max + if min is not None: + if max is not None: + return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) + return mx.where(mask_lower, min, x) + return mx.where(mask_upper, max, x) + + +def unsqueeze(x, axis): + assert axis <= len(x.shape) + if axis >= 0: + new_shape = x.shape[:axis] + tuple([1]) + x.shape[axis:] + else: + new_shape = x.shape + tuple([1]) + return x.reshape(new_shape) + + +class DepthWiseConv1d(nn.Module): + def __init__(self, channels, kernel_size, bias, padding): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.bias = bias + self.padding = padding + self.weight = mx.random.normal(shape=(channels, 1, kernel_size)) + scale = math.sqrt(1.0 / (channels * kernel_size)) + self.weight *= scale # Ensure scaling is applied correctly + if bias: + self.bias = mx.zeros((channels,)) + else: + self.bias = None + + def __call__(self, x): + out = nn.Conv1d(x, self.weight, kernel_size=self.kernel_size, bias=self.bias, padding=self.padding) + return out + + +class MambaBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.hidden_size = args.hidden_size + self.ssm_state_size = args.state_size + self.conv_kernel_size = args.conv_kernel + self.intermediate_size = args.intermediate_size + self.time_step_rank = int(args.time_step_rank) + self.use_conv_bias = args.use_conv_bias + + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) + + self.conv1d = DepthWiseConv1d( + channels=self.intermediate_size, + kernel_size=self.conv_kernel_size, + bias=self.use_conv_bias, + padding=self.conv_kernel_size-1 + ) + + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + + dt_init_std = args.dt_rank**-0.5 * args.state_size + if args.time_step_init_scheme == "constant": + self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) + elif args.time_step_init_scheme == "random": + self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) + else: + raise NotImplementedError + + dt = clamp(mx.exp(mx.random.uniform(shape=[self.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)), min=args.time_step_floor) + self.dt_proj.bias = dt + mx.log1p(-mx.exp(-dt)) + inv_dt = dt + mx.log1p(-mx.exp(-dt)) + self.dt_proj.bias = inv_dt + + A = mx.repeat(mx.arange(1, self.ssm_state_size + 1).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) + self.A_log = mx.log(A) + self.D = mx.ones([self.intermediate_size]) + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) + + def ssm(self, x, h): + A = -mx.exp(self.A_log) + D = self.D + delta, B, C = self.x_proj(x).split(split_size=[self.intermediate_size, self.intermediate_size], dim=-1) + delta = nn.softplus(self.dt_proj(delta)) + deltaA = mx.exp(mx.unsqueeze(delta, -1) * A) + deltaB = unsqueeze(delta, -1) * unsqueeze(B, 1) + BX = deltaB * unsqueeze(x, -1) + if h is None: + h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) + h = deltaA * h + BX + y = (h @ mx.unsqueeze(C, -1)).squeeze(2) + y = y + D * x + return y, h + + def __call__(self, x, cache: Optional[MambaCache]): + h, inputs = cache + x, z = self.in_proj(x).split(indices_or_sections=2, axis=1) + x_cache = unsqueeze(x, 1) + x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.conv_kernel_size-1, :] # (B, ED) + y, h = self.ssm(nn.silu(x), h) + output = y * nn.silu(z) + inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) + cache.update(h, inputs) + return self.out_proj(output), cache + + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.mixer = MambaBlock(args) + self.norm = nn.RMSNorm(args.hidden_size) + + def __call__(self, inputs: mx.array, cache): + output, cache = self.mixer(self.norm(inputs), cache) + output = output + inputs + return output, cache + + +class Mamba(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] + self.norm_f = nn.RMSNorm(args.hidden_size) + + def __call__(self, inputs: mx.array, cache=None): + tokens = self.embeddings(inputs) + if cache is None: + cache = [None] * len(self.layers) + for i, layer in enumerate(self.layers): + h, cache[i] = layer(tokens, cache[i]) + h = self.norm_f(h) + return h, cache + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.backbone = Mamba(args) + # self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__(self, inputs: mx.array, cache=None): + out, cache = self.backbone(inputs, cache) + out = self.backbone.embeddings.as_linear(out) + return out, cache + + @property + def layers(self): + return self.backbone.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_hidden_layers + + @property + def n_kv_heads(self): + return self.args.num_hidden_layers + + def generate(self, input_ids: mx.array, n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): + self.eval() + + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + + caches = [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] + + for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): + next_token_logits, caches = self(input_ids[:, i], caches) + + if i+1 >= input_ids.shape[1]: + + if top_k is not None: + values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to biggest + mask = next_token_logits < (values[:, 0, None]) + next_token_logits = mx.where(mask, -5000, next_token_logits) # TODO -mx.inf is problematic for now + + if sample and temperature > 0: + next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) + else: + next_token = mx.argmax(next_token_logits, axis=-1)[:, None] + + input_ids = mx.concatenate([input_ids, next_token], axis=1) + + self.train() + return input_ids + + + + + + + + +# from dataclasses import dataclass +# from typing import Optional, Union + +# import math +# import einsum + +# import mlx.core as mx +# import mlx.nn as nn + +# from .base import BaseModelArgs, MambaCache + + +# @dataclass +# class ModelArgs(BaseModelArgs): +# model_type: str +# vocab_size: int +# hidden_size: int # d_model +# intermediate_size: int # d_inner +# state_size: int # d_state +# num_hidden_layers: int # n_layer +# layer_norm_epsilon: float +# expand: int +# conv_kernel: int # d_conv +# use_bias: bool # bias +# use_conv_bias: bool # conv_bias +# initializer_range: float +# time_step_rank: int +# time_step_scale: float +# time_step_min: float +# time_step_max: float +# time_step_init_scheme: str +# time_step_floor: float +# rescale_prenorm_residual: bool +# use_cache: bool +# use_mambapy: bool = False # pscan +# dt_rank: str = "auto" + +# def __post_init__(self): +# self.intermediate_size = self.expand * self.hidden_size +# if self.dt_rank == "auto": +# self.dt_rank = math.ceil(self.hidden_size / 16) + + +# def clamp(x, min=None, max=None): +# if min is not None: +# mask_lower = x < min +# if max is not None: +# mask_upper = x > max +# if min is not None: +# if max is not None: +# return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) +# return mx.where(mask_lower, min, x) +# return mx.where(mask_upper, max, x) + +# class MambaBlock(nn.Module): +# def __init__(self, args: ModelArgs): +# super().__init__() +# self.args = args + +# self.hidden_size = args.hidden_size +# self.ssm_state_size = args.state_size +# self.conv_kernel_size = args.conv_kernel +# self.intermediate_size = args.intermediate_size +# self.time_step_rank = int(args.time_step_rank) +# self.use_conv_bias = args.use_conv_bias + +# self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) + +# self.conv1d = nn.Conv1d( +# in_channels=self.intermediate_size, +# out_channels=self.intermediate_size, +# kernel_size=self.conv_kernel_size, +# bias=self.use_conv_bias, +# padding=self.conv_kernel_size-1 +# ) + +# self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) +# self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + +# A = mx.repeat(mx.arange(1., self.ssm_state_size + 1), "n -> d n", repeats=self.intermediate_size) +# self.A_log = mx.log(A) +# self.D = mx.ones([self.intermediate_size]) + +# self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) + +# def ssm(self, x): +# (d_in, n) = self.A_log.shape + +# A = -mx.exp(self.A_log.float()) # shape (d_in, n) +# D = self.D.float() + +# x_dbl = self.x_proj(x) # (b, l, time_step_rank + 2*n) + +# (delta, B, C) = x_dbl.split(indices_or_sections=[self.time_step_rank, n, n], axis=-1) # delta: (b, l, time_step_rank). B, C: (b, l, n) +# delta = nn.softplus(self.dt_proj(delta)) # (b, l, d_in) + +# y = self.selective_scan(x, delta, A, B, C, D) + +# return y + +# def selective_scan(self, u, delta, A, B, C, D): +# (b, l, d_in) = u.shape +# n = A.shape[1] +# deltaA = mx.exp(einsum(delta, A, 'b l d_in, d_in n -> b d_in l n')) +# deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b d_in l n') +# x = mx.zeros((b, d_in, n), device=deltaA.device) +# ys = [] +# for i in range(l): +# x = deltaA[:, :, i] * x + deltaB_u[:, :, i] +# y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') +# ys.append(y) +# y = mx.stack(ys, dim=1) # shape (b, l, d_in) + +# y = y + u * D +# return y + +# def __call__(self, x): +# (b, l, d) = x.shape +# x_copy = x +# x, res = self.in_proj(self.norm(x)).split(indices_or_sections=[self.intermediate_size, self.intermediate_size], axis=-1) + +# x = mx.rearrange(x, 'b l d_in -> b d_in l') +# x = self.conv1d(x)[:, :, :l] +# x = mx.rearrange(x, 'b d_in l -> b l d_in') + +# x = nn.silu(x) + +# y = self.ssm(x) + +# y = y * nn.silu(res) +# return self.out_proj(y) + x_copy + + +# class ResidualBlock(nn.Module): +# def __init__(self, args: ModelArgs): +# super().__init__() +# self.mixer = MambaBlock(args) +# self.norm = nn.RMSNorm(args.hidden_size) + +# def __call__(self, inputs: mx.array): +# output = self.mixer(self.norm(inputs)) +# output = output + inputs +# return output + + +# class Mamba(nn.Module): +# def __init__(self, args: ModelArgs): +# super().__init__() +# self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) +# self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] +# self.norm_f = nn.RMSNorm(args.hidden_size) + +# def __call__(self, inputs: mx.array): +# tokens = self.embeddings(inputs) +# for i, layer in enumerate(self.layers): +# h, = layer(tokens) +# return self.norm_f(h) + + +# class Model(nn.Module): +# def __init__(self, args: ModelArgs): +# super().__init__() +# self.args = args +# self.model_type = args.model_type +# self.backbone = Mamba(args) + +# def __call__(self, inputs: mx.array, cache=None): +# out = self.backbone(inputs) +# out = self.backbone.embeddings.as_linear(out) +# return out, cache + +# @property +# def layers(self): +# return self.backbone.layers + +# @property +# def head_dim(self): +# return self.args.hidden_size // self.args.num_hidden_layers + +# @property +# def n_kv_heads(self): +# return self.args.num_hidden_layers + + + +from dataclasses import dataclass +from typing import Optional + +import math + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + vocab_size: int + hidden_size: int # d_model + intermediate_size: int # d_inner + state_size: int # d_state + num_hidden_layers: int # n_layer, n_layer + layer_norm_epsilon: float + expand: int + conv_kernel: int # d_conv + use_bias: bool # bias + use_conv_bias: bool # conv_bias + initializer_range: float + time_step_rank: int + time_step_scale: float + time_step_min: float + time_step_max: float + time_step_init_scheme: str + time_step_floor: float + rescale_prenorm_residual: bool + use_cache: bool + use_mambapy: bool = False # pscan + dt_rank: str = "auto" + + def __post_init__(self): + if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'): + self.hidden_size = self.d_model + if not hasattr(self, 'intermediate_size') and hasattr(self, 'd_inner'): + self.intermediate_size = self.d_inner + if not hasattr(self, 'state_size') and hasattr(self, 'd_state'): + self.state_size = self.d_state + if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'): + self.num_hidden_layers = self.n_layer + if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'): + self.num_hidden_layers = self.n_layers + if not hasattr(self, 'conv_kernel') and hasattr(self, 'd_conv'): + self.conv_kernel = self.d_conv + if not hasattr(self, 'use_bias') and hasattr(self, 'bias'): + self.use_bias = self.bias + if not hasattr(self, 'use_conv_bias') and hasattr(self, 'conv_bias'): + self.use_conv_bias = self.conv_bias + + self.intermediate_size = self.expand * self.hidden_size + if self.dt_rank == "auto": + self.dt_rank = math.ceil(self.hidden_size / 16) + + +def clamp(x, min=None, max=None): + if min is not None: + mask_lower = x < min + if max is not None: + mask_upper = x > max + if min is not None: + if max is not None: + return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) + return mx.where(mask_lower, min, x) + return mx.where(mask_upper, max, x) + + +def pscan_f(A, X): + Aa = A + Xa = X + + B, D, L, _ = A.shape + + num_steps = int(math.log2(L)) + + # up sweep + for k in range(num_steps): + T = 2 * (Xa.shape[2] // 2) + + Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) + Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) + + Xa[:, :, :, 1] += Aa[:, :, :, 1] * Xa[:, :, :, 0] + Aa[:, :, :, 1] *= Aa[:, :, :, 0] + + A[:, :, 2**(k+1)-1::2**(k+1)] = Aa[:, :, :, 1] + X[:, :, 2**(k+1)-1::2**(k+1)] = Xa[:, :, :, 1] + + Aa = Aa[:, :, :, 1] + Xa = Xa[:, :, :, 1] + + # down sweep + for k in range(num_steps-1, -1, -1): + Aa = A[:, :, 2**k-1::2**k] + Xa = X[:, :, 2**k-1::2**k] + + step_len = Xa.shape[2] + T = 2 * (step_len // 2) + + if T < step_len: + last_val_aa = Aa[:, :, -1] * Aa[:, :, -2] + last_val_xa = Xa[:, :, -1] + Aa[:, :, -1] * Xa[:, :, -2] + + Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) + Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) + + Xa[:, :, 1:, 0] += Aa[:, :, 1:, 0] * Xa[:, :, :-1, 1] + Aa[:, :, 1:, 0] *= Aa[:, :, :-1, 1] + + if T == step_len: + A[:, :, 2**k-1::2**(k+1)] = Aa[:, :, :, 0] + X[:, :, 2**k-1::2**(k+1)] = Xa[:, :, :, 0] + else: + A[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Aa[:, :, :, 0], mx.array([last_val_aa]).reshape(B, D, 1, -1)], axis=2) + X[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Xa[:, :, :, 0], mx.array([last_val_xa]).reshape(B, D, 1, -1)], axis=2) + +def pscan(A_in, X_in): + A = A_in[:].transpose(0, 2, 1, 3) + X = X_in[:].transpose(0, 2, 1, 3) + pscan_f(A, X) + return X.transpose(0, 2, 1, 3) + + +class DepthWiseConv1d(nn.Module): + def __init__(self, channels, kernel_size, bias, padding): + super().__init__() + self.channels = int(channels) + self.kernel_size = int(kernel_size) + self.bias = bias + self.padding = padding + self.weight = mx.random.normal(shape=(self.channels, 1, self.kernel_size)) + scale = math.sqrt(1.0 / (self.channels * self.kernel_size)) + self.weight *= scale # Ensure scaling is applied correctly + if bias: + self.bias = mx.zeros((self.channels,)) + else: + self.bias = None + + def __call__(self, x): + B, D, L = x.shape + assert D == self.channels, f"Input channels ({D}) must match the initialized channels ({self.channels})." + print("FORWARD PASS THROUGH CONV") + print(self.kernel_size) + print(self.weight) + out = nn.Conv1d(x, self.weight, kernel_size=self.kernel_size, bias=self.bias, padding=self.padding) + return out + + +class MambaBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.hidden_size = args.hidden_size + self.ssm_state_size = args.state_size + self.conv_kernel_size = args.conv_kernel + self.intermediate_size = args.intermediate_size + self.time_step_rank = int(args.time_step_rank) + self.use_conv_bias = args.use_conv_bias + + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) + + self.conv1d = DepthWiseConv1d( + channels=self.intermediate_size, + kernel_size=self.conv_kernel_size, + bias=self.use_conv_bias, + padding=self.conv_kernel_size-1 + ) + + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + + dt_init_std = args.dt_rank**-0.5 * args.state_size + if args.time_step_init_scheme == "constant": + self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) + elif args.time_step_init_scheme == "random": + self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) + else: + raise NotImplementedError + + dt = clamp(mx.exp( + mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min) + ), min=args.time_step_floor) + inv_dt = dt + mx.log1p(-mx.exp(-dt)) + self.dt_proj.bias = inv_dt + + A = mx.repeat(mx.arange(1., self.ssm_state_size + 1.).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) + self.A_log = mx.log(A) + self.D = mx.ones([self.intermediate_size]) + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) + + # def ssm_old(self, x): + # A = -mx.exp(self.A_log) # (ED, N) + # D = self.D + + # deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N) + + # delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, L, dt_rank), (B, L, N), (B, L, N) + # delta = mx.softplus(self.dt_proj(delta)) # (B, L, ED) + + # if self.args.use_mambapy: + # y = self.selective_scan(x, delta, A, B, C, D) + # else: + # y = self.selective_scan_seq(x, delta, A, B, C, D) + # return y + + # def selective_scan(self, x, delta, A, B, C, D): + # deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N) + # deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N) + # BX = deltaB * mx.expand_dims(x, -1) # (B, L, ED, N) + # hs = pscan(deltaA, BX) + # y = (hs @ mx.expand_dims(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) + # y = y + D * x + # return y + + # def selective_scan_seq(self, x, delta, A, B, C, D): + # _, L, _ = x.shape + # deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N) + # deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N) + # BX = deltaB * mx.expand_dims(x, -1) # (B, L, ED, N) + # h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) + # hs = [] + # for t in range(0, L): + # h = deltaA[:, t] * h + BX[:, t] + # hs.append(h) + # hs = mx.stack(hs, axis=1) + # y = (hs @ mx.expand_dims(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) + # y = y + D * x + # return y + + def ssm(self, x, h): + A = -mx.exp(self.A_log) # (ED, N) + D = self.D + + deltaBC = self.x_proj(x) # (B, dt_rank+2*N) + + delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) + delta = nn.softplus(self.dt_proj(delta)) # (B, ED) + + deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) + deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N) + + BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) + + if h is None: + h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) + + h = deltaA * h + BX # (B, ED, N) + + y = (h @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1) + + y = y + D * x + return y, h + + def __call__(self, x, cache): + h, inputs = cache + + xz = self.in_proj(x) # (B, 2*ED) + x, z = mx.split(xz, indices_or_sections=2, axis=-1) # (B, ED), (B, ED) + + # x branch + x_cache = mx.expand_dims(x, 1) # (B, 1, ED) + + # Ensure inputs has the correct shape + if inputs.ndim == 2: + inputs = mx.expand_dims(inputs, 1) # Add a dimension if it's missing + + print(f"inputs shape: {inputs.shape}") + print(f"x_cache shape: {x_cache.shape}") + + conv_input = mx.concatenate([inputs, x_cache], axis=1) # (B, d_conv, ED) + x = self.conv1d(conv_input)[:, -1, :] # (B, ED) + + x = nn.silu(x) + y, h = self.ssm(x, h) + + # z branch + z = nn.silu(z) + + output = y * z + output = self.out_proj(output) # (B, D) + + # prepare cache for next call + inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, d_conv-1, ED) + cache = (h, inputs) + + return output, cache + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.mixer = MambaBlock(args) + self.norm = nn.RMSNorm(args.hidden_size) + + def __call__(self, inputs: mx.array, cache): + output, cache = self.mixer(self.norm(inputs), cache) + output = output + inputs + return output, cache + + +class Mamba(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] + self.norm_f = nn.RMSNorm(args.hidden_size) + + def __call__(self, inputs: mx.array, cache): + tokens = self.embeddings(inputs) + for i, layer in enumerate(self.layers): + h, cache[i] = layer(tokens, cache[i]) + h = self.norm_f(h) + return h, cache + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.backbone = Mamba(args) + + def __call__(self, inputs: mx.array, cache=None): + out, cache = self.backbone(inputs, cache) + out = self.backbone.embeddings.as_linear(out) + return out, cache + + @property + def layers(self): + return self.backbone.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_hidden_layers + + @property + def n_kv_heads(self): + return self.args.num_hidden_layers + + def make_cache(self): + # return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] + return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] + # return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))] + + def generate(self, input_ids: mx.array, n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): + self.eval() + + if input_ids.ndim == 1: + input_ids = mx.expand_dims(input_ids, 0) + + caches = self.make_cache() + + for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): + next_token_logits, caches = self(input_ids[:, i], caches) + + if i+1 >= input_ids.shape[1]: + + if top_k is not None: + values = mx.topk(next_token_logits, k=top_k) + mask = next_token_logits < (values[:, 0, None]) + next_token_logits = mx.where(mask, -5000, next_token_logits) + + if sample and temperature > 0: + next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) + else: + next_token = mx.argmax(next_token_logits, axis=-1)[:, None] + + input_ids = mx.concatenate([input_ids, next_token], axis=1) + + self.train() + return input_ids + + def generate_step(self, input_ids: mx.array, sample: bool = True, temperature: float = 1.0, top_k: int = None): + self.eval() + + if input_ids.ndim == 1: + input_ids = mx.expand_dims(input_ids, 0) + + caches = self.make_cache() + + # Generate the next token logits + next_token_logits, caches = self(input_ids, caches) + + # Apply top_k filtering if specified + if top_k is not None: + values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to highest + mask = next_token_logits < (values[:, -1, None]) + next_token_logits = mx.where(mask, -5000, next_token_logits) # -mx.inf is problematic for now + + # Sample the next token or take the argmax based on the temperature + if sample and temperature > 0: + next_token = mx.random.categorical(next_token_logits * (1 / temperature), num_samples=1) + else: + next_token = mx.argmax(next_token_logits, axis=-1)[:, None] + + # Concatenate the next token to the input_ids + input_ids = mx.concatenate([input_ids, next_token], axis=1) + + self.train() + return input_ids, caches \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba-tiny-pld.py b/llms/mlx_lm/models/mamba-tiny-pld.py new file mode 100644 index 000000000..8713978d5 --- /dev/null +++ b/llms/mlx_lm/models/mamba-tiny-pld.py @@ -0,0 +1,154 @@ +from dataclasses import dataclass +from typing import Optional, Union + +import math + +import torch + +# import tokenizer + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + vocab_size: int + n_layer: int + use_conv_bias: bool + expand: int + pad_vocab_size_multiple: int + conv_kernel: int + d_model: int + state_size: int + d_inner: int + initializer_range: float + use_bias: bool + time_step_init_scheme: str + time_step_max: float + time_step_min: float + time_step_floor: float + dt_rank: Union[int, str] = "auto" + + def __post_init__(self): + self.d_inner = self.expand * self.d_model + if self.n_layer is None: + self.n_layer = self.num_hidden_layers + if self.d_model is None: + self.d_model = self.hidden_size + if self.dt_rank == "auto": + self.dt_rank = math.ceil(self.d_model / 16) + +class MambaBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.in_proj = nn.Linear(args.d_model, 2 * args.d_inner, bias=args.use_bias) + # self.conv1d = DepthWiseConv1d(channels=args.d_inner, kernel_size=args.conv_kernel, bias=args.use_conv_bias, padding=args.conv_kernel-1) + self.conv1d = nn.Conv1d( + in_channels=args.d_inner, + out_channels=args.d_inner, + bias=args.use_conv_bias, + kernel_size=args.conv_kernel, + # groups=args.d_inner, + padding=args.conv_kernel - 1, + ) + self.x_proj = nn.Linear(args.d_inner, args.dt_rank + 2 * args.state_size, bias=False) + self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True) + + A = mx.repeat(mx.arange(1, args.state_size + 1).reshape([1, 16]), repeats=args.d_inner) + + + self.A_log = mx.log(A) + self.D = mx.ones([args.d_inner]) + + self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.use_bias) + + self.norm = nn.RMSNorm(args.d_model) + + def ssm(self, x): + d_in, N = self.A_log.shape + A = -mx.exp(self.A_log.float()) + D = self.D.float() + delta, B, C = self.x_proj(x).split(split_size=[self.config.dt_rank, N, N], dim=-1) + delta = nn.softplus(self.dt_proj(delta)) + return self.selective_scan(x, delta, A, B, C, D) + + def selective_scan(self, u, delta, A, B, C, D): + (b, l, d_in) = u.shape + n = A.shape[1] + deltaA = mx.exp(mx.einsum(delta, A, 'b l d_in, d_in n -> b d_in l n')) + deltaB_u = mx.einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b d_in l n') + + # Perform selective scan (see scan_SSM() in The Annotated S4 [2]) + x = mx.zeros((b, d_in, n), device=deltaA.device) + ys = [] + for i in range(l): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + y = mx.einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') + ys.append(y) + y = mx.stack(ys, dim=1) # shape (b, l, d_in) + + y = y + u * D + + return y + + def __call__(self, x): + _, L, _ = x.shape + x, r = self.in_proj(x).split([self.args.d_inner, self.args.d_inner], axis=-1) + + x = mx.reshape(x, 'b l d_in -> b d_in l') + x = self.conv1d(x)[:, :, :L] + x = mx.rearrange(x, 'b d_in l -> b l d_in') + out = self.ssm(nn.silu(x)) * nn.silu(r) + return self.out_proj(out) + x + +class MambaModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.embedding = nn.Embedding(args.vocab_size, args.d_model) + self.layers = [MambaBlock(args) for _ in range(args.n_layer)] + self.norm_f = nn.RMSNorm(args.d_model) + + def __call__(self, inputs: mx.array_equal): + tokens = self.embedding(inputs) + for i, layer in enumerate(self.layers): + h = layer(tokens) + h = self.norm_f(h) + return h + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = self.backbone = MambaModel(args) + self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False) + self.lm_head.weight = self.model.embedding.weight + + def __call__(self, inputs: mx.array): + h = self.backbone(inputs) + return self.lm_head(h) + + @property + def layers(self): + return self.backbone.layers + + # def sanitize(self, weights): + # exclude_patterns = [ + # 'backbone.layers.mixer.A_log', + # 'backbone.layers.mixer.conv1d.weight', + # 'backbone.layers.mixer.dt_proj.weight', + # 'backbone.layers.mixer.in_proj.weight', + # 'backbone.layers.mixer.dt_proj.bias', + # 'backbone.layers.mixer.conv1d.bias', + # 'backbone.layers.mixer.D' + # ] + # return { + # k: v for k, v in weights.items() + # if not any(pattern in k for pattern in exclude_patterns) + # } \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba-torch.py b/llms/mlx_lm/models/mamba-torch.py new file mode 100644 index 000000000..84deb4d3f --- /dev/null +++ b/llms/mlx_lm/models/mamba-torch.py @@ -0,0 +1,145 @@ +import torch.nn as nn +import torch +from configuration_mamba import MambaConfig +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_utils import PreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +import math +import json +import torch +import torch.nn as nn +import torch.nn.functional as F +from dataclasses import dataclass +from einops import rearrange, repeat, einsum +from typing import Optional , Union ,Tuple +l +# Dear contributors of the https://github.com/johnma2006/mamba-minimal/tree/master repository, special thanks to Albert Gu and Tri Dao for their articles. (https://arxiv.org/abs/2312.00752) + + +class MambaRMSNorm(nn.Module): + def __init__(self, + d_model: int, + eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(d_model)) + def forward(self, x): + output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight + return output + + +class MambaBlock(nn.Module): + def __init__(self, config: MambaConfig): + """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1].""" + super().__init__() + self.config = config + + self.in_proj = nn.Linear(config.d_model, config.d_inner * 2, bias=config.bias) + + self.conv1d = nn.Conv1d( + in_channels=config.d_inner, + out_channels=config.d_inner, + bias=config.conv_bias, + kernel_size=config.d_conv, + groups=config.d_inner, + padding=config.d_conv - 1, + ) + + # x_proj takes in `x` and outputs the input-specific Δ, B, C + self.x_proj = nn.Linear(config.d_inner, config.dt_rank + config.d_state * 2, bias=False) + + # dt_proj projects Δ from dt_rank to d_in + self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True) + + A = repeat(torch.arange(1, config.d_state + 1), 'n -> d n', d=config.d_inner) + self.A_log = nn.Parameter(torch.log(A)) + self.D = nn.Parameter(torch.ones(config.d_inner)) + self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias) + # self.norm = MambaRMSNorm(config.d_model) + + def forward(self, x): + (b, l, d) = x.shape + x_copy = x # There was a separate class for residual, I deleted that part and added it here. + x, res = self.in_proj(self.norm(x)).split(split_size=[self.config.d_inner, self.config.d_inner], dim=-1) + + x = rearrange(x, 'b l d_in -> b d_in l') + x = self.conv1d(x)[:, :, :l] + x = rearrange(x, 'b d_in l -> b l d_in') + + x = F.silu(x) + + y = self.ssm(x) + + y = y * F.silu(res) + + output = self.out_proj(y) + x_copy + + return output + + + def ssm(self, x): + (d_in, n) = self.A_log.shape + + A = -torch.exp(self.A_log.float()) # shape (d_in, n) + D = self.D.float() + + x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n) + + (delta, B, C) = x_dbl.split(split_size=[self.config.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n) + delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in) + + y = self.selective_scan(x, delta, A, B, C, D) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2] + + return y + + + def selective_scan(self, u, delta, A, B, C, D): + (b, l, d_in) = u.shape + n = A.shape[1] + deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b d_in l n')) + deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b d_in l n') + x = torch.zeros((b, d_in, n), device=deltaA.device) + ys = [] + for i in range(l): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') + ys.append(y) + y = torch.stack(ys, dim=1) # shape (b, l, d_in) + + y = y + u * D + + return y + + +class MambaModel(MambaPreTrainedModel): + def __init__(self, config: MambaConfig): + super().__init__(config) + self.config = config + + self.embedding = nn.Embedding(config.vocab_size, config.d_model) + self.layers = nn.ModuleList([MambaBlock(config) for _ in range(config.n_layer)]) + self.norm_f = MambaRMSNorm(config.d_model) + + def forward(self, input_ids: torch.LongTensor = None): + x = self.embedding(input_ids) + all_hidden_states = list() + for layer in self.layers: + x = layer(x) + all_hidden_states.append(x) + return self.norm_f(x) + + +class MambaForCausalLM(MambaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = MambaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + self.lm_head.weight = self.model.embedding.weight + + + def forward(self, input_ids: torch.LongTensor = None): + hidden_states = self.model(input_ids=input_ids) + return self.lm_head(hidden_states) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py new file mode 100644 index 000000000..340cce831 --- /dev/null +++ b/llms/mlx_lm/models/mamba.py @@ -0,0 +1,294 @@ +from dataclasses import dataclass +from typing import Optional + +import math + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + vocab_size: int + hidden_size: int + intermediate_size: int + state_size: int + num_hidden_layers: int + layer_norm_epsilon: float + expand: int + conv_kernel: int + use_bias: bool + use_conv_bias: bool + initializer_range: float + time_step_rank: int + time_step_scale: float + time_step_min: float + time_step_max: float + time_step_init_scheme: str + time_step_floor: float + rescale_prenorm_residual: bool + use_cache: bool + use_mambapy: bool = False + dt_rank: str = "auto" + + def __post_init__(self): + if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'): + self.hidden_size = self.d_model + if not hasattr(self, 'intermediate_size') and hasattr(self, 'd_inner'): + self.intermediate_size = self.d_inner + if not hasattr(self, 'state_size') and hasattr(self, 'd_state'): + self.state_size = self.d_state + if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'): + self.num_hidden_layers = self.n_layer + if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'): + self.num_hidden_layers = self.n_layers + if not hasattr(self, 'conv_kernel') and hasattr(self, 'd_conv'): + self.conv_kernel = self.d_conv + if not hasattr(self, 'use_bias') and hasattr(self, 'bias'): + self.use_bias = self.bias + if not hasattr(self, 'use_conv_bias') and hasattr(self, 'conv_bias'): + self.use_conv_bias = self.conv_bias + + self.intermediate_size = self.expand * self.hidden_size + if self.dt_rank == "auto": + self.dt_rank = math.ceil(self.hidden_size / 16) + +class DepthWiseConv1d(nn.Module): + def __init__(self, channels, kernel_size, bias, padding): + super().__init__() + self.channels = int(channels) + self.kernel_size = int(kernel_size) + self.bias = bias + self.padding = padding + self.weight = mx.random.normal(shape=(self.channels, 1, self.kernel_size)) + scale = math.sqrt(1.0 / (self.channels * self.kernel_size)) + self.weight *= scale # Ensure scaling is applied correctly + if bias: + self.bias = mx.zeros((self.channels,)) + else: + self.bias = None + + def __call__(self, x): + B, D, L = x.shape + assert D == self.channels, f"Input channels ({D}) must match the initialized channels ({self.channels})." + print("FORWARD PASS THROUGH CONV") + print(self.kernel_size) + print(self.weight) + out = nn.Conv1d(x, self.weight, kernel_size=self.kernel_size, bias=self.bias, padding=self.padding) + return out + + +def clamp(x, min=None, max=None): + if min is not None: + mask_lower = x < min + if max is not None: + mask_upper = x > max + if min is not None: + if max is not None: + return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) + return mx.where(mask_lower, min, x) + return mx.where(mask_upper, max, x) + + +class MambaBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.hidden_size = args.hidden_size + self.ssm_state_size = args.state_size + self.conv_kernel_size = args.conv_kernel + self.intermediate_size = args.intermediate_size + self.time_step_rank = int(args.time_step_rank) + self.use_conv_bias = args.use_conv_bias + + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) + + self.conv1d = DepthWiseConv1d( + channels=self.intermediate_size, + kernel_size=self.conv_kernel_size, + bias=self.use_conv_bias, + padding=self.conv_kernel_size-1 + ) + + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + + dt_init_std = args.dt_rank**-0.5 * args.state_size + if args.time_step_init_scheme == "constant": + self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) + elif args.time_step_init_scheme == "random": + self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) + else: + raise NotImplementedError + + dt = clamp(mx.exp( + mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min) + ), min=args.time_step_floor) + inv_dt = dt + mx.log1p(-mx.exp(-dt)) + self.dt_proj.bias = inv_dt + + A = mx.repeat(mx.arange(1., self.ssm_state_size + 1.).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) + self.A_log = mx.log(A) + self.D = mx.ones([self.intermediate_size]) + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) + + def ssm(self, x, h): + A = -mx.exp(self.A_log) # (ED, N) + D = self.D + + deltaBC = self.x_proj(x) # (B, dt_rank+2*N) + + delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) + delta = nn.softplus(self.dt_proj(delta)) # (B, ED) + + deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) + deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N) + + BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) + + if h is None: + h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) + + h = deltaA * h + BX # (B, ED, N) + + y = (h @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1) + + y = y + D * x + return y, h + + def __call__(self, x, cache): + h, inputs = cache + x, z = self.in_proj(x).split(indices_or_sections=2, axis=-1) # (B, ED), (B, ED) + # x branch + x_cache = mx.expand_dims(x, 1) # (B, 1, ED) + # Ensure inputs has the correct shape + if inputs.ndim == 2: + inputs = mx.expand_dims(inputs, 1) # Add a dimension if it's missing + + print(f"inputs shape: {inputs.shape}") + print(f"x_cache shape: {x_cache.shape}") + + conv_input = mx.concatenate([inputs, x_cache], axis=1) # (B, d_conv, ED) <---------- Here is the problem ValueError: [concatenate] All the input arrays must have the same number of dimensions. However, got arrays with dimensions 3 and 4. ||| inputs shape: (1, 3, 1536) x_cache shape: (1, 1, 5, 1536) + x = self.conv1d(conv_input)[:, -1, :] # (B, ED) + y, h = self.ssm(nn.silu(x), h) + output = y * nn.silu(z) # * z branch + # prepare cache for next call + inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, d_conv-1, ED) + return self.out_proj(output), (h, inputs) # (B, D), cache + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.mixer = MambaBlock(args) + self.norm = nn.RMSNorm(args.hidden_size) + + def __call__(self, inputs: mx.array, cache): + output, cache = self.mixer(self.norm(inputs), cache) + output = output + inputs + return output, cache + + +class Mamba(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] + self.norm_f = nn.RMSNorm(args.hidden_size) + + def __call__(self, inputs: mx.array, cache): + tokens = self.embeddings(inputs) + for i, layer in enumerate(self.layers): + h, cache[i] = layer(tokens, cache[i]) + h = self.norm_f(h) + return h, cache + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.backbone = Mamba(args) + + def __call__(self, inputs: mx.array, cache=None): + out, cache = self.backbone(inputs, cache) + out = self.backbone.embeddings.as_linear(out) + return out, cache + + @property + def layers(self): + return self.backbone.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_hidden_layers + + @property + def n_kv_heads(self): + return self.args.num_hidden_layers + + def make_cache(self): + # return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] + return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))] + + def generate(self, input_ids: mx.array, n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): + self.eval() + + if input_ids.ndim == 1: + input_ids = mx.expand_dims(input_ids, 0) + + caches = self.make_cache() + + for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): + next_token_logits, caches = self(input_ids[:, i], caches) + + if i+1 >= input_ids.shape[1]: + + if top_k is not None: + values = mx.topk(next_token_logits, k=top_k) + mask = next_token_logits < (values[:, 0, None]) + next_token_logits = mx.where(mask, -5000, next_token_logits) + + if sample and temperature > 0: + next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) + else: + next_token = mx.argmax(next_token_logits, axis=-1)[:, None] + + input_ids = mx.concatenate([input_ids, next_token], axis=1) + + self.train() + return input_ids + + def generate_step(self, input_ids: mx.array, sample: bool = True, temperature: float = 1.0, top_k: int = None): + self.eval() + + if input_ids.ndim == 1: + input_ids = mx.expand_dims(input_ids, 0) + + caches = self.make_cache() + + # Generate the next token logits + next_token_logits, caches = self(input_ids, caches) + + # Apply top_k filtering if specified + if top_k is not None: + values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to highest + mask = next_token_logits < (values[:, -1, None]) + next_token_logits = mx.where(mask, -5000, next_token_logits) # -mx.inf is problematic for now + + # Sample the next token or take the argmax based on the temperature + if sample and temperature > 0: + next_token = mx.random.categorical(next_token_logits * (1 / temperature), num_samples=1) + else: + next_token = mx.argmax(next_token_logits, axis=-1)[:, None] + + # Concatenate the next token to the input_ids + input_ids = mx.concatenate([input_ids, next_token], axis=1) + + self.train() + return input_ids, caches \ No newline at end of file From 3d3dfa39f16aad018b125e2dca91bfc8747eddeb Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sat, 17 Aug 2024 14:32:19 +0200 Subject: [PATCH 21/40] utils.py logits = logits[:, -1, :] TypeError: tuple indices must be integers or slices, not tuple --- llms/mlx_lm/models/mamba-old.py | 399 ------------- llms/mlx_lm/models/mamba-save.py | 832 --------------------------- llms/mlx_lm/models/mamba-tiny-pld.py | 154 ----- llms/mlx_lm/models/mamba-torch.py | 145 ----- llms/mlx_lm/models/mamba.py | 175 +++--- 5 files changed, 70 insertions(+), 1635 deletions(-) delete mode 100644 llms/mlx_lm/models/mamba-old.py delete mode 100644 llms/mlx_lm/models/mamba-save.py delete mode 100644 llms/mlx_lm/models/mamba-tiny-pld.py delete mode 100644 llms/mlx_lm/models/mamba-torch.py diff --git a/llms/mlx_lm/models/mamba-old.py b/llms/mlx_lm/models/mamba-old.py deleted file mode 100644 index 844d4fb7d..000000000 --- a/llms/mlx_lm/models/mamba-old.py +++ /dev/null @@ -1,399 +0,0 @@ -from dataclasses import dataclass -from typing import Optional, Union - -import math - -import torch - -# import tokenizer - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str = "mamba" - dt_rank: Union[int, str] # time_step_rank - d_model: int - d_inner: int - vocab_size: int - n_layer: int - use_bias: bool - use_conv_bias: bool - rms_norm: bool - conv_kernel: int - state_size: int - expand: int - time_step_init_scheme: str - time_step_max: float - time_step_min: float - time_step_floor: float - pscan: bool = False - tie_word_embeddings: bool = False - num_hidden_layers: int = None - hidden_size: int = None - # time_step_scale - - def __post_init__(self): - self.d_inner = self.expand * self.d_model - if self.n_layer is None: - self.n_layer = self.num_hidden_layers - if self.d_model is None: - self.d_model = self.hidden_size - if self.dt_rank == 'auto': - self.dt_rank = math.ceil(self.d_model / 16) - - -def pscan_main(A, X): - Aa = A - Xa = X - B, D, L, _ = A.shape - num_steps = int(math.log2(L)) - - for k in range(num_steps): - T = 2 * (Xa.shape[2] // 2) - Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa[:, :, :, 1] += Aa[:, :, :, 1] * Xa[:, :, :, 0] - Aa[:, :, :, 1] *= Aa[:, :, :, 0] - A[:, :, 2**(k+1)-1::2**(k+1)] = Aa[:, :, :, 1] - X[:, :, 2**(k+1)-1::2**(k+1)] = Xa[:, :, :, 1] - Aa = Aa[:, :, :, 1] - Xa = Xa[:, :, :, 1] - - for k in range(num_steps-1, -1, -1): - Aa = A[:, :, 2**k-1::2**k] - Xa = X[:, :, 2**k-1::2**k] - step_len = Xa.shape[2] - T = 2 * (step_len // 2) - if T < step_len: - last_val_aa = Aa[:, :, -1] * Aa[:, :, -2] - last_val_xa = Xa[:, :, -1] + Aa[:, :, -1] * Xa[:, :, -2] - Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa[:, :, 1:, 0] += Aa[:, :, 1:, 0] * Xa[:, :, :-1, 1] - Aa[:, :, 1:, 0] *= Aa[:, :, :-1, 1] - if T == step_len: - A[:, :, 2**k-1::2**(k+1)] = Aa[:, :, :, 0] - X[:, :, 2**k-1::2**(k+1)] = Xa[:, :, :, 0] - else: - A[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Aa[:, :, :, 0], mx.array([last_val_aa]).reshape(B, D, 1, -1)], axis=2) - X[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Xa[:, :, :, 0], mx.array([last_val_xa]).reshape(B, D, 1, -1)], axis=2) - - -def pscan(A_in, X_in): - A = A_in[:].transpose(0, 2, 1, 3) - X = X_in[:].transpose(0, 2, 1, 3) - pscan_main(A, X) - return X.transpose(0, 2, 1, 3) - - -def clamp(x, min=None, max=None): - if min is not None: - mask_lower = x < min - if max is not None: - mask_upper = x > max - if min is not None: - if max is not None: - return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) - return mx.where(mask_lower, min, x) - return mx.where(mask_upper, max, x) - - -def unsqueeze(x, axis): - assert axis <= len(x.shape) - if axis >= 0: - new_shape = x.shape[:axis] + tuple([1]) + x.shape[axis:] - else: - new_shape = x.shape + tuple([1]) - return x.reshape(new_shape) - - -class DepthWiseConv1d(nn.Module): - def __init__(self, channels, kernel_size, bias, padding): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.bias = bias - self.padding = padding - self.conv1d = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, bias=True, padding=padding) - indices = mx.arange(channels) - mask = mx.zeros_like(self.conv1d.weight) - mask[indices, :, indices] = 1 - self.conv1d.weight *= mask - - def __call__(self, x): - return self.conv1d(x) - - -class MambaBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.in_proj = nn.Linear(args.d_model, 2 * args.d_inner, bias=args.use_bias) - # self.conv1d = DepthWiseConv1d(channels=args.d_inner, kernel_size=args.conv_kernel, bias=args.use_conv_bias, padding=args.conv_kernel-1) - self.conv1d = nn.Conv1d( - in_channels=args.d_inner, - out_channels=args.d_inner, - bias=args.conv_bias, - kernel_size=args.d_conv, - groups=args.d_inner, - padding=args.d_conv - 1, - ) - self.x_proj = nn.Linear(args.d_inner, args.dt_rank + 2 * args.state_size, bias=False) - self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True) - - dt_init_std = args.dt_rank**-0.5 * args.state_size - if args.time_step_init_scheme == "constant": - self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) - elif args.time_step_init_scheme == "random": - self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) - else: - raise NotImplementedError - - dt = clamp(mx.exp(mx.random.uniform(shape=[args.d_inner]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)), min=args.time_step_floor) - self.dt_proj.bias = dt + mx.log1p(-mx.exp(-dt)) - A = mx.repeat(mx.arange(1., 16 + 1.).reshape([1, 16]), repeats=args.d_inner, axis=0) - self.A_log = mx.log(A) - self.D = mx.ones([args.d_inner]) - - self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.use_bias) - - self.norm = nn.RMSNorm(args.d_model) - - def ssm_step(self, x, h): - A = -mx.exp(self.A_log) - D = self.D - deltaBC = self.x_proj(self.norm(x)) - delta, B, C = mx.split(deltaBC, indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) - delta = nn.softplus(self.dt_proj(delta)) - deltaA = mx.exp(unsqueeze(delta, -1) * A) - deltaB = unsqueeze(delta, -1) * unsqueeze(B, 1) - BX = deltaB * unsqueeze(x, -1) - if h is None: - h = mx.zeros([x.shape[0], self.args.d_inner, self.args.state_size]) - h = deltaA * h + BX - y = (h @ unsqueeze(C, -1)).squeeze(2) - y = y + D * x - return y, h - - def ssm(self, x): # DONE - A = -mx.exp(self.A_log) - D = self.D - delta, B, C = self.x_proj(x).split(indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) - delta = nn.softplus(self.dt_proj(delta)) - if self.args.pscan: - y = self.selective_scan(x, delta, A, B, C, D) - else: - y = self.selective_scan_seq(x, delta, A, B, C, D) - return y - - def ssm_new(self, x): - d_in, N = self.A_log.shape - A = -mx.exp(self.A_log.float()) - D = self.D.float() - delta, B, C = self.x_proj(x).split(split_size=[self.config.dt_rank, N, N], dim=-1) - delta = nn.softplus(self.dt_proj(delta)) - return self.selective_scan_new(x, delta, A, B, C, D) - - def selective_scan(self, x, delta, A, B, C, D): # DONE - deltaA = mx.exp(unsqueeze(delta, -1) * A) - deltaB = unsqueeze(delta, -1) * unsqueeze(B, 2) - BX = deltaB * unsqueeze(x, -1) - hs = pscan(deltaA, BX) - y = (hs @ unsqueeze(C, -1)).squeeze(3) - return y + D * x - - def selective_scan_new(self, u, delta, A, B, C, D): - (b, l, d_in) = u.shape - n = A.shape[1] - deltaA = mx.exp(mx.einsum(delta, A, 'b l d_in, d_in n -> b d_in l n')) - deltaB_u = mx.einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b d_in l n') - - # Perform selective scan (see scan_SSM() in The Annotated S4 [2]) - x = mx.zeros((b, d_in, n), device=deltaA.device) - ys = [] - for i in range(l): - x = deltaA[:, :, i] * x + deltaB_u[:, :, i] - y = mx.einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') - ys.append(y) - y = mx.stack(ys, dim=1) # shape (b, l, d_in) - - y = y + u * D - - return y - - def selective_scan_seq(self, x, delta, A, B, C, D): - _, L, _ = x.shape - deltaA = mx.exp(unsqueeze(delta, -1) * A) - deltaB = unsqueeze(delta, -1) * unsqueeze(B, 2) - BX = deltaB * unsqueeze(x, -1) - h = mx.zeros([x.shape[0], self.args.d_inner, self.args.state_size]) - hs = [] - for t in range(0, L): - h = deltaA[:, t] * h + BX[:, t] - hs.append(h) - hs = mx.stack(hs, axis=1) - y = (hs @ unsqueeze(C, -1)).squeeze(3) - return y + D * x - - def step(self, x, cache): # Done - h, inputs = cache - x, z = self.in_proj(x).split(indices_or_sections=2, axis=1) - x_cache = unsqueeze(x, 1) - x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.args.conv_kernel-1, :] - y, h = self.ssm_step(nn.silu(x), h) - output = y * nn.silu(z) - output = self.out_proj(output) - inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) - return output, (h, inputs) - - def ssm_step(self, x, h): # Done - A = -mx.exp(self.A_log) - D = self.D - delta, B, C = self.x_proj(x).split(indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) - delta = nn.softplus(self.dt_proj(delta)) - deltaB = unsqueeze(delta, -1) * unsqueeze(B, 1) - BX = deltaB * unsqueeze(x, -1) - if h is None: - h = mx.zeros([x.shape[0], self.args.d_inner, self.args.d_state]) - h = deltaA * h + BX - y = (h @ unsqueeze(C, -1)).squeeze(2) - y = y + D * x - return y, h - - def __call__(self, x): # DONE - _, L, _ = x.shape - x, z = self.in_proj(x).split(indices_or_sections=2, axis=2) - x = self.conv1d(x)[:, :L, :] - output = self.ssm(nn.silu(x)) * nn.silu(z) - return self.out_proj(output) - - def new(self, x): - _, L, _ = x.shape - x, r = self.in_proj(x).split([self.args.d_inner, self.args.d_inner], axis=-1) - - x = mx.reshape(x, 'b l d_in -> b d_in l') - x = self.conv1d(x)[:, :, :L] - x = mx.rearrange(x, 'b d_in l -> b l d_in') - out = self.ssm_new(nn.silu(x)) * nn.silu(r) - return self.out_proj(out) + x - - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.mixer = MambaBlock(args) - self.norm = nn.RMSNorm(args.d_model) - - def __call__(self, inputs: mx.array, cache: Optional[mx.array] = None): - output, cache = self.mixer.step(self.norm(inputs), cache) - output = output + inputs - return output, cache - - -class Mamba(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.embedding = nn.Embedding(args.vocab_size, args.d_model) - self.layers = [ResidualBlock(args) for _ in range(args.n_layer)] - self.norm_f = nn.RMSNorm(args.d_model) - - def __call__(self, inputs: mx.array, cache=None): - tokens = self.embedding(inputs) - if cache is None: - cache = [None] * len(self.layers) - for i, layer in enumerate(self.layers): - h, cache[i] = layer(tokens, cache[i]) - h = self.norm_f(h) - return h, cache - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.backbone = Mamba(args) - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False) - - def __call__(self, inputs: mx.array, cache=None): - out, cache = self.backbone(inputs, cache) - - if self.args.tie_word_embeddings: - out = self.model.embed_tokens.as_linear(out) - return out, cache - - # def torch_to_mlx_depthwise_weights(self, torch_weights): - # torch_weights = torch_weights.transpose(2, 1) - # channels, kernel_size, _ = torch_weights.shape - - # mlx_weights = torch.zeros(channels, kernel_size, channels) - - # indices = torch.arange(channels) - # if torch_weights[:, :, 0].type() == 'torch.BFloat16Tensor': - # mlx_weights[indices, :, indices] = torch_weights[:, :, 0].float() - # else: - # mlx_weights[indices, :, indices] = torch_weights[:, :, 0] - - # return mlx_weights - - def sanitize(self, torch_state_dict): - new_state_dict = {} - for key, value in torch_state_dict.items(): - if 'conv1d.weight' in key: - value = self.torch_to_mlx_depthwise_weights(value) - - if 'conv1d' in key: - key = key.replace('conv1d', 'conv1d.conv1d') - - if value.type() == 'torch.BFloat16Tensor': - new_state_dict[key] = value.half().numpy() - else: - new_state_dict[key] = value.numpy() - - return new_state_dict - - @property - def layers(self): - return self.model.layers - - def generate(self, tokenizer=None, prompt: str="Hello", n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): - self.eval() - - input_ids = mx.array(tokenizer(prompt, return_tensors='np').input_ids) - - caches = [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.d_inner])) for _ in range(self.args.n_layer)] - - for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): - next_token_logits, caches = self(input_ids[:, i], caches) - - if i+1 >= input_ids.shape[1]: - - if top_k is not None: - values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to biggest - mask = next_token_logits < (values[:, 0, None]) - next_token_logits = mx.where(mask, -5000, next_token_logits) # TODO -mx.inf is problematic for now - - if sample and temperature > 0: - next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) - else: - next_token = mx.argmax(next_token_logits, axis=-1)[:, None] - - input_ids = mx.concatenate([input_ids, next_token], axis=1) - - output = [tokenizer.decode(output.tolist()) for output in input_ids][0] - - self.train() - - return output - -# model = Model(ModelArgs()) -# print(model) - -# logits = model.generate() -# print(logits) diff --git a/llms/mlx_lm/models/mamba-save.py b/llms/mlx_lm/models/mamba-save.py deleted file mode 100644 index 9858158f3..000000000 --- a/llms/mlx_lm/models/mamba-save.py +++ /dev/null @@ -1,832 +0,0 @@ -from dataclasses import dataclass -from typing import Optional, Union - -import math - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, MambaCache - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - vocab_size: int - hidden_size: int # d_model - intermediate_size: int # d_inner - state_size: int # d_state - num_hidden_layers: int # n_layer - layer_norm_epsilon: float - expand: int - conv_kernel: int # d_conv - use_bias: bool # bias - use_conv_bias: bool # conv_bias - initializer_range: float - time_step_rank: int - time_step_scale: float - time_step_min: float - time_step_max: float - time_step_init_scheme: str - time_step_floor: float - rescale_prenorm_residual: bool - use_cache: bool - use_mambapy: bool = False # pscan - dt_rank: str = "auto" - - def __post_init__(self): - self.intermediate_size = self.expand * self.hidden_size - if self.dt_rank == "auto": - self.dt_rank = math.ceil(self.hidden_size / 16) - - -def clamp(x, min=None, max=None): - if min is not None: - mask_lower = x < min - if max is not None: - mask_upper = x > max - if min is not None: - if max is not None: - return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) - return mx.where(mask_lower, min, x) - return mx.where(mask_upper, max, x) - - -def unsqueeze(x, axis): - assert axis <= len(x.shape) - if axis >= 0: - new_shape = x.shape[:axis] + tuple([1]) + x.shape[axis:] - else: - new_shape = x.shape + tuple([1]) - return x.reshape(new_shape) - - -class DepthWiseConv1d(nn.Module): - def __init__(self, channels, kernel_size, bias, padding): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.bias = bias - self.padding = padding - self.weight = mx.random.normal(shape=(channels, 1, kernel_size)) - scale = math.sqrt(1.0 / (channels * kernel_size)) - self.weight *= scale # Ensure scaling is applied correctly - if bias: - self.bias = mx.zeros((channels,)) - else: - self.bias = None - - def __call__(self, x): - out = nn.Conv1d(x, self.weight, kernel_size=self.kernel_size, bias=self.bias, padding=self.padding) - return out - - -class MambaBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - - self.hidden_size = args.hidden_size - self.ssm_state_size = args.state_size - self.conv_kernel_size = args.conv_kernel - self.intermediate_size = args.intermediate_size - self.time_step_rank = int(args.time_step_rank) - self.use_conv_bias = args.use_conv_bias - - self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) - - self.conv1d = DepthWiseConv1d( - channels=self.intermediate_size, - kernel_size=self.conv_kernel_size, - bias=self.use_conv_bias, - padding=self.conv_kernel_size-1 - ) - - self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) - self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) - - dt_init_std = args.dt_rank**-0.5 * args.state_size - if args.time_step_init_scheme == "constant": - self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) - elif args.time_step_init_scheme == "random": - self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) - else: - raise NotImplementedError - - dt = clamp(mx.exp(mx.random.uniform(shape=[self.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)), min=args.time_step_floor) - self.dt_proj.bias = dt + mx.log1p(-mx.exp(-dt)) - inv_dt = dt + mx.log1p(-mx.exp(-dt)) - self.dt_proj.bias = inv_dt - - A = mx.repeat(mx.arange(1, self.ssm_state_size + 1).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) - self.A_log = mx.log(A) - self.D = mx.ones([self.intermediate_size]) - - self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) - - def ssm(self, x, h): - A = -mx.exp(self.A_log) - D = self.D - delta, B, C = self.x_proj(x).split(split_size=[self.intermediate_size, self.intermediate_size], dim=-1) - delta = nn.softplus(self.dt_proj(delta)) - deltaA = mx.exp(mx.unsqueeze(delta, -1) * A) - deltaB = unsqueeze(delta, -1) * unsqueeze(B, 1) - BX = deltaB * unsqueeze(x, -1) - if h is None: - h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) - h = deltaA * h + BX - y = (h @ mx.unsqueeze(C, -1)).squeeze(2) - y = y + D * x - return y, h - - def __call__(self, x, cache: Optional[MambaCache]): - h, inputs = cache - x, z = self.in_proj(x).split(indices_or_sections=2, axis=1) - x_cache = unsqueeze(x, 1) - x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.conv_kernel_size-1, :] # (B, ED) - y, h = self.ssm(nn.silu(x), h) - output = y * nn.silu(z) - inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) - cache.update(h, inputs) - return self.out_proj(output), cache - - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.mixer = MambaBlock(args) - self.norm = nn.RMSNorm(args.hidden_size) - - def __call__(self, inputs: mx.array, cache): - output, cache = self.mixer(self.norm(inputs), cache) - output = output + inputs - return output, cache - - -class Mamba(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] - self.norm_f = nn.RMSNorm(args.hidden_size) - - def __call__(self, inputs: mx.array, cache=None): - tokens = self.embeddings(inputs) - if cache is None: - cache = [None] * len(self.layers) - for i, layer in enumerate(self.layers): - h, cache[i] = layer(tokens, cache[i]) - h = self.norm_f(h) - return h, cache - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.backbone = Mamba(args) - # self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__(self, inputs: mx.array, cache=None): - out, cache = self.backbone(inputs, cache) - out = self.backbone.embeddings.as_linear(out) - return out, cache - - @property - def layers(self): - return self.backbone.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_hidden_layers - - @property - def n_kv_heads(self): - return self.args.num_hidden_layers - - def generate(self, input_ids: mx.array, n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): - self.eval() - - if input_ids.ndim == 1: - input_ids = input_ids.unsqueeze(0) - - caches = [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] - - for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): - next_token_logits, caches = self(input_ids[:, i], caches) - - if i+1 >= input_ids.shape[1]: - - if top_k is not None: - values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to biggest - mask = next_token_logits < (values[:, 0, None]) - next_token_logits = mx.where(mask, -5000, next_token_logits) # TODO -mx.inf is problematic for now - - if sample and temperature > 0: - next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) - else: - next_token = mx.argmax(next_token_logits, axis=-1)[:, None] - - input_ids = mx.concatenate([input_ids, next_token], axis=1) - - self.train() - return input_ids - - - - - - - - -# from dataclasses import dataclass -# from typing import Optional, Union - -# import math -# import einsum - -# import mlx.core as mx -# import mlx.nn as nn - -# from .base import BaseModelArgs, MambaCache - - -# @dataclass -# class ModelArgs(BaseModelArgs): -# model_type: str -# vocab_size: int -# hidden_size: int # d_model -# intermediate_size: int # d_inner -# state_size: int # d_state -# num_hidden_layers: int # n_layer -# layer_norm_epsilon: float -# expand: int -# conv_kernel: int # d_conv -# use_bias: bool # bias -# use_conv_bias: bool # conv_bias -# initializer_range: float -# time_step_rank: int -# time_step_scale: float -# time_step_min: float -# time_step_max: float -# time_step_init_scheme: str -# time_step_floor: float -# rescale_prenorm_residual: bool -# use_cache: bool -# use_mambapy: bool = False # pscan -# dt_rank: str = "auto" - -# def __post_init__(self): -# self.intermediate_size = self.expand * self.hidden_size -# if self.dt_rank == "auto": -# self.dt_rank = math.ceil(self.hidden_size / 16) - - -# def clamp(x, min=None, max=None): -# if min is not None: -# mask_lower = x < min -# if max is not None: -# mask_upper = x > max -# if min is not None: -# if max is not None: -# return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) -# return mx.where(mask_lower, min, x) -# return mx.where(mask_upper, max, x) - -# class MambaBlock(nn.Module): -# def __init__(self, args: ModelArgs): -# super().__init__() -# self.args = args - -# self.hidden_size = args.hidden_size -# self.ssm_state_size = args.state_size -# self.conv_kernel_size = args.conv_kernel -# self.intermediate_size = args.intermediate_size -# self.time_step_rank = int(args.time_step_rank) -# self.use_conv_bias = args.use_conv_bias - -# self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) - -# self.conv1d = nn.Conv1d( -# in_channels=self.intermediate_size, -# out_channels=self.intermediate_size, -# kernel_size=self.conv_kernel_size, -# bias=self.use_conv_bias, -# padding=self.conv_kernel_size-1 -# ) - -# self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) -# self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) - -# A = mx.repeat(mx.arange(1., self.ssm_state_size + 1), "n -> d n", repeats=self.intermediate_size) -# self.A_log = mx.log(A) -# self.D = mx.ones([self.intermediate_size]) - -# self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) - -# def ssm(self, x): -# (d_in, n) = self.A_log.shape - -# A = -mx.exp(self.A_log.float()) # shape (d_in, n) -# D = self.D.float() - -# x_dbl = self.x_proj(x) # (b, l, time_step_rank + 2*n) - -# (delta, B, C) = x_dbl.split(indices_or_sections=[self.time_step_rank, n, n], axis=-1) # delta: (b, l, time_step_rank). B, C: (b, l, n) -# delta = nn.softplus(self.dt_proj(delta)) # (b, l, d_in) - -# y = self.selective_scan(x, delta, A, B, C, D) - -# return y - -# def selective_scan(self, u, delta, A, B, C, D): -# (b, l, d_in) = u.shape -# n = A.shape[1] -# deltaA = mx.exp(einsum(delta, A, 'b l d_in, d_in n -> b d_in l n')) -# deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b d_in l n') -# x = mx.zeros((b, d_in, n), device=deltaA.device) -# ys = [] -# for i in range(l): -# x = deltaA[:, :, i] * x + deltaB_u[:, :, i] -# y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') -# ys.append(y) -# y = mx.stack(ys, dim=1) # shape (b, l, d_in) - -# y = y + u * D -# return y - -# def __call__(self, x): -# (b, l, d) = x.shape -# x_copy = x -# x, res = self.in_proj(self.norm(x)).split(indices_or_sections=[self.intermediate_size, self.intermediate_size], axis=-1) - -# x = mx.rearrange(x, 'b l d_in -> b d_in l') -# x = self.conv1d(x)[:, :, :l] -# x = mx.rearrange(x, 'b d_in l -> b l d_in') - -# x = nn.silu(x) - -# y = self.ssm(x) - -# y = y * nn.silu(res) -# return self.out_proj(y) + x_copy - - -# class ResidualBlock(nn.Module): -# def __init__(self, args: ModelArgs): -# super().__init__() -# self.mixer = MambaBlock(args) -# self.norm = nn.RMSNorm(args.hidden_size) - -# def __call__(self, inputs: mx.array): -# output = self.mixer(self.norm(inputs)) -# output = output + inputs -# return output - - -# class Mamba(nn.Module): -# def __init__(self, args: ModelArgs): -# super().__init__() -# self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) -# self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] -# self.norm_f = nn.RMSNorm(args.hidden_size) - -# def __call__(self, inputs: mx.array): -# tokens = self.embeddings(inputs) -# for i, layer in enumerate(self.layers): -# h, = layer(tokens) -# return self.norm_f(h) - - -# class Model(nn.Module): -# def __init__(self, args: ModelArgs): -# super().__init__() -# self.args = args -# self.model_type = args.model_type -# self.backbone = Mamba(args) - -# def __call__(self, inputs: mx.array, cache=None): -# out = self.backbone(inputs) -# out = self.backbone.embeddings.as_linear(out) -# return out, cache - -# @property -# def layers(self): -# return self.backbone.layers - -# @property -# def head_dim(self): -# return self.args.hidden_size // self.args.num_hidden_layers - -# @property -# def n_kv_heads(self): -# return self.args.num_hidden_layers - - - -from dataclasses import dataclass -from typing import Optional - -import math - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - vocab_size: int - hidden_size: int # d_model - intermediate_size: int # d_inner - state_size: int # d_state - num_hidden_layers: int # n_layer, n_layer - layer_norm_epsilon: float - expand: int - conv_kernel: int # d_conv - use_bias: bool # bias - use_conv_bias: bool # conv_bias - initializer_range: float - time_step_rank: int - time_step_scale: float - time_step_min: float - time_step_max: float - time_step_init_scheme: str - time_step_floor: float - rescale_prenorm_residual: bool - use_cache: bool - use_mambapy: bool = False # pscan - dt_rank: str = "auto" - - def __post_init__(self): - if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'): - self.hidden_size = self.d_model - if not hasattr(self, 'intermediate_size') and hasattr(self, 'd_inner'): - self.intermediate_size = self.d_inner - if not hasattr(self, 'state_size') and hasattr(self, 'd_state'): - self.state_size = self.d_state - if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'): - self.num_hidden_layers = self.n_layer - if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'): - self.num_hidden_layers = self.n_layers - if not hasattr(self, 'conv_kernel') and hasattr(self, 'd_conv'): - self.conv_kernel = self.d_conv - if not hasattr(self, 'use_bias') and hasattr(self, 'bias'): - self.use_bias = self.bias - if not hasattr(self, 'use_conv_bias') and hasattr(self, 'conv_bias'): - self.use_conv_bias = self.conv_bias - - self.intermediate_size = self.expand * self.hidden_size - if self.dt_rank == "auto": - self.dt_rank = math.ceil(self.hidden_size / 16) - - -def clamp(x, min=None, max=None): - if min is not None: - mask_lower = x < min - if max is not None: - mask_upper = x > max - if min is not None: - if max is not None: - return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) - return mx.where(mask_lower, min, x) - return mx.where(mask_upper, max, x) - - -def pscan_f(A, X): - Aa = A - Xa = X - - B, D, L, _ = A.shape - - num_steps = int(math.log2(L)) - - # up sweep - for k in range(num_steps): - T = 2 * (Xa.shape[2] // 2) - - Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) - - Xa[:, :, :, 1] += Aa[:, :, :, 1] * Xa[:, :, :, 0] - Aa[:, :, :, 1] *= Aa[:, :, :, 0] - - A[:, :, 2**(k+1)-1::2**(k+1)] = Aa[:, :, :, 1] - X[:, :, 2**(k+1)-1::2**(k+1)] = Xa[:, :, :, 1] - - Aa = Aa[:, :, :, 1] - Xa = Xa[:, :, :, 1] - - # down sweep - for k in range(num_steps-1, -1, -1): - Aa = A[:, :, 2**k-1::2**k] - Xa = X[:, :, 2**k-1::2**k] - - step_len = Xa.shape[2] - T = 2 * (step_len // 2) - - if T < step_len: - last_val_aa = Aa[:, :, -1] * Aa[:, :, -2] - last_val_xa = Xa[:, :, -1] + Aa[:, :, -1] * Xa[:, :, -2] - - Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) - - Xa[:, :, 1:, 0] += Aa[:, :, 1:, 0] * Xa[:, :, :-1, 1] - Aa[:, :, 1:, 0] *= Aa[:, :, :-1, 1] - - if T == step_len: - A[:, :, 2**k-1::2**(k+1)] = Aa[:, :, :, 0] - X[:, :, 2**k-1::2**(k+1)] = Xa[:, :, :, 0] - else: - A[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Aa[:, :, :, 0], mx.array([last_val_aa]).reshape(B, D, 1, -1)], axis=2) - X[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Xa[:, :, :, 0], mx.array([last_val_xa]).reshape(B, D, 1, -1)], axis=2) - -def pscan(A_in, X_in): - A = A_in[:].transpose(0, 2, 1, 3) - X = X_in[:].transpose(0, 2, 1, 3) - pscan_f(A, X) - return X.transpose(0, 2, 1, 3) - - -class DepthWiseConv1d(nn.Module): - def __init__(self, channels, kernel_size, bias, padding): - super().__init__() - self.channels = int(channels) - self.kernel_size = int(kernel_size) - self.bias = bias - self.padding = padding - self.weight = mx.random.normal(shape=(self.channels, 1, self.kernel_size)) - scale = math.sqrt(1.0 / (self.channels * self.kernel_size)) - self.weight *= scale # Ensure scaling is applied correctly - if bias: - self.bias = mx.zeros((self.channels,)) - else: - self.bias = None - - def __call__(self, x): - B, D, L = x.shape - assert D == self.channels, f"Input channels ({D}) must match the initialized channels ({self.channels})." - print("FORWARD PASS THROUGH CONV") - print(self.kernel_size) - print(self.weight) - out = nn.Conv1d(x, self.weight, kernel_size=self.kernel_size, bias=self.bias, padding=self.padding) - return out - - -class MambaBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - - self.hidden_size = args.hidden_size - self.ssm_state_size = args.state_size - self.conv_kernel_size = args.conv_kernel - self.intermediate_size = args.intermediate_size - self.time_step_rank = int(args.time_step_rank) - self.use_conv_bias = args.use_conv_bias - - self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) - - self.conv1d = DepthWiseConv1d( - channels=self.intermediate_size, - kernel_size=self.conv_kernel_size, - bias=self.use_conv_bias, - padding=self.conv_kernel_size-1 - ) - - self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) - self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) - - dt_init_std = args.dt_rank**-0.5 * args.state_size - if args.time_step_init_scheme == "constant": - self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) - elif args.time_step_init_scheme == "random": - self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) - else: - raise NotImplementedError - - dt = clamp(mx.exp( - mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min) - ), min=args.time_step_floor) - inv_dt = dt + mx.log1p(-mx.exp(-dt)) - self.dt_proj.bias = inv_dt - - A = mx.repeat(mx.arange(1., self.ssm_state_size + 1.).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) - self.A_log = mx.log(A) - self.D = mx.ones([self.intermediate_size]) - - self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) - - # def ssm_old(self, x): - # A = -mx.exp(self.A_log) # (ED, N) - # D = self.D - - # deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N) - - # delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, L, dt_rank), (B, L, N), (B, L, N) - # delta = mx.softplus(self.dt_proj(delta)) # (B, L, ED) - - # if self.args.use_mambapy: - # y = self.selective_scan(x, delta, A, B, C, D) - # else: - # y = self.selective_scan_seq(x, delta, A, B, C, D) - # return y - - # def selective_scan(self, x, delta, A, B, C, D): - # deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N) - # deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N) - # BX = deltaB * mx.expand_dims(x, -1) # (B, L, ED, N) - # hs = pscan(deltaA, BX) - # y = (hs @ mx.expand_dims(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) - # y = y + D * x - # return y - - # def selective_scan_seq(self, x, delta, A, B, C, D): - # _, L, _ = x.shape - # deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N) - # deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N) - # BX = deltaB * mx.expand_dims(x, -1) # (B, L, ED, N) - # h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) - # hs = [] - # for t in range(0, L): - # h = deltaA[:, t] * h + BX[:, t] - # hs.append(h) - # hs = mx.stack(hs, axis=1) - # y = (hs @ mx.expand_dims(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) - # y = y + D * x - # return y - - def ssm(self, x, h): - A = -mx.exp(self.A_log) # (ED, N) - D = self.D - - deltaBC = self.x_proj(x) # (B, dt_rank+2*N) - - delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) - delta = nn.softplus(self.dt_proj(delta)) # (B, ED) - - deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) - deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N) - - BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) - - if h is None: - h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) - - h = deltaA * h + BX # (B, ED, N) - - y = (h @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1) - - y = y + D * x - return y, h - - def __call__(self, x, cache): - h, inputs = cache - - xz = self.in_proj(x) # (B, 2*ED) - x, z = mx.split(xz, indices_or_sections=2, axis=-1) # (B, ED), (B, ED) - - # x branch - x_cache = mx.expand_dims(x, 1) # (B, 1, ED) - - # Ensure inputs has the correct shape - if inputs.ndim == 2: - inputs = mx.expand_dims(inputs, 1) # Add a dimension if it's missing - - print(f"inputs shape: {inputs.shape}") - print(f"x_cache shape: {x_cache.shape}") - - conv_input = mx.concatenate([inputs, x_cache], axis=1) # (B, d_conv, ED) - x = self.conv1d(conv_input)[:, -1, :] # (B, ED) - - x = nn.silu(x) - y, h = self.ssm(x, h) - - # z branch - z = nn.silu(z) - - output = y * z - output = self.out_proj(output) # (B, D) - - # prepare cache for next call - inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, d_conv-1, ED) - cache = (h, inputs) - - return output, cache - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.mixer = MambaBlock(args) - self.norm = nn.RMSNorm(args.hidden_size) - - def __call__(self, inputs: mx.array, cache): - output, cache = self.mixer(self.norm(inputs), cache) - output = output + inputs - return output, cache - - -class Mamba(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] - self.norm_f = nn.RMSNorm(args.hidden_size) - - def __call__(self, inputs: mx.array, cache): - tokens = self.embeddings(inputs) - for i, layer in enumerate(self.layers): - h, cache[i] = layer(tokens, cache[i]) - h = self.norm_f(h) - return h, cache - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.backbone = Mamba(args) - - def __call__(self, inputs: mx.array, cache=None): - out, cache = self.backbone(inputs, cache) - out = self.backbone.embeddings.as_linear(out) - return out, cache - - @property - def layers(self): - return self.backbone.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_hidden_layers - - @property - def n_kv_heads(self): - return self.args.num_hidden_layers - - def make_cache(self): - # return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] - return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] - # return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))] - - def generate(self, input_ids: mx.array, n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): - self.eval() - - if input_ids.ndim == 1: - input_ids = mx.expand_dims(input_ids, 0) - - caches = self.make_cache() - - for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): - next_token_logits, caches = self(input_ids[:, i], caches) - - if i+1 >= input_ids.shape[1]: - - if top_k is not None: - values = mx.topk(next_token_logits, k=top_k) - mask = next_token_logits < (values[:, 0, None]) - next_token_logits = mx.where(mask, -5000, next_token_logits) - - if sample and temperature > 0: - next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) - else: - next_token = mx.argmax(next_token_logits, axis=-1)[:, None] - - input_ids = mx.concatenate([input_ids, next_token], axis=1) - - self.train() - return input_ids - - def generate_step(self, input_ids: mx.array, sample: bool = True, temperature: float = 1.0, top_k: int = None): - self.eval() - - if input_ids.ndim == 1: - input_ids = mx.expand_dims(input_ids, 0) - - caches = self.make_cache() - - # Generate the next token logits - next_token_logits, caches = self(input_ids, caches) - - # Apply top_k filtering if specified - if top_k is not None: - values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to highest - mask = next_token_logits < (values[:, -1, None]) - next_token_logits = mx.where(mask, -5000, next_token_logits) # -mx.inf is problematic for now - - # Sample the next token or take the argmax based on the temperature - if sample and temperature > 0: - next_token = mx.random.categorical(next_token_logits * (1 / temperature), num_samples=1) - else: - next_token = mx.argmax(next_token_logits, axis=-1)[:, None] - - # Concatenate the next token to the input_ids - input_ids = mx.concatenate([input_ids, next_token], axis=1) - - self.train() - return input_ids, caches \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba-tiny-pld.py b/llms/mlx_lm/models/mamba-tiny-pld.py deleted file mode 100644 index 8713978d5..000000000 --- a/llms/mlx_lm/models/mamba-tiny-pld.py +++ /dev/null @@ -1,154 +0,0 @@ -from dataclasses import dataclass -from typing import Optional, Union - -import math - -import torch - -# import tokenizer - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - vocab_size: int - n_layer: int - use_conv_bias: bool - expand: int - pad_vocab_size_multiple: int - conv_kernel: int - d_model: int - state_size: int - d_inner: int - initializer_range: float - use_bias: bool - time_step_init_scheme: str - time_step_max: float - time_step_min: float - time_step_floor: float - dt_rank: Union[int, str] = "auto" - - def __post_init__(self): - self.d_inner = self.expand * self.d_model - if self.n_layer is None: - self.n_layer = self.num_hidden_layers - if self.d_model is None: - self.d_model = self.hidden_size - if self.dt_rank == "auto": - self.dt_rank = math.ceil(self.d_model / 16) - -class MambaBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.in_proj = nn.Linear(args.d_model, 2 * args.d_inner, bias=args.use_bias) - # self.conv1d = DepthWiseConv1d(channels=args.d_inner, kernel_size=args.conv_kernel, bias=args.use_conv_bias, padding=args.conv_kernel-1) - self.conv1d = nn.Conv1d( - in_channels=args.d_inner, - out_channels=args.d_inner, - bias=args.use_conv_bias, - kernel_size=args.conv_kernel, - # groups=args.d_inner, - padding=args.conv_kernel - 1, - ) - self.x_proj = nn.Linear(args.d_inner, args.dt_rank + 2 * args.state_size, bias=False) - self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True) - - A = mx.repeat(mx.arange(1, args.state_size + 1).reshape([1, 16]), repeats=args.d_inner) - - - self.A_log = mx.log(A) - self.D = mx.ones([args.d_inner]) - - self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.use_bias) - - self.norm = nn.RMSNorm(args.d_model) - - def ssm(self, x): - d_in, N = self.A_log.shape - A = -mx.exp(self.A_log.float()) - D = self.D.float() - delta, B, C = self.x_proj(x).split(split_size=[self.config.dt_rank, N, N], dim=-1) - delta = nn.softplus(self.dt_proj(delta)) - return self.selective_scan(x, delta, A, B, C, D) - - def selective_scan(self, u, delta, A, B, C, D): - (b, l, d_in) = u.shape - n = A.shape[1] - deltaA = mx.exp(mx.einsum(delta, A, 'b l d_in, d_in n -> b d_in l n')) - deltaB_u = mx.einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b d_in l n') - - # Perform selective scan (see scan_SSM() in The Annotated S4 [2]) - x = mx.zeros((b, d_in, n), device=deltaA.device) - ys = [] - for i in range(l): - x = deltaA[:, :, i] * x + deltaB_u[:, :, i] - y = mx.einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') - ys.append(y) - y = mx.stack(ys, dim=1) # shape (b, l, d_in) - - y = y + u * D - - return y - - def __call__(self, x): - _, L, _ = x.shape - x, r = self.in_proj(x).split([self.args.d_inner, self.args.d_inner], axis=-1) - - x = mx.reshape(x, 'b l d_in -> b d_in l') - x = self.conv1d(x)[:, :, :L] - x = mx.rearrange(x, 'b d_in l -> b l d_in') - out = self.ssm(nn.silu(x)) * nn.silu(r) - return self.out_proj(out) + x - -class MambaModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.embedding = nn.Embedding(args.vocab_size, args.d_model) - self.layers = [MambaBlock(args) for _ in range(args.n_layer)] - self.norm_f = nn.RMSNorm(args.d_model) - - def __call__(self, inputs: mx.array_equal): - tokens = self.embedding(inputs) - for i, layer in enumerate(self.layers): - h = layer(tokens) - h = self.norm_f(h) - return h - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.model = self.backbone = MambaModel(args) - self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False) - self.lm_head.weight = self.model.embedding.weight - - def __call__(self, inputs: mx.array): - h = self.backbone(inputs) - return self.lm_head(h) - - @property - def layers(self): - return self.backbone.layers - - # def sanitize(self, weights): - # exclude_patterns = [ - # 'backbone.layers.mixer.A_log', - # 'backbone.layers.mixer.conv1d.weight', - # 'backbone.layers.mixer.dt_proj.weight', - # 'backbone.layers.mixer.in_proj.weight', - # 'backbone.layers.mixer.dt_proj.bias', - # 'backbone.layers.mixer.conv1d.bias', - # 'backbone.layers.mixer.D' - # ] - # return { - # k: v for k, v in weights.items() - # if not any(pattern in k for pattern in exclude_patterns) - # } \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba-torch.py b/llms/mlx_lm/models/mamba-torch.py deleted file mode 100644 index 84deb4d3f..000000000 --- a/llms/mlx_lm/models/mamba-torch.py +++ /dev/null @@ -1,145 +0,0 @@ -import torch.nn as nn -import torch -from configuration_mamba import MambaConfig -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.modeling_utils import PreTrainedModel -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -import math -import json -import torch -import torch.nn as nn -import torch.nn.functional as F -from dataclasses import dataclass -from einops import rearrange, repeat, einsum -from typing import Optional , Union ,Tuple -l -# Dear contributors of the https://github.com/johnma2006/mamba-minimal/tree/master repository, special thanks to Albert Gu and Tri Dao for their articles. (https://arxiv.org/abs/2312.00752) - - -class MambaRMSNorm(nn.Module): - def __init__(self, - d_model: int, - eps: float = 1e-5): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(d_model)) - def forward(self, x): - output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight - return output - - -class MambaBlock(nn.Module): - def __init__(self, config: MambaConfig): - """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1].""" - super().__init__() - self.config = config - - self.in_proj = nn.Linear(config.d_model, config.d_inner * 2, bias=config.bias) - - self.conv1d = nn.Conv1d( - in_channels=config.d_inner, - out_channels=config.d_inner, - bias=config.conv_bias, - kernel_size=config.d_conv, - groups=config.d_inner, - padding=config.d_conv - 1, - ) - - # x_proj takes in `x` and outputs the input-specific Δ, B, C - self.x_proj = nn.Linear(config.d_inner, config.dt_rank + config.d_state * 2, bias=False) - - # dt_proj projects Δ from dt_rank to d_in - self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True) - - A = repeat(torch.arange(1, config.d_state + 1), 'n -> d n', d=config.d_inner) - self.A_log = nn.Parameter(torch.log(A)) - self.D = nn.Parameter(torch.ones(config.d_inner)) - self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias) - # self.norm = MambaRMSNorm(config.d_model) - - def forward(self, x): - (b, l, d) = x.shape - x_copy = x # There was a separate class for residual, I deleted that part and added it here. - x, res = self.in_proj(self.norm(x)).split(split_size=[self.config.d_inner, self.config.d_inner], dim=-1) - - x = rearrange(x, 'b l d_in -> b d_in l') - x = self.conv1d(x)[:, :, :l] - x = rearrange(x, 'b d_in l -> b l d_in') - - x = F.silu(x) - - y = self.ssm(x) - - y = y * F.silu(res) - - output = self.out_proj(y) + x_copy - - return output - - - def ssm(self, x): - (d_in, n) = self.A_log.shape - - A = -torch.exp(self.A_log.float()) # shape (d_in, n) - D = self.D.float() - - x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n) - - (delta, B, C) = x_dbl.split(split_size=[self.config.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n) - delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in) - - y = self.selective_scan(x, delta, A, B, C, D) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2] - - return y - - - def selective_scan(self, u, delta, A, B, C, D): - (b, l, d_in) = u.shape - n = A.shape[1] - deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b d_in l n')) - deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b d_in l n') - x = torch.zeros((b, d_in, n), device=deltaA.device) - ys = [] - for i in range(l): - x = deltaA[:, :, i] * x + deltaB_u[:, :, i] - y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') - ys.append(y) - y = torch.stack(ys, dim=1) # shape (b, l, d_in) - - y = y + u * D - - return y - - -class MambaModel(MambaPreTrainedModel): - def __init__(self, config: MambaConfig): - super().__init__(config) - self.config = config - - self.embedding = nn.Embedding(config.vocab_size, config.d_model) - self.layers = nn.ModuleList([MambaBlock(config) for _ in range(config.n_layer)]) - self.norm_f = MambaRMSNorm(config.d_model) - - def forward(self, input_ids: torch.LongTensor = None): - x = self.embedding(input_ids) - all_hidden_states = list() - for layer in self.layers: - x = layer(x) - all_hidden_states.append(x) - return self.norm_f(x) - - -class MambaForCausalLM(MambaPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = MambaModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) - self.lm_head.weight = self.model.embedding.weight - - - def forward(self, input_ids: torch.LongTensor = None): - hidden_states = self.model(input_ids=input_ids) - return self.lm_head(hidden_states) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 340cce831..85a6fdb00 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional import math @@ -56,30 +55,6 @@ def __post_init__(self): if self.dt_rank == "auto": self.dt_rank = math.ceil(self.hidden_size / 16) -class DepthWiseConv1d(nn.Module): - def __init__(self, channels, kernel_size, bias, padding): - super().__init__() - self.channels = int(channels) - self.kernel_size = int(kernel_size) - self.bias = bias - self.padding = padding - self.weight = mx.random.normal(shape=(self.channels, 1, self.kernel_size)) - scale = math.sqrt(1.0 / (self.channels * self.kernel_size)) - self.weight *= scale # Ensure scaling is applied correctly - if bias: - self.bias = mx.zeros((self.channels,)) - else: - self.bias = None - - def __call__(self, x): - B, D, L = x.shape - assert D == self.channels, f"Input channels ({D}) must match the initialized channels ({self.channels})." - print("FORWARD PASS THROUGH CONV") - print(self.kernel_size) - print(self.weight) - out = nn.Conv1d(x, self.weight, kernel_size=self.kernel_size, bias=self.bias, padding=self.padding) - return out - def clamp(x, min=None, max=None): if min is not None: @@ -93,6 +68,46 @@ def clamp(x, min=None, max=None): return mx.where(mask_upper, max, x) +class DepthWiseConv1d(nn.Module): + def __init__(self, channels, kernel_size, bias, padding): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.padding = padding + self.weight = mx.random.normal(shape=(channels, 1, kernel_size)) + scale = math.sqrt(1.0 / (channels * kernel_size)) + self.weight *= scale + if bias: + self.bias = mx.zeros((channels,)) + else: + self.bias = None + + def __call__(self, x): + # x shape is (B, C, L) + B, C, L = x.shape + + # Pad the input + if self.padding > 0: + padding = [(0, 0), (0, 0), (self.padding, self.padding)] + x_padded = mx.pad(x, padding) + else: + x_padded = x + + # Perform depthwise convolution manually + out = [] + for i in range(L): + slice = x_padded[:, :, i:i+self.kernel_size] + out.append(mx.sum(slice * self.weight, axis=2)) + + out = mx.stack(out, axis=2) + + # Apply bias if present + if self.bias is not None: + out = out + self.bias.reshape(1, -1, 1) + + return out + + class MambaBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -108,10 +123,10 @@ def __init__(self, args: ModelArgs): self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) self.conv1d = DepthWiseConv1d( - channels=self.intermediate_size, - kernel_size=self.conv_kernel_size, + channels=int(self.intermediate_size), + kernel_size=int(self.conv_kernel_size), bias=self.use_conv_bias, - padding=self.conv_kernel_size-1 + padding=int(self.conv_kernel_size - 1) ) self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) @@ -156,30 +171,36 @@ def ssm(self, x, h): h = deltaA * h + BX # (B, ED, N) - y = (h @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1) + y = mx.sum(h * mx.expand_dims(C, 1), axis=-1) # (B, ED) y = y + D * x return y, h def __call__(self, x, cache): h, inputs = cache - x, z = self.in_proj(x).split(indices_or_sections=2, axis=-1) # (B, ED), (B, ED) - # x branch - x_cache = mx.expand_dims(x, 1) # (B, 1, ED) - # Ensure inputs has the correct shape - if inputs.ndim == 2: - inputs = mx.expand_dims(inputs, 1) # Add a dimension if it's missing - - print(f"inputs shape: {inputs.shape}") - print(f"x_cache shape: {x_cache.shape}") - - conv_input = mx.concatenate([inputs, x_cache], axis=1) # (B, d_conv, ED) <---------- Here is the problem ValueError: [concatenate] All the input arrays must have the same number of dimensions. However, got arrays with dimensions 3 and 4. ||| inputs shape: (1, 3, 1536) x_cache shape: (1, 1, 5, 1536) - x = self.conv1d(conv_input)[:, -1, :] # (B, ED) - y, h = self.ssm(nn.silu(x), h) - output = y * nn.silu(z) # * z branch - # prepare cache for next call - inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, d_conv-1, ED) - return self.out_proj(output), (h, inputs) # (B, D), cache + + x, z = self.in_proj(x).split(indices_or_sections=2, axis=-1) + + # x is now (B, L, C), we need (B, C, L) for conv1d + x_cache = x.transpose(0, 2, 1) + + if inputs is None: + inputs = mx.zeros((x.shape[0], self.intermediate_size, self.conv_kernel_size - 1)) + else: + inputs = inputs.transpose(0, 2, 1) # Change to (batch, channels, sequence) + + conv_input = mx.concatenate([inputs, x_cache], axis=2) + + x = self.conv1d(conv_input) + x = x[:, :, -1] # Take the last element of the sequence + + y, h = self.ssm(x, h) + output = y * nn.silu(z[:, -1, :]) + + # Update inputs for the next iteration + inputs = conv_input[:, :, 1:] + + return self.out_proj(output), (h, inputs.transpose(0, 2, 1)) class ResidualBlock(nn.Module): def __init__(self, args: ModelArgs): @@ -188,8 +209,9 @@ def __init__(self, args: ModelArgs): self.norm = nn.RMSNorm(args.hidden_size) def __call__(self, inputs: mx.array, cache): + residual = inputs output, cache = self.mixer(self.norm(inputs), cache) - output = output + inputs + output = output + residual[:, -1, :] # Add residual only for the last time step return output, cache @@ -234,61 +256,4 @@ def n_kv_heads(self): def make_cache(self): # return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] - return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))] - - def generate(self, input_ids: mx.array, n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): - self.eval() - - if input_ids.ndim == 1: - input_ids = mx.expand_dims(input_ids, 0) - - caches = self.make_cache() - - for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): - next_token_logits, caches = self(input_ids[:, i], caches) - - if i+1 >= input_ids.shape[1]: - - if top_k is not None: - values = mx.topk(next_token_logits, k=top_k) - mask = next_token_logits < (values[:, 0, None]) - next_token_logits = mx.where(mask, -5000, next_token_logits) - - if sample and temperature > 0: - next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) - else: - next_token = mx.argmax(next_token_logits, axis=-1)[:, None] - - input_ids = mx.concatenate([input_ids, next_token], axis=1) - - self.train() - return input_ids - - def generate_step(self, input_ids: mx.array, sample: bool = True, temperature: float = 1.0, top_k: int = None): - self.eval() - - if input_ids.ndim == 1: - input_ids = mx.expand_dims(input_ids, 0) - - caches = self.make_cache() - - # Generate the next token logits - next_token_logits, caches = self(input_ids, caches) - - # Apply top_k filtering if specified - if top_k is not None: - values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to highest - mask = next_token_logits < (values[:, -1, None]) - next_token_logits = mx.where(mask, -5000, next_token_logits) # -mx.inf is problematic for now - - # Sample the next token or take the argmax based on the temperature - if sample and temperature > 0: - next_token = mx.random.categorical(next_token_logits * (1 / temperature), num_samples=1) - else: - next_token = mx.argmax(next_token_logits, axis=-1)[:, None] - - # Concatenate the next token to the input_ids - input_ids = mx.concatenate([input_ids, next_token], axis=1) - - self.train() - return input_ids, caches \ No newline at end of file + return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))] \ No newline at end of file From c02d462a940dab4c2abd43e5341a5504e41a989a Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Mon, 19 Aug 2024 17:59:25 +0200 Subject: [PATCH 22/40] update --- llms/mlx_lm/models/mamba.py | 243 ++++++++++++++++++----------- llms/mlx_lm/models/mamba1.py | 293 +++++++++++++++++++++++++++++++++++ llms/mlx_lm/models/mamba2.py | 258 ++++++++++++++++++++++++++++++ 3 files changed, 704 insertions(+), 90 deletions(-) create mode 100644 llms/mlx_lm/models/mamba1.py create mode 100644 llms/mlx_lm/models/mamba2.py diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 85a6fdb00..172aab685 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -40,6 +40,12 @@ def __post_init__(self): self.intermediate_size = self.d_inner if not hasattr(self, 'state_size') and hasattr(self, 'd_state'): self.state_size = self.d_state + if not hasattr(self, 'time_step_min') and hasattr(self, 'dt_min'): + self.time_step_min = self.dt_min + if not hasattr(self, 'time_step_max') and hasattr(self, 'dt_max'): + self.time_step_min = self.dt_max + if not hasattr(self, 'time_step_floor') and hasattr(self, 'dt_init_floor'): + self.time_step_min = self.dt_init_floor if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'): self.num_hidden_layers = self.n_layer if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'): @@ -66,46 +72,54 @@ def clamp(x, min=None, max=None): return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) return mx.where(mask_lower, min, x) return mx.where(mask_upper, max, x) + - -class DepthWiseConv1d(nn.Module): - def __init__(self, channels, kernel_size, bias, padding): +class Conv1d(nn.Module): + def __init__( + self, + channels: int, + kernel_size: int, + bias: bool = True, + padding: int = 0 + ): super().__init__() self.channels = channels self.kernel_size = kernel_size + self.use_bias = bias self.padding = padding - self.weight = mx.random.normal(shape=(channels, 1, kernel_size)) - scale = math.sqrt(1.0 / (channels * kernel_size)) - self.weight *= scale - if bias: + + # Change the weight initialization to match the expected shape + self.weight = mx.zeros((kernel_size, 1, channels)) + if self.use_bias: self.bias = mx.zeros((channels,)) else: self.bias = None - def __call__(self, x): - # x shape is (B, C, L) - B, C, L = x.shape - - # Pad the input - if self.padding > 0: - padding = [(0, 0), (0, 0), (self.padding, self.padding)] - x_padded = mx.pad(x, padding) + def __call__(self, x, cache=None): + # Use the weight directly without transposing + w = self.weight + if cache is not None: + l = [] + # Pad the cache if needed + if cache.shape[1] < self.kernel_size - 1: + l.append( + mx.zeros( + (x.shape[0], self.kernel_size - 1 - cache.shape[1], self.channels), dtype=x.dtype + ) + ) + l.extend([cache, x]) + x = mx.concatenate(l, axis=1) + y = mx.conv_general(x, w, padding=([0], [0]), groups=self.channels) else: - x_padded = x - - # Perform depthwise convolution manually - out = [] - for i in range(L): - slice = x_padded[:, :, i:i+self.kernel_size] - out.append(mx.sum(slice * self.weight, axis=2)) - - out = mx.stack(out, axis=2) - - # Apply bias if present - if self.bias is not None: - out = out + self.bias.reshape(1, -1, 1) + y = mx.conv_general(x, w, padding=([self.padding], [0]), groups=self.channels) + + # The cache is always kernel_size - 1 + cache = x[:, max(x.shape[1] - self.kernel_size + 1, 0) :, :] - return out + if self.use_bias: + y = y + self.bias + + return y, cache class MambaBlock(nn.Module): @@ -113,94 +127,115 @@ def __init__(self, args: ModelArgs): super().__init__() self.args = args - self.hidden_size = args.hidden_size - self.ssm_state_size = args.state_size - self.conv_kernel_size = args.conv_kernel - self.intermediate_size = args.intermediate_size - self.time_step_rank = int(args.time_step_rank) - self.use_conv_bias = args.use_conv_bias - - self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) + # projects block input from D to 2*ED (two branches) + self.in_proj = nn.Linear(args.hidden_size, 2 * args.intermediate_size, bias=args.use_bias) - self.conv1d = DepthWiseConv1d( - channels=int(self.intermediate_size), - kernel_size=int(self.conv_kernel_size), - bias=self.use_conv_bias, - padding=int(self.conv_kernel_size - 1) + # short 1d conv over time + self.conv1d = Conv1d( + channels=args.intermediate_size, + kernel_size=args.conv_kernel, + bias=args.use_conv_bias, + padding=args.conv_kernel-1 ) - self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) - self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + # projects x to input-dependent Δ, B, C + self.x_proj = nn.Linear(args.intermediate_size, args.dt_rank + 2 * args.state_size, bias=False) + # projects Δ from dt_rank to intermediate_size + self.dt_proj = nn.Linear(args.dt_rank, args.intermediate_size, bias=True) + + # dt initialization + # dt weights dt_init_std = args.dt_rank**-0.5 * args.state_size + if args.time_step_init_scheme == "constant": self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) elif args.time_step_init_scheme == "random": self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) else: raise NotImplementedError - + + # dt bias dt = clamp(mx.exp( mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min) ), min=args.time_step_floor) inv_dt = dt + mx.log1p(-mx.exp(-dt)) self.dt_proj.bias = inv_dt - A = mx.repeat(mx.arange(1., self.ssm_state_size + 1.).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) - self.A_log = mx.log(A) - self.D = mx.ones([self.intermediate_size]) + # S4D real initialization + A = mx.repeat(mx.arange(1., 16 + 1.).reshape([1, 16]), repeats=args.intermediate_size, axis=0) + self.A_log = mx.log(A) # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ? + self.D = mx.ones([args.intermediate_size]) - self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) + # projects block output from ED back to D + self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) def ssm(self, x, h): - A = -mx.exp(self.A_log) # (ED, N) + # x : (B, ED) + # h : (B, ED, N) + + # y : (B, ED) + # h : (B, ED, N) + + A = -mx.exp(self.A_log) # (ED, N) # todo : move out of step (timestep independent) D = self.D - deltaBC = self.x_proj(x) # (B, dt_rank+2*N) + deltaBC = self.x_proj(x) # (B, dt_rank+2*N) - delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) - delta = nn.softplus(self.dt_proj(delta)) # (B, ED) + delta, B, C = mx.split(deltaBC, indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) + delta = nn.softplus(self.dt_proj(delta)) # (B, ED) - deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) - deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N) + deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) + deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N) - BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) + BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) if h is None: - h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) + h = mx.zeros([x.shape[0], self.args.hidden_size, self.args.state_size]) # (B, ED, N) - h = deltaA * h + BX # (B, ED, N) + h = deltaA * h + BX # (B, ED, N) - y = mx.sum(h * mx.expand_dims(C, 1), axis=-1) # (B, ED) + y = (h @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1) y = y + D * x + return y, h def __call__(self, x, cache): - h, inputs = cache - - x, z = self.in_proj(x).split(indices_or_sections=2, axis=-1) + # x : (B, D) + # cache : (h, inputs) + # h : (B, ED, N) + # inputs : (B, conv_kernel-1, ED) - # x is now (B, L, C), we need (B, C, L) for conv1d - x_cache = x.transpose(0, 2, 1) - - if inputs is None: - inputs = mx.zeros((x.shape[0], self.intermediate_size, self.conv_kernel_size - 1)) - else: - inputs = inputs.transpose(0, 2, 1) # Change to (batch, channels, sequence) + # y : (B, D) + # cache : (h, inputs) - conv_input = mx.concatenate([inputs, x_cache], axis=2) - - x = self.conv1d(conv_input) - x = x[:, :, -1] # Take the last element of the sequence - - y, h = self.ssm(x, h) - output = y * nn.silu(z[:, -1, :]) + h, inputs = cache - # Update inputs for the next iteration - inputs = conv_input[:, :, 1:] + print("Input shape:", x.shape) + xz = self.in_proj(x) # (B, 2*ED) + xz = xz.reshape(x.shape[0], -1) # Ensure shape is (B, 2*ED) + print("After in_proj shape:", xz.shape) + x, z = xz.split(indices_or_sections=2, axis=1) # (B, ED), (B, ED) + + # x branch + x_cache = mx.expand_dims(x, 1) + x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.args.conv_kernel-1, :] # (B, ED) + + x = nn.silu(x) + y, h = self.ssm_step(x, h) + + # z branch + z = nn.silu(z) + + output = y * z + output = self.out_proj(output) # (B, D) + + # prepare cache for next call + inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, conv_kernel-1, ED) + cache = (h, inputs) - return self.out_proj(output), (h, inputs.transpose(0, 2, 1)) + return output, cache class ResidualBlock(nn.Module): def __init__(self, args: ModelArgs): @@ -209,12 +244,18 @@ def __init__(self, args: ModelArgs): self.norm = nn.RMSNorm(args.hidden_size) def __call__(self, inputs: mx.array, cache): - residual = inputs + # x : (B, D) + # cache : (h, inputs) + # h : (B, ED, N) + # inputs: (B, conv_kernel-1, ED) + + # output : (B, D) + # cache : (h, inputs) + output, cache = self.mixer(self.norm(inputs), cache) - output = output + residual[:, -1, :] # Add residual only for the last time step + output = output + inputs return output, cache - class Mamba(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -222,12 +263,23 @@ def __init__(self, args: ModelArgs): self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] self.norm_f = nn.RMSNorm(args.hidden_size) - def __call__(self, inputs: mx.array, cache): - tokens = self.embeddings(inputs) + def __call__(self, tokens: mx.array, caches): + # tokens : (B, L) + + # logits : (B, L, vocab_size) + + x = self.embeddings(tokens) + + # x : (B, L, D) + # caches : [cache(layer) for all layers], cache : (h, inputs) + + # y : (B, L, D) + # caches : [cache(layer) for all layers], cache : (h, inputs) + for i, layer in enumerate(self.layers): - h, cache[i] = layer(tokens, cache[i]) - h = self.norm_f(h) - return h, cache + x, caches[i] = layer(x, caches[i]) + + return x, caches class Model(nn.Module): @@ -239,7 +291,7 @@ def __init__(self, args: ModelArgs): def __call__(self, inputs: mx.array, cache=None): out, cache = self.backbone(inputs, cache) - out = self.backbone.embeddings.as_linear(out) + # out = self.backbone.embeddings.as_linear(out) return out, cache @property @@ -255,5 +307,16 @@ def n_kv_heads(self): return self.args.num_hidden_layers def make_cache(self): - # return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] - return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))] \ No newline at end of file + return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] + + def sanitize(self, weights): + for key, value in weights.items(): + if "mixer.conv1d.weight" in key: + # Ensure the weight is in the shape (kernel_size, 1, channels) + if value.shape != (self.args.conv_kernel, 1, self.args.intermediate_size): + weights[key] = value.reshape(self.args.conv_kernel, 1, self.args.intermediate_size) + elif key == "backbone.embeddings.weight": + # Ensure the embedding weight is in the shape (vocab_size, hidden_size) + if value.shape != (self.args.vocab_size, self.args.hidden_size): + weights[key] = value.T + return weights \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba1.py b/llms/mlx_lm/models/mamba1.py new file mode 100644 index 000000000..0b64f967e --- /dev/null +++ b/llms/mlx_lm/models/mamba1.py @@ -0,0 +1,293 @@ +from dataclasses import dataclass + +import math + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + vocab_size: int + hidden_size: int + intermediate_size: int + state_size: int + num_hidden_layers: int + layer_norm_epsilon: float + expand: int + conv_kernel: int + use_bias: bool + use_conv_bias: bool + initializer_range: float + time_step_rank: int + time_step_scale: float + time_step_min: float + time_step_max: float + time_step_init_scheme: str + time_step_floor: float + rescale_prenorm_residual: bool + use_cache: bool + use_mambapy: bool = False + dt_rank: str = "auto" + + def __post_init__(self): + if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'): + self.hidden_size = self.d_model + if not hasattr(self, 'intermediate_size') and hasattr(self, 'd_inner'): + self.intermediate_size = self.d_inner + if not hasattr(self, 'state_size') and hasattr(self, 'd_state'): + self.state_size = self.d_state + if not hasattr(self, 'time_step_min') and hasattr(self, 'dt_min'): + self.time_step_min = self.dt_min + if not hasattr(self, 'time_step_max') and hasattr(self, 'dt_max'): + self.time_step_min = self.dt_max + if not hasattr(self, 'time_step_floor') and hasattr(self, 'dt_init_floor'): + self.time_step_min = self.dt_init_floor + if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'): + self.num_hidden_layers = self.n_layer + if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'): + self.num_hidden_layers = self.n_layers + if not hasattr(self, 'conv_kernel') and hasattr(self, 'd_conv'): + self.conv_kernel = self.d_conv + if not hasattr(self, 'use_bias') and hasattr(self, 'bias'): + self.use_bias = self.bias + if not hasattr(self, 'use_conv_bias') and hasattr(self, 'conv_bias'): + self.use_conv_bias = self.conv_bias + + self.intermediate_size = self.expand * self.hidden_size + if self.dt_rank == "auto": + self.dt_rank = math.ceil(self.hidden_size / 16) + + +def clamp(x, min=None, max=None): + if min is not None: + mask_lower = x < min + if max is not None: + mask_upper = x > max + if min is not None: + if max is not None: + return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) + return mx.where(mask_lower, min, x) + return mx.where(mask_upper, max, x) + + +class DepthWiseConv1d(nn.Module): + def __init__(self, channels, kernel_size, bias, padding): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.bias = bias + self.padding = padding + + self.conv1d = nn.Conv1d( + in_channels=channels, + out_channels=channels, + kernel_size=kernel_size, + bias=True, + padding=padding + ) + indices = mx.arange(channels) + mask = mx.zeros_like(self.conv1d.weight) + mask[indices, :, indices] = 1 + self.conv1d.weight *= mask + + def __call__(self, x): + return self.conv1d(x) + + +class MambaBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + # projects block input from D to 2*ED (two branches) + self.in_proj = nn.Linear(args.hidden_size, 2 * args.intermediate_size, bias=args.use_bias) + + # short 1d conv over time + self.conv1d = DepthWiseConv1d( + channels=args.intermediate_size, + kernel_size=args.conv_kernel, + bias=args.use_conv_bias, + padding=args.conv_kernel-1 + ) + + # projects x to input-dependent Δ, B, C + self.x_proj = nn.Linear(args.intermediate_size, args.dt_rank + 2 * args.state_size, bias=False) + + # projects Δ from dt_rank to intermediate_size + self.dt_proj = nn.Linear(args.dt_rank, args.intermediate_size, bias=True) + + # dt initialization + # dt weights + dt_init_std = args.dt_rank**-0.5 * args.state_size + + if args.time_step_init_scheme == "constant": + self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) + elif args.time_step_init_scheme == "random": + self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) + else: + raise NotImplementedError + + # dt bias + dt = clamp(mx.exp( + mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min) + ), min=args.time_step_floor) + inv_dt = dt + mx.log1p(-mx.exp(-dt)) # inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + self.dt_proj.bias = inv_dt + + # S4D real initialization + A = mx.repeat(mx.arange(1., 16 + 1.).reshape([1, 16]), repeats=args.intermediate_size, axis=0) + self.A_log = mx.log(A) # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ? + self.D = mx.ones([args.intermediate_size]) + + # projects block output from ED back to D + self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) + + def ssm(self, x, h): + # x : (B, ED) + # h : (B, ED, N) + + # y : (B, ED) + # h : (B, ED, N) + + A = -mx.exp(self.A_log) # (ED, N) # todo : move out of step (timestep independent) + D = self.D + + deltaBC = self.x_proj(x) # (B, dt_rank+2*N) + + delta, B, C = mx.split(deltaBC, indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) + delta = nn.softplus(self.dt_proj(delta)) # (B, ED) + + deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) + deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N) + + BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) + + if h is None: + h = mx.zeros([x.shape[0], self.args.hidden_size, self.args.state_size]) # (B, ED, N) + + h = deltaA * h + BX # (B, ED, N) + + y = (h @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1) + + y = y + D * x + + return y, h + + def __call__(self, x, cache): + # x : (B, D) + # cache : (h, inputs) + # h : (B, ED, N) + # inputs : (B, conv_kernel-1, ED) + + # y : (B, D) + # cache : (h, inputs) + + h, inputs = cache + + xz = self.in_proj(x) # (B, 2*ED) + x, z = xz.split(indices_or_sections=2, axis=1) # (B, ED), (B, ED) + + # x branch + x_cache = mx.expand_dims(x, 1) + x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.args.conv_kernel-1, :] # (B, ED) + + x = nn.silu(x) + y, h = self.ssm_step(x, h) + + # z branch + z = nn.silu(z) + + output = y * z + output = self.out_proj(output) # (B, D) + + # prepare cache for next call + inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, conv_kernel-1, ED) + cache = (h, inputs) + + return output, cache + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.mixer = MambaBlock(args) + self.norm = nn.RMSNorm(args.hidden_size) + + def __call__(self, inputs: mx.array, cache): + # x : (B, D) + # cache : (h, inputs) + # h : (B, ED, N) + # inputs: (B, conv_kernel-1, ED) + + # output : (B, D) + # cache : (h, inputs) + + output, cache = self.mixer(self.norm(inputs), cache) + output = output + inputs + return output, cache + +class Mamba(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] + self.norm_f = nn.RMSNorm(args.hidden_size) + + def __call__(self, tokens: mx.array, caches): + # tokens : (B, L) + + # logits : (B, L, vocab_size) + + x = self.embeddings(tokens) + + # x : (B, L, D) + # caches : [cache(layer) for all layers], cache : (h, inputs) + + # y : (B, L, D) + # caches : [cache(layer) for all layers], cache : (h, inputs) + + for i, layer in enumerate(self.layers): + x, caches[i] = layer(x, caches[i]) + + return x, caches + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.backbone = Mamba(args) + + def __call__(self, inputs: mx.array, cache=None): + out, cache = self.backbone(inputs, cache) + # out = self.backbone.embeddings.as_linear(out) + return out, cache + + @property + def layers(self): + return self.backbone.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_hidden_layers + + @property + def n_kv_heads(self): + return self.args.num_hidden_layers + + def make_cache(self): + # return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] + return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))] + + def sanitize(self, weights): + new_weights = {} + for key, value in weights.items(): + if "mixer.conv1d.weight" in key: + weights[key] = value.T + new_key = key.replace('mixer.conv1d', 'mixer.conv1d.conv1d') + new_weights[new_key] = value + return new_weights \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py new file mode 100644 index 000000000..04f67d050 --- /dev/null +++ b/llms/mlx_lm/models/mamba2.py @@ -0,0 +1,258 @@ +from dataclasses import dataclass + +import math + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + vocab_size: int + hidden_size: int + intermediate_size: int + state_size: int + num_hidden_layers: int + layer_norm_epsilon: float + expand: int + conv_kernel: int + use_bias: bool + use_conv_bias: bool + initializer_range: float + time_step_rank: int + time_step_scale: float + time_step_min: float + time_step_max: float + time_step_init_scheme: str + time_step_floor: float + rescale_prenorm_residual: bool + use_cache: bool + use_mambapy: bool = False + dt_rank: str = "auto" + + def __post_init__(self): + if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'): + self.hidden_size = self.d_model + if not hasattr(self, 'intermediate_size') and hasattr(self, 'd_inner'): + self.intermediate_size = self.d_inner + if not hasattr(self, 'state_size') and hasattr(self, 'd_state'): + self.state_size = self.d_state + if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'): + self.num_hidden_layers = self.n_layer + if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'): + self.num_hidden_layers = self.n_layers + if not hasattr(self, 'conv_kernel') and hasattr(self, 'd_conv'): + self.conv_kernel = self.d_conv + if not hasattr(self, 'use_bias') and hasattr(self, 'bias'): + self.use_bias = self.bias + if not hasattr(self, 'use_conv_bias') and hasattr(self, 'conv_bias'): + self.use_conv_bias = self.conv_bias + + self.intermediate_size = self.expand * self.hidden_size + if self.dt_rank == "auto": + self.dt_rank = math.ceil(self.hidden_size / 16) + + +def clamp(x, min=None, max=None): + if min is not None: + mask_lower = x < min + if max is not None: + mask_upper = x > max + if min is not None: + if max is not None: + return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) + return mx.where(mask_lower, min, x) + return mx.where(mask_upper, max, x) + + +class DepthWiseConv1d(nn.Module): + def __init__(self, channels, kernel_size, bias, padding): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.padding = padding + self.weight = mx.random.normal(shape=(channels, 1, kernel_size)) + scale = math.sqrt(1.0 / (channels * kernel_size)) + self.weight *= scale + if bias: + self.bias = mx.zeros((channels,)) + else: + self.bias = None + + def __call__(self, x): + # x shape is (B, C, L) + B, C, L = x.shape + + # Pad the input + if self.padding > 0: + padding = [(0, 0), (0, 0), (self.padding, self.padding)] + x_padded = mx.pad(x, padding) + else: + x_padded = x + + # Perform depthwise convolution manually + out = [] + for i in range(L): + slice = x_padded[:, :, i:i+self.kernel_size] + out.append(mx.sum(slice * self.weight, axis=2)) + + out = mx.stack(out, axis=2) + + # Apply bias if present + if self.bias is not None: + out = out + self.bias.reshape(1, -1, 1) + + return out + + +class MambaBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.hidden_size = args.hidden_size + self.ssm_state_size = args.state_size + self.conv_kernel_size = args.conv_kernel + self.intermediate_size = args.intermediate_size + self.time_step_rank = int(args.time_step_rank) + self.use_conv_bias = args.use_conv_bias + + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) + + self.conv1d = DepthWiseConv1d( + channels=int(self.intermediate_size), + kernel_size=int(self.conv_kernel_size), + bias=self.use_conv_bias, + padding=int(self.conv_kernel_size - 1) + ) + + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + + dt_init_std = args.dt_rank**-0.5 * args.state_size + if args.time_step_init_scheme == "constant": + self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) + elif args.time_step_init_scheme == "random": + self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) + else: + raise NotImplementedError + + dt = clamp(mx.exp( + mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min) + ), min=args.time_step_floor) + inv_dt = dt + mx.log1p(-mx.exp(-dt)) + self.dt_proj.bias = inv_dt + + A = mx.repeat(mx.arange(1., self.ssm_state_size + 1.).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) + self.A_log = mx.log(A) + self.D = mx.ones([self.intermediate_size]) + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) + + def ssm(self, x, h): + A = -mx.exp(self.A_log) # (ED, N) + D = self.D + + deltaBC = self.x_proj(x) # (B, dt_rank+2*N) + + delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) + delta = nn.softplus(self.dt_proj(delta)) # (B, ED) + + deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) + deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N) + + BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) + + if h is None: + h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) + + h = deltaA * h + BX # (B, ED, N) + + y = mx.sum(h * mx.expand_dims(C, 1), axis=-1) # (B, ED) + + y = y + D * x + return y, h + + def __call__(self, x, cache): + h, inputs = cache + + x, z = self.in_proj(x).split(indices_or_sections=2, axis=-1) + + # x is now (B, L, C), we need (B, C, L) for conv1d + x_cache = x.transpose(0, 2, 1) + + if inputs is None: + inputs = mx.zeros((x.shape[0], self.intermediate_size, self.conv_kernel_size - 1)) + else: + inputs = inputs.transpose(0, 2, 1) # Change to (batch, channels, sequence) + + conv_input = mx.concatenate([inputs, x_cache], axis=2) + + x = self.conv1d(conv_input) + x = x[:, :, -1] # Take the last element of the sequence + + y, h = self.ssm(x, h) + output = y * nn.silu(z[:, -1, :]) + + # Update inputs for the next iteration + inputs = conv_input[:, :, 1:] + + return self.out_proj(output), (h, inputs.transpose(0, 2, 1)) + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.mixer = MambaBlock(args) + self.norm = nn.RMSNorm(args.hidden_size) + + def __call__(self, inputs: mx.array, cache): + output, cache = self.mixer(self.norm(inputs), cache) + output = output + inputs[:, -1, :] # Add residual only for the last time step + return output, cache + + +class Mamba(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] + self.norm_f = nn.RMSNorm(args.hidden_size) + + def __call__(self, inputs: mx.array, cache): + tokens = self.embeddings(inputs) + for i, layer in enumerate(self.layers): + h, cache[i] = layer(tokens, cache[i]) + h = self.norm_f(h) + return h, cache + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.backbone = Mamba(args) + + def __call__(self, inputs: mx.array, cache=None): + out, cache = self.backbone(inputs, cache) + out = self.backbone.embeddings.as_linear(out) + return out, cache + + @property + def layers(self): + return self.backbone.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_hidden_layers + + @property + def n_kv_heads(self): + return self.args.num_hidden_layers + + def make_cache(self): + # return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] + return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))] \ No newline at end of file From f3733cf5f89441bd04be3235837a309285c6808d Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sun, 1 Sep 2024 13:39:03 +0200 Subject: [PATCH 23/40] update --- llms/mlx_lm/models/mamba-old.py | 399 ------------- llms/mlx_lm/models/mamba-save.py | 832 --------------------------- llms/mlx_lm/models/mamba-tiny-pld.py | 154 ----- llms/mlx_lm/models/mamba-torch.py | 145 ----- llms/mlx_lm/models/mamba.py | 294 ---------- 5 files changed, 1824 deletions(-) delete mode 100644 llms/mlx_lm/models/mamba-old.py delete mode 100644 llms/mlx_lm/models/mamba-save.py delete mode 100644 llms/mlx_lm/models/mamba-tiny-pld.py delete mode 100644 llms/mlx_lm/models/mamba-torch.py delete mode 100644 llms/mlx_lm/models/mamba.py diff --git a/llms/mlx_lm/models/mamba-old.py b/llms/mlx_lm/models/mamba-old.py deleted file mode 100644 index 844d4fb7d..000000000 --- a/llms/mlx_lm/models/mamba-old.py +++ /dev/null @@ -1,399 +0,0 @@ -from dataclasses import dataclass -from typing import Optional, Union - -import math - -import torch - -# import tokenizer - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str = "mamba" - dt_rank: Union[int, str] # time_step_rank - d_model: int - d_inner: int - vocab_size: int - n_layer: int - use_bias: bool - use_conv_bias: bool - rms_norm: bool - conv_kernel: int - state_size: int - expand: int - time_step_init_scheme: str - time_step_max: float - time_step_min: float - time_step_floor: float - pscan: bool = False - tie_word_embeddings: bool = False - num_hidden_layers: int = None - hidden_size: int = None - # time_step_scale - - def __post_init__(self): - self.d_inner = self.expand * self.d_model - if self.n_layer is None: - self.n_layer = self.num_hidden_layers - if self.d_model is None: - self.d_model = self.hidden_size - if self.dt_rank == 'auto': - self.dt_rank = math.ceil(self.d_model / 16) - - -def pscan_main(A, X): - Aa = A - Xa = X - B, D, L, _ = A.shape - num_steps = int(math.log2(L)) - - for k in range(num_steps): - T = 2 * (Xa.shape[2] // 2) - Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa[:, :, :, 1] += Aa[:, :, :, 1] * Xa[:, :, :, 0] - Aa[:, :, :, 1] *= Aa[:, :, :, 0] - A[:, :, 2**(k+1)-1::2**(k+1)] = Aa[:, :, :, 1] - X[:, :, 2**(k+1)-1::2**(k+1)] = Xa[:, :, :, 1] - Aa = Aa[:, :, :, 1] - Xa = Xa[:, :, :, 1] - - for k in range(num_steps-1, -1, -1): - Aa = A[:, :, 2**k-1::2**k] - Xa = X[:, :, 2**k-1::2**k] - step_len = Xa.shape[2] - T = 2 * (step_len // 2) - if T < step_len: - last_val_aa = Aa[:, :, -1] * Aa[:, :, -2] - last_val_xa = Xa[:, :, -1] + Aa[:, :, -1] * Xa[:, :, -2] - Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa[:, :, 1:, 0] += Aa[:, :, 1:, 0] * Xa[:, :, :-1, 1] - Aa[:, :, 1:, 0] *= Aa[:, :, :-1, 1] - if T == step_len: - A[:, :, 2**k-1::2**(k+1)] = Aa[:, :, :, 0] - X[:, :, 2**k-1::2**(k+1)] = Xa[:, :, :, 0] - else: - A[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Aa[:, :, :, 0], mx.array([last_val_aa]).reshape(B, D, 1, -1)], axis=2) - X[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Xa[:, :, :, 0], mx.array([last_val_xa]).reshape(B, D, 1, -1)], axis=2) - - -def pscan(A_in, X_in): - A = A_in[:].transpose(0, 2, 1, 3) - X = X_in[:].transpose(0, 2, 1, 3) - pscan_main(A, X) - return X.transpose(0, 2, 1, 3) - - -def clamp(x, min=None, max=None): - if min is not None: - mask_lower = x < min - if max is not None: - mask_upper = x > max - if min is not None: - if max is not None: - return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) - return mx.where(mask_lower, min, x) - return mx.where(mask_upper, max, x) - - -def unsqueeze(x, axis): - assert axis <= len(x.shape) - if axis >= 0: - new_shape = x.shape[:axis] + tuple([1]) + x.shape[axis:] - else: - new_shape = x.shape + tuple([1]) - return x.reshape(new_shape) - - -class DepthWiseConv1d(nn.Module): - def __init__(self, channels, kernel_size, bias, padding): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.bias = bias - self.padding = padding - self.conv1d = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, bias=True, padding=padding) - indices = mx.arange(channels) - mask = mx.zeros_like(self.conv1d.weight) - mask[indices, :, indices] = 1 - self.conv1d.weight *= mask - - def __call__(self, x): - return self.conv1d(x) - - -class MambaBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.in_proj = nn.Linear(args.d_model, 2 * args.d_inner, bias=args.use_bias) - # self.conv1d = DepthWiseConv1d(channels=args.d_inner, kernel_size=args.conv_kernel, bias=args.use_conv_bias, padding=args.conv_kernel-1) - self.conv1d = nn.Conv1d( - in_channels=args.d_inner, - out_channels=args.d_inner, - bias=args.conv_bias, - kernel_size=args.d_conv, - groups=args.d_inner, - padding=args.d_conv - 1, - ) - self.x_proj = nn.Linear(args.d_inner, args.dt_rank + 2 * args.state_size, bias=False) - self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True) - - dt_init_std = args.dt_rank**-0.5 * args.state_size - if args.time_step_init_scheme == "constant": - self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) - elif args.time_step_init_scheme == "random": - self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) - else: - raise NotImplementedError - - dt = clamp(mx.exp(mx.random.uniform(shape=[args.d_inner]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)), min=args.time_step_floor) - self.dt_proj.bias = dt + mx.log1p(-mx.exp(-dt)) - A = mx.repeat(mx.arange(1., 16 + 1.).reshape([1, 16]), repeats=args.d_inner, axis=0) - self.A_log = mx.log(A) - self.D = mx.ones([args.d_inner]) - - self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.use_bias) - - self.norm = nn.RMSNorm(args.d_model) - - def ssm_step(self, x, h): - A = -mx.exp(self.A_log) - D = self.D - deltaBC = self.x_proj(self.norm(x)) - delta, B, C = mx.split(deltaBC, indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) - delta = nn.softplus(self.dt_proj(delta)) - deltaA = mx.exp(unsqueeze(delta, -1) * A) - deltaB = unsqueeze(delta, -1) * unsqueeze(B, 1) - BX = deltaB * unsqueeze(x, -1) - if h is None: - h = mx.zeros([x.shape[0], self.args.d_inner, self.args.state_size]) - h = deltaA * h + BX - y = (h @ unsqueeze(C, -1)).squeeze(2) - y = y + D * x - return y, h - - def ssm(self, x): # DONE - A = -mx.exp(self.A_log) - D = self.D - delta, B, C = self.x_proj(x).split(indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) - delta = nn.softplus(self.dt_proj(delta)) - if self.args.pscan: - y = self.selective_scan(x, delta, A, B, C, D) - else: - y = self.selective_scan_seq(x, delta, A, B, C, D) - return y - - def ssm_new(self, x): - d_in, N = self.A_log.shape - A = -mx.exp(self.A_log.float()) - D = self.D.float() - delta, B, C = self.x_proj(x).split(split_size=[self.config.dt_rank, N, N], dim=-1) - delta = nn.softplus(self.dt_proj(delta)) - return self.selective_scan_new(x, delta, A, B, C, D) - - def selective_scan(self, x, delta, A, B, C, D): # DONE - deltaA = mx.exp(unsqueeze(delta, -1) * A) - deltaB = unsqueeze(delta, -1) * unsqueeze(B, 2) - BX = deltaB * unsqueeze(x, -1) - hs = pscan(deltaA, BX) - y = (hs @ unsqueeze(C, -1)).squeeze(3) - return y + D * x - - def selective_scan_new(self, u, delta, A, B, C, D): - (b, l, d_in) = u.shape - n = A.shape[1] - deltaA = mx.exp(mx.einsum(delta, A, 'b l d_in, d_in n -> b d_in l n')) - deltaB_u = mx.einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b d_in l n') - - # Perform selective scan (see scan_SSM() in The Annotated S4 [2]) - x = mx.zeros((b, d_in, n), device=deltaA.device) - ys = [] - for i in range(l): - x = deltaA[:, :, i] * x + deltaB_u[:, :, i] - y = mx.einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') - ys.append(y) - y = mx.stack(ys, dim=1) # shape (b, l, d_in) - - y = y + u * D - - return y - - def selective_scan_seq(self, x, delta, A, B, C, D): - _, L, _ = x.shape - deltaA = mx.exp(unsqueeze(delta, -1) * A) - deltaB = unsqueeze(delta, -1) * unsqueeze(B, 2) - BX = deltaB * unsqueeze(x, -1) - h = mx.zeros([x.shape[0], self.args.d_inner, self.args.state_size]) - hs = [] - for t in range(0, L): - h = deltaA[:, t] * h + BX[:, t] - hs.append(h) - hs = mx.stack(hs, axis=1) - y = (hs @ unsqueeze(C, -1)).squeeze(3) - return y + D * x - - def step(self, x, cache): # Done - h, inputs = cache - x, z = self.in_proj(x).split(indices_or_sections=2, axis=1) - x_cache = unsqueeze(x, 1) - x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.args.conv_kernel-1, :] - y, h = self.ssm_step(nn.silu(x), h) - output = y * nn.silu(z) - output = self.out_proj(output) - inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) - return output, (h, inputs) - - def ssm_step(self, x, h): # Done - A = -mx.exp(self.A_log) - D = self.D - delta, B, C = self.x_proj(x).split(indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) - delta = nn.softplus(self.dt_proj(delta)) - deltaB = unsqueeze(delta, -1) * unsqueeze(B, 1) - BX = deltaB * unsqueeze(x, -1) - if h is None: - h = mx.zeros([x.shape[0], self.args.d_inner, self.args.d_state]) - h = deltaA * h + BX - y = (h @ unsqueeze(C, -1)).squeeze(2) - y = y + D * x - return y, h - - def __call__(self, x): # DONE - _, L, _ = x.shape - x, z = self.in_proj(x).split(indices_or_sections=2, axis=2) - x = self.conv1d(x)[:, :L, :] - output = self.ssm(nn.silu(x)) * nn.silu(z) - return self.out_proj(output) - - def new(self, x): - _, L, _ = x.shape - x, r = self.in_proj(x).split([self.args.d_inner, self.args.d_inner], axis=-1) - - x = mx.reshape(x, 'b l d_in -> b d_in l') - x = self.conv1d(x)[:, :, :L] - x = mx.rearrange(x, 'b d_in l -> b l d_in') - out = self.ssm_new(nn.silu(x)) * nn.silu(r) - return self.out_proj(out) + x - - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.mixer = MambaBlock(args) - self.norm = nn.RMSNorm(args.d_model) - - def __call__(self, inputs: mx.array, cache: Optional[mx.array] = None): - output, cache = self.mixer.step(self.norm(inputs), cache) - output = output + inputs - return output, cache - - -class Mamba(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.embedding = nn.Embedding(args.vocab_size, args.d_model) - self.layers = [ResidualBlock(args) for _ in range(args.n_layer)] - self.norm_f = nn.RMSNorm(args.d_model) - - def __call__(self, inputs: mx.array, cache=None): - tokens = self.embedding(inputs) - if cache is None: - cache = [None] * len(self.layers) - for i, layer in enumerate(self.layers): - h, cache[i] = layer(tokens, cache[i]) - h = self.norm_f(h) - return h, cache - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.backbone = Mamba(args) - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False) - - def __call__(self, inputs: mx.array, cache=None): - out, cache = self.backbone(inputs, cache) - - if self.args.tie_word_embeddings: - out = self.model.embed_tokens.as_linear(out) - return out, cache - - # def torch_to_mlx_depthwise_weights(self, torch_weights): - # torch_weights = torch_weights.transpose(2, 1) - # channels, kernel_size, _ = torch_weights.shape - - # mlx_weights = torch.zeros(channels, kernel_size, channels) - - # indices = torch.arange(channels) - # if torch_weights[:, :, 0].type() == 'torch.BFloat16Tensor': - # mlx_weights[indices, :, indices] = torch_weights[:, :, 0].float() - # else: - # mlx_weights[indices, :, indices] = torch_weights[:, :, 0] - - # return mlx_weights - - def sanitize(self, torch_state_dict): - new_state_dict = {} - for key, value in torch_state_dict.items(): - if 'conv1d.weight' in key: - value = self.torch_to_mlx_depthwise_weights(value) - - if 'conv1d' in key: - key = key.replace('conv1d', 'conv1d.conv1d') - - if value.type() == 'torch.BFloat16Tensor': - new_state_dict[key] = value.half().numpy() - else: - new_state_dict[key] = value.numpy() - - return new_state_dict - - @property - def layers(self): - return self.model.layers - - def generate(self, tokenizer=None, prompt: str="Hello", n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): - self.eval() - - input_ids = mx.array(tokenizer(prompt, return_tensors='np').input_ids) - - caches = [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.d_inner])) for _ in range(self.args.n_layer)] - - for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): - next_token_logits, caches = self(input_ids[:, i], caches) - - if i+1 >= input_ids.shape[1]: - - if top_k is not None: - values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to biggest - mask = next_token_logits < (values[:, 0, None]) - next_token_logits = mx.where(mask, -5000, next_token_logits) # TODO -mx.inf is problematic for now - - if sample and temperature > 0: - next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) - else: - next_token = mx.argmax(next_token_logits, axis=-1)[:, None] - - input_ids = mx.concatenate([input_ids, next_token], axis=1) - - output = [tokenizer.decode(output.tolist()) for output in input_ids][0] - - self.train() - - return output - -# model = Model(ModelArgs()) -# print(model) - -# logits = model.generate() -# print(logits) diff --git a/llms/mlx_lm/models/mamba-save.py b/llms/mlx_lm/models/mamba-save.py deleted file mode 100644 index 9858158f3..000000000 --- a/llms/mlx_lm/models/mamba-save.py +++ /dev/null @@ -1,832 +0,0 @@ -from dataclasses import dataclass -from typing import Optional, Union - -import math - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, MambaCache - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - vocab_size: int - hidden_size: int # d_model - intermediate_size: int # d_inner - state_size: int # d_state - num_hidden_layers: int # n_layer - layer_norm_epsilon: float - expand: int - conv_kernel: int # d_conv - use_bias: bool # bias - use_conv_bias: bool # conv_bias - initializer_range: float - time_step_rank: int - time_step_scale: float - time_step_min: float - time_step_max: float - time_step_init_scheme: str - time_step_floor: float - rescale_prenorm_residual: bool - use_cache: bool - use_mambapy: bool = False # pscan - dt_rank: str = "auto" - - def __post_init__(self): - self.intermediate_size = self.expand * self.hidden_size - if self.dt_rank == "auto": - self.dt_rank = math.ceil(self.hidden_size / 16) - - -def clamp(x, min=None, max=None): - if min is not None: - mask_lower = x < min - if max is not None: - mask_upper = x > max - if min is not None: - if max is not None: - return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) - return mx.where(mask_lower, min, x) - return mx.where(mask_upper, max, x) - - -def unsqueeze(x, axis): - assert axis <= len(x.shape) - if axis >= 0: - new_shape = x.shape[:axis] + tuple([1]) + x.shape[axis:] - else: - new_shape = x.shape + tuple([1]) - return x.reshape(new_shape) - - -class DepthWiseConv1d(nn.Module): - def __init__(self, channels, kernel_size, bias, padding): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.bias = bias - self.padding = padding - self.weight = mx.random.normal(shape=(channels, 1, kernel_size)) - scale = math.sqrt(1.0 / (channels * kernel_size)) - self.weight *= scale # Ensure scaling is applied correctly - if bias: - self.bias = mx.zeros((channels,)) - else: - self.bias = None - - def __call__(self, x): - out = nn.Conv1d(x, self.weight, kernel_size=self.kernel_size, bias=self.bias, padding=self.padding) - return out - - -class MambaBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - - self.hidden_size = args.hidden_size - self.ssm_state_size = args.state_size - self.conv_kernel_size = args.conv_kernel - self.intermediate_size = args.intermediate_size - self.time_step_rank = int(args.time_step_rank) - self.use_conv_bias = args.use_conv_bias - - self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) - - self.conv1d = DepthWiseConv1d( - channels=self.intermediate_size, - kernel_size=self.conv_kernel_size, - bias=self.use_conv_bias, - padding=self.conv_kernel_size-1 - ) - - self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) - self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) - - dt_init_std = args.dt_rank**-0.5 * args.state_size - if args.time_step_init_scheme == "constant": - self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) - elif args.time_step_init_scheme == "random": - self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) - else: - raise NotImplementedError - - dt = clamp(mx.exp(mx.random.uniform(shape=[self.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min)), min=args.time_step_floor) - self.dt_proj.bias = dt + mx.log1p(-mx.exp(-dt)) - inv_dt = dt + mx.log1p(-mx.exp(-dt)) - self.dt_proj.bias = inv_dt - - A = mx.repeat(mx.arange(1, self.ssm_state_size + 1).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) - self.A_log = mx.log(A) - self.D = mx.ones([self.intermediate_size]) - - self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) - - def ssm(self, x, h): - A = -mx.exp(self.A_log) - D = self.D - delta, B, C = self.x_proj(x).split(split_size=[self.intermediate_size, self.intermediate_size], dim=-1) - delta = nn.softplus(self.dt_proj(delta)) - deltaA = mx.exp(mx.unsqueeze(delta, -1) * A) - deltaB = unsqueeze(delta, -1) * unsqueeze(B, 1) - BX = deltaB * unsqueeze(x, -1) - if h is None: - h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) - h = deltaA * h + BX - y = (h @ mx.unsqueeze(C, -1)).squeeze(2) - y = y + D * x - return y, h - - def __call__(self, x, cache: Optional[MambaCache]): - h, inputs = cache - x, z = self.in_proj(x).split(indices_or_sections=2, axis=1) - x_cache = unsqueeze(x, 1) - x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.conv_kernel_size-1, :] # (B, ED) - y, h = self.ssm(nn.silu(x), h) - output = y * nn.silu(z) - inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) - cache.update(h, inputs) - return self.out_proj(output), cache - - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.mixer = MambaBlock(args) - self.norm = nn.RMSNorm(args.hidden_size) - - def __call__(self, inputs: mx.array, cache): - output, cache = self.mixer(self.norm(inputs), cache) - output = output + inputs - return output, cache - - -class Mamba(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] - self.norm_f = nn.RMSNorm(args.hidden_size) - - def __call__(self, inputs: mx.array, cache=None): - tokens = self.embeddings(inputs) - if cache is None: - cache = [None] * len(self.layers) - for i, layer in enumerate(self.layers): - h, cache[i] = layer(tokens, cache[i]) - h = self.norm_f(h) - return h, cache - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.backbone = Mamba(args) - # self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__(self, inputs: mx.array, cache=None): - out, cache = self.backbone(inputs, cache) - out = self.backbone.embeddings.as_linear(out) - return out, cache - - @property - def layers(self): - return self.backbone.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_hidden_layers - - @property - def n_kv_heads(self): - return self.args.num_hidden_layers - - def generate(self, input_ids: mx.array, n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): - self.eval() - - if input_ids.ndim == 1: - input_ids = input_ids.unsqueeze(0) - - caches = [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] - - for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): - next_token_logits, caches = self(input_ids[:, i], caches) - - if i+1 >= input_ids.shape[1]: - - if top_k is not None: - values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to biggest - mask = next_token_logits < (values[:, 0, None]) - next_token_logits = mx.where(mask, -5000, next_token_logits) # TODO -mx.inf is problematic for now - - if sample and temperature > 0: - next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) - else: - next_token = mx.argmax(next_token_logits, axis=-1)[:, None] - - input_ids = mx.concatenate([input_ids, next_token], axis=1) - - self.train() - return input_ids - - - - - - - - -# from dataclasses import dataclass -# from typing import Optional, Union - -# import math -# import einsum - -# import mlx.core as mx -# import mlx.nn as nn - -# from .base import BaseModelArgs, MambaCache - - -# @dataclass -# class ModelArgs(BaseModelArgs): -# model_type: str -# vocab_size: int -# hidden_size: int # d_model -# intermediate_size: int # d_inner -# state_size: int # d_state -# num_hidden_layers: int # n_layer -# layer_norm_epsilon: float -# expand: int -# conv_kernel: int # d_conv -# use_bias: bool # bias -# use_conv_bias: bool # conv_bias -# initializer_range: float -# time_step_rank: int -# time_step_scale: float -# time_step_min: float -# time_step_max: float -# time_step_init_scheme: str -# time_step_floor: float -# rescale_prenorm_residual: bool -# use_cache: bool -# use_mambapy: bool = False # pscan -# dt_rank: str = "auto" - -# def __post_init__(self): -# self.intermediate_size = self.expand * self.hidden_size -# if self.dt_rank == "auto": -# self.dt_rank = math.ceil(self.hidden_size / 16) - - -# def clamp(x, min=None, max=None): -# if min is not None: -# mask_lower = x < min -# if max is not None: -# mask_upper = x > max -# if min is not None: -# if max is not None: -# return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) -# return mx.where(mask_lower, min, x) -# return mx.where(mask_upper, max, x) - -# class MambaBlock(nn.Module): -# def __init__(self, args: ModelArgs): -# super().__init__() -# self.args = args - -# self.hidden_size = args.hidden_size -# self.ssm_state_size = args.state_size -# self.conv_kernel_size = args.conv_kernel -# self.intermediate_size = args.intermediate_size -# self.time_step_rank = int(args.time_step_rank) -# self.use_conv_bias = args.use_conv_bias - -# self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) - -# self.conv1d = nn.Conv1d( -# in_channels=self.intermediate_size, -# out_channels=self.intermediate_size, -# kernel_size=self.conv_kernel_size, -# bias=self.use_conv_bias, -# padding=self.conv_kernel_size-1 -# ) - -# self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) -# self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) - -# A = mx.repeat(mx.arange(1., self.ssm_state_size + 1), "n -> d n", repeats=self.intermediate_size) -# self.A_log = mx.log(A) -# self.D = mx.ones([self.intermediate_size]) - -# self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) - -# def ssm(self, x): -# (d_in, n) = self.A_log.shape - -# A = -mx.exp(self.A_log.float()) # shape (d_in, n) -# D = self.D.float() - -# x_dbl = self.x_proj(x) # (b, l, time_step_rank + 2*n) - -# (delta, B, C) = x_dbl.split(indices_or_sections=[self.time_step_rank, n, n], axis=-1) # delta: (b, l, time_step_rank). B, C: (b, l, n) -# delta = nn.softplus(self.dt_proj(delta)) # (b, l, d_in) - -# y = self.selective_scan(x, delta, A, B, C, D) - -# return y - -# def selective_scan(self, u, delta, A, B, C, D): -# (b, l, d_in) = u.shape -# n = A.shape[1] -# deltaA = mx.exp(einsum(delta, A, 'b l d_in, d_in n -> b d_in l n')) -# deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b d_in l n') -# x = mx.zeros((b, d_in, n), device=deltaA.device) -# ys = [] -# for i in range(l): -# x = deltaA[:, :, i] * x + deltaB_u[:, :, i] -# y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') -# ys.append(y) -# y = mx.stack(ys, dim=1) # shape (b, l, d_in) - -# y = y + u * D -# return y - -# def __call__(self, x): -# (b, l, d) = x.shape -# x_copy = x -# x, res = self.in_proj(self.norm(x)).split(indices_or_sections=[self.intermediate_size, self.intermediate_size], axis=-1) - -# x = mx.rearrange(x, 'b l d_in -> b d_in l') -# x = self.conv1d(x)[:, :, :l] -# x = mx.rearrange(x, 'b d_in l -> b l d_in') - -# x = nn.silu(x) - -# y = self.ssm(x) - -# y = y * nn.silu(res) -# return self.out_proj(y) + x_copy - - -# class ResidualBlock(nn.Module): -# def __init__(self, args: ModelArgs): -# super().__init__() -# self.mixer = MambaBlock(args) -# self.norm = nn.RMSNorm(args.hidden_size) - -# def __call__(self, inputs: mx.array): -# output = self.mixer(self.norm(inputs)) -# output = output + inputs -# return output - - -# class Mamba(nn.Module): -# def __init__(self, args: ModelArgs): -# super().__init__() -# self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) -# self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] -# self.norm_f = nn.RMSNorm(args.hidden_size) - -# def __call__(self, inputs: mx.array): -# tokens = self.embeddings(inputs) -# for i, layer in enumerate(self.layers): -# h, = layer(tokens) -# return self.norm_f(h) - - -# class Model(nn.Module): -# def __init__(self, args: ModelArgs): -# super().__init__() -# self.args = args -# self.model_type = args.model_type -# self.backbone = Mamba(args) - -# def __call__(self, inputs: mx.array, cache=None): -# out = self.backbone(inputs) -# out = self.backbone.embeddings.as_linear(out) -# return out, cache - -# @property -# def layers(self): -# return self.backbone.layers - -# @property -# def head_dim(self): -# return self.args.hidden_size // self.args.num_hidden_layers - -# @property -# def n_kv_heads(self): -# return self.args.num_hidden_layers - - - -from dataclasses import dataclass -from typing import Optional - -import math - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - vocab_size: int - hidden_size: int # d_model - intermediate_size: int # d_inner - state_size: int # d_state - num_hidden_layers: int # n_layer, n_layer - layer_norm_epsilon: float - expand: int - conv_kernel: int # d_conv - use_bias: bool # bias - use_conv_bias: bool # conv_bias - initializer_range: float - time_step_rank: int - time_step_scale: float - time_step_min: float - time_step_max: float - time_step_init_scheme: str - time_step_floor: float - rescale_prenorm_residual: bool - use_cache: bool - use_mambapy: bool = False # pscan - dt_rank: str = "auto" - - def __post_init__(self): - if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'): - self.hidden_size = self.d_model - if not hasattr(self, 'intermediate_size') and hasattr(self, 'd_inner'): - self.intermediate_size = self.d_inner - if not hasattr(self, 'state_size') and hasattr(self, 'd_state'): - self.state_size = self.d_state - if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'): - self.num_hidden_layers = self.n_layer - if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'): - self.num_hidden_layers = self.n_layers - if not hasattr(self, 'conv_kernel') and hasattr(self, 'd_conv'): - self.conv_kernel = self.d_conv - if not hasattr(self, 'use_bias') and hasattr(self, 'bias'): - self.use_bias = self.bias - if not hasattr(self, 'use_conv_bias') and hasattr(self, 'conv_bias'): - self.use_conv_bias = self.conv_bias - - self.intermediate_size = self.expand * self.hidden_size - if self.dt_rank == "auto": - self.dt_rank = math.ceil(self.hidden_size / 16) - - -def clamp(x, min=None, max=None): - if min is not None: - mask_lower = x < min - if max is not None: - mask_upper = x > max - if min is not None: - if max is not None: - return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) - return mx.where(mask_lower, min, x) - return mx.where(mask_upper, max, x) - - -def pscan_f(A, X): - Aa = A - Xa = X - - B, D, L, _ = A.shape - - num_steps = int(math.log2(L)) - - # up sweep - for k in range(num_steps): - T = 2 * (Xa.shape[2] // 2) - - Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) - - Xa[:, :, :, 1] += Aa[:, :, :, 1] * Xa[:, :, :, 0] - Aa[:, :, :, 1] *= Aa[:, :, :, 0] - - A[:, :, 2**(k+1)-1::2**(k+1)] = Aa[:, :, :, 1] - X[:, :, 2**(k+1)-1::2**(k+1)] = Xa[:, :, :, 1] - - Aa = Aa[:, :, :, 1] - Xa = Xa[:, :, :, 1] - - # down sweep - for k in range(num_steps-1, -1, -1): - Aa = A[:, :, 2**k-1::2**k] - Xa = X[:, :, 2**k-1::2**k] - - step_len = Xa.shape[2] - T = 2 * (step_len // 2) - - if T < step_len: - last_val_aa = Aa[:, :, -1] * Aa[:, :, -2] - last_val_xa = Xa[:, :, -1] + Aa[:, :, -1] * Xa[:, :, -2] - - Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) - - Xa[:, :, 1:, 0] += Aa[:, :, 1:, 0] * Xa[:, :, :-1, 1] - Aa[:, :, 1:, 0] *= Aa[:, :, :-1, 1] - - if T == step_len: - A[:, :, 2**k-1::2**(k+1)] = Aa[:, :, :, 0] - X[:, :, 2**k-1::2**(k+1)] = Xa[:, :, :, 0] - else: - A[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Aa[:, :, :, 0], mx.array([last_val_aa]).reshape(B, D, 1, -1)], axis=2) - X[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Xa[:, :, :, 0], mx.array([last_val_xa]).reshape(B, D, 1, -1)], axis=2) - -def pscan(A_in, X_in): - A = A_in[:].transpose(0, 2, 1, 3) - X = X_in[:].transpose(0, 2, 1, 3) - pscan_f(A, X) - return X.transpose(0, 2, 1, 3) - - -class DepthWiseConv1d(nn.Module): - def __init__(self, channels, kernel_size, bias, padding): - super().__init__() - self.channels = int(channels) - self.kernel_size = int(kernel_size) - self.bias = bias - self.padding = padding - self.weight = mx.random.normal(shape=(self.channels, 1, self.kernel_size)) - scale = math.sqrt(1.0 / (self.channels * self.kernel_size)) - self.weight *= scale # Ensure scaling is applied correctly - if bias: - self.bias = mx.zeros((self.channels,)) - else: - self.bias = None - - def __call__(self, x): - B, D, L = x.shape - assert D == self.channels, f"Input channels ({D}) must match the initialized channels ({self.channels})." - print("FORWARD PASS THROUGH CONV") - print(self.kernel_size) - print(self.weight) - out = nn.Conv1d(x, self.weight, kernel_size=self.kernel_size, bias=self.bias, padding=self.padding) - return out - - -class MambaBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - - self.hidden_size = args.hidden_size - self.ssm_state_size = args.state_size - self.conv_kernel_size = args.conv_kernel - self.intermediate_size = args.intermediate_size - self.time_step_rank = int(args.time_step_rank) - self.use_conv_bias = args.use_conv_bias - - self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) - - self.conv1d = DepthWiseConv1d( - channels=self.intermediate_size, - kernel_size=self.conv_kernel_size, - bias=self.use_conv_bias, - padding=self.conv_kernel_size-1 - ) - - self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) - self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) - - dt_init_std = args.dt_rank**-0.5 * args.state_size - if args.time_step_init_scheme == "constant": - self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) - elif args.time_step_init_scheme == "random": - self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) - else: - raise NotImplementedError - - dt = clamp(mx.exp( - mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min) - ), min=args.time_step_floor) - inv_dt = dt + mx.log1p(-mx.exp(-dt)) - self.dt_proj.bias = inv_dt - - A = mx.repeat(mx.arange(1., self.ssm_state_size + 1.).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) - self.A_log = mx.log(A) - self.D = mx.ones([self.intermediate_size]) - - self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) - - # def ssm_old(self, x): - # A = -mx.exp(self.A_log) # (ED, N) - # D = self.D - - # deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N) - - # delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, L, dt_rank), (B, L, N), (B, L, N) - # delta = mx.softplus(self.dt_proj(delta)) # (B, L, ED) - - # if self.args.use_mambapy: - # y = self.selective_scan(x, delta, A, B, C, D) - # else: - # y = self.selective_scan_seq(x, delta, A, B, C, D) - # return y - - # def selective_scan(self, x, delta, A, B, C, D): - # deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N) - # deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N) - # BX = deltaB * mx.expand_dims(x, -1) # (B, L, ED, N) - # hs = pscan(deltaA, BX) - # y = (hs @ mx.expand_dims(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) - # y = y + D * x - # return y - - # def selective_scan_seq(self, x, delta, A, B, C, D): - # _, L, _ = x.shape - # deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N) - # deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N) - # BX = deltaB * mx.expand_dims(x, -1) # (B, L, ED, N) - # h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) - # hs = [] - # for t in range(0, L): - # h = deltaA[:, t] * h + BX[:, t] - # hs.append(h) - # hs = mx.stack(hs, axis=1) - # y = (hs @ mx.expand_dims(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) - # y = y + D * x - # return y - - def ssm(self, x, h): - A = -mx.exp(self.A_log) # (ED, N) - D = self.D - - deltaBC = self.x_proj(x) # (B, dt_rank+2*N) - - delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) - delta = nn.softplus(self.dt_proj(delta)) # (B, ED) - - deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) - deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N) - - BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) - - if h is None: - h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) - - h = deltaA * h + BX # (B, ED, N) - - y = (h @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1) - - y = y + D * x - return y, h - - def __call__(self, x, cache): - h, inputs = cache - - xz = self.in_proj(x) # (B, 2*ED) - x, z = mx.split(xz, indices_or_sections=2, axis=-1) # (B, ED), (B, ED) - - # x branch - x_cache = mx.expand_dims(x, 1) # (B, 1, ED) - - # Ensure inputs has the correct shape - if inputs.ndim == 2: - inputs = mx.expand_dims(inputs, 1) # Add a dimension if it's missing - - print(f"inputs shape: {inputs.shape}") - print(f"x_cache shape: {x_cache.shape}") - - conv_input = mx.concatenate([inputs, x_cache], axis=1) # (B, d_conv, ED) - x = self.conv1d(conv_input)[:, -1, :] # (B, ED) - - x = nn.silu(x) - y, h = self.ssm(x, h) - - # z branch - z = nn.silu(z) - - output = y * z - output = self.out_proj(output) # (B, D) - - # prepare cache for next call - inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, d_conv-1, ED) - cache = (h, inputs) - - return output, cache - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.mixer = MambaBlock(args) - self.norm = nn.RMSNorm(args.hidden_size) - - def __call__(self, inputs: mx.array, cache): - output, cache = self.mixer(self.norm(inputs), cache) - output = output + inputs - return output, cache - - -class Mamba(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] - self.norm_f = nn.RMSNorm(args.hidden_size) - - def __call__(self, inputs: mx.array, cache): - tokens = self.embeddings(inputs) - for i, layer in enumerate(self.layers): - h, cache[i] = layer(tokens, cache[i]) - h = self.norm_f(h) - return h, cache - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.backbone = Mamba(args) - - def __call__(self, inputs: mx.array, cache=None): - out, cache = self.backbone(inputs, cache) - out = self.backbone.embeddings.as_linear(out) - return out, cache - - @property - def layers(self): - return self.backbone.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_hidden_layers - - @property - def n_kv_heads(self): - return self.args.num_hidden_layers - - def make_cache(self): - # return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] - return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] - # return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))] - - def generate(self, input_ids: mx.array, n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): - self.eval() - - if input_ids.ndim == 1: - input_ids = mx.expand_dims(input_ids, 0) - - caches = self.make_cache() - - for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): - next_token_logits, caches = self(input_ids[:, i], caches) - - if i+1 >= input_ids.shape[1]: - - if top_k is not None: - values = mx.topk(next_token_logits, k=top_k) - mask = next_token_logits < (values[:, 0, None]) - next_token_logits = mx.where(mask, -5000, next_token_logits) - - if sample and temperature > 0: - next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) - else: - next_token = mx.argmax(next_token_logits, axis=-1)[:, None] - - input_ids = mx.concatenate([input_ids, next_token], axis=1) - - self.train() - return input_ids - - def generate_step(self, input_ids: mx.array, sample: bool = True, temperature: float = 1.0, top_k: int = None): - self.eval() - - if input_ids.ndim == 1: - input_ids = mx.expand_dims(input_ids, 0) - - caches = self.make_cache() - - # Generate the next token logits - next_token_logits, caches = self(input_ids, caches) - - # Apply top_k filtering if specified - if top_k is not None: - values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to highest - mask = next_token_logits < (values[:, -1, None]) - next_token_logits = mx.where(mask, -5000, next_token_logits) # -mx.inf is problematic for now - - # Sample the next token or take the argmax based on the temperature - if sample and temperature > 0: - next_token = mx.random.categorical(next_token_logits * (1 / temperature), num_samples=1) - else: - next_token = mx.argmax(next_token_logits, axis=-1)[:, None] - - # Concatenate the next token to the input_ids - input_ids = mx.concatenate([input_ids, next_token], axis=1) - - self.train() - return input_ids, caches \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba-tiny-pld.py b/llms/mlx_lm/models/mamba-tiny-pld.py deleted file mode 100644 index 8713978d5..000000000 --- a/llms/mlx_lm/models/mamba-tiny-pld.py +++ /dev/null @@ -1,154 +0,0 @@ -from dataclasses import dataclass -from typing import Optional, Union - -import math - -import torch - -# import tokenizer - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - vocab_size: int - n_layer: int - use_conv_bias: bool - expand: int - pad_vocab_size_multiple: int - conv_kernel: int - d_model: int - state_size: int - d_inner: int - initializer_range: float - use_bias: bool - time_step_init_scheme: str - time_step_max: float - time_step_min: float - time_step_floor: float - dt_rank: Union[int, str] = "auto" - - def __post_init__(self): - self.d_inner = self.expand * self.d_model - if self.n_layer is None: - self.n_layer = self.num_hidden_layers - if self.d_model is None: - self.d_model = self.hidden_size - if self.dt_rank == "auto": - self.dt_rank = math.ceil(self.d_model / 16) - -class MambaBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.in_proj = nn.Linear(args.d_model, 2 * args.d_inner, bias=args.use_bias) - # self.conv1d = DepthWiseConv1d(channels=args.d_inner, kernel_size=args.conv_kernel, bias=args.use_conv_bias, padding=args.conv_kernel-1) - self.conv1d = nn.Conv1d( - in_channels=args.d_inner, - out_channels=args.d_inner, - bias=args.use_conv_bias, - kernel_size=args.conv_kernel, - # groups=args.d_inner, - padding=args.conv_kernel - 1, - ) - self.x_proj = nn.Linear(args.d_inner, args.dt_rank + 2 * args.state_size, bias=False) - self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True) - - A = mx.repeat(mx.arange(1, args.state_size + 1).reshape([1, 16]), repeats=args.d_inner) - - - self.A_log = mx.log(A) - self.D = mx.ones([args.d_inner]) - - self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.use_bias) - - self.norm = nn.RMSNorm(args.d_model) - - def ssm(self, x): - d_in, N = self.A_log.shape - A = -mx.exp(self.A_log.float()) - D = self.D.float() - delta, B, C = self.x_proj(x).split(split_size=[self.config.dt_rank, N, N], dim=-1) - delta = nn.softplus(self.dt_proj(delta)) - return self.selective_scan(x, delta, A, B, C, D) - - def selective_scan(self, u, delta, A, B, C, D): - (b, l, d_in) = u.shape - n = A.shape[1] - deltaA = mx.exp(mx.einsum(delta, A, 'b l d_in, d_in n -> b d_in l n')) - deltaB_u = mx.einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b d_in l n') - - # Perform selective scan (see scan_SSM() in The Annotated S4 [2]) - x = mx.zeros((b, d_in, n), device=deltaA.device) - ys = [] - for i in range(l): - x = deltaA[:, :, i] * x + deltaB_u[:, :, i] - y = mx.einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') - ys.append(y) - y = mx.stack(ys, dim=1) # shape (b, l, d_in) - - y = y + u * D - - return y - - def __call__(self, x): - _, L, _ = x.shape - x, r = self.in_proj(x).split([self.args.d_inner, self.args.d_inner], axis=-1) - - x = mx.reshape(x, 'b l d_in -> b d_in l') - x = self.conv1d(x)[:, :, :L] - x = mx.rearrange(x, 'b d_in l -> b l d_in') - out = self.ssm(nn.silu(x)) * nn.silu(r) - return self.out_proj(out) + x - -class MambaModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.embedding = nn.Embedding(args.vocab_size, args.d_model) - self.layers = [MambaBlock(args) for _ in range(args.n_layer)] - self.norm_f = nn.RMSNorm(args.d_model) - - def __call__(self, inputs: mx.array_equal): - tokens = self.embedding(inputs) - for i, layer in enumerate(self.layers): - h = layer(tokens) - h = self.norm_f(h) - return h - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.model = self.backbone = MambaModel(args) - self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False) - self.lm_head.weight = self.model.embedding.weight - - def __call__(self, inputs: mx.array): - h = self.backbone(inputs) - return self.lm_head(h) - - @property - def layers(self): - return self.backbone.layers - - # def sanitize(self, weights): - # exclude_patterns = [ - # 'backbone.layers.mixer.A_log', - # 'backbone.layers.mixer.conv1d.weight', - # 'backbone.layers.mixer.dt_proj.weight', - # 'backbone.layers.mixer.in_proj.weight', - # 'backbone.layers.mixer.dt_proj.bias', - # 'backbone.layers.mixer.conv1d.bias', - # 'backbone.layers.mixer.D' - # ] - # return { - # k: v for k, v in weights.items() - # if not any(pattern in k for pattern in exclude_patterns) - # } \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba-torch.py b/llms/mlx_lm/models/mamba-torch.py deleted file mode 100644 index 84deb4d3f..000000000 --- a/llms/mlx_lm/models/mamba-torch.py +++ /dev/null @@ -1,145 +0,0 @@ -import torch.nn as nn -import torch -from configuration_mamba import MambaConfig -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.modeling_utils import PreTrainedModel -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -import math -import json -import torch -import torch.nn as nn -import torch.nn.functional as F -from dataclasses import dataclass -from einops import rearrange, repeat, einsum -from typing import Optional , Union ,Tuple -l -# Dear contributors of the https://github.com/johnma2006/mamba-minimal/tree/master repository, special thanks to Albert Gu and Tri Dao for their articles. (https://arxiv.org/abs/2312.00752) - - -class MambaRMSNorm(nn.Module): - def __init__(self, - d_model: int, - eps: float = 1e-5): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(d_model)) - def forward(self, x): - output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight - return output - - -class MambaBlock(nn.Module): - def __init__(self, config: MambaConfig): - """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1].""" - super().__init__() - self.config = config - - self.in_proj = nn.Linear(config.d_model, config.d_inner * 2, bias=config.bias) - - self.conv1d = nn.Conv1d( - in_channels=config.d_inner, - out_channels=config.d_inner, - bias=config.conv_bias, - kernel_size=config.d_conv, - groups=config.d_inner, - padding=config.d_conv - 1, - ) - - # x_proj takes in `x` and outputs the input-specific Δ, B, C - self.x_proj = nn.Linear(config.d_inner, config.dt_rank + config.d_state * 2, bias=False) - - # dt_proj projects Δ from dt_rank to d_in - self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True) - - A = repeat(torch.arange(1, config.d_state + 1), 'n -> d n', d=config.d_inner) - self.A_log = nn.Parameter(torch.log(A)) - self.D = nn.Parameter(torch.ones(config.d_inner)) - self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias) - # self.norm = MambaRMSNorm(config.d_model) - - def forward(self, x): - (b, l, d) = x.shape - x_copy = x # There was a separate class for residual, I deleted that part and added it here. - x, res = self.in_proj(self.norm(x)).split(split_size=[self.config.d_inner, self.config.d_inner], dim=-1) - - x = rearrange(x, 'b l d_in -> b d_in l') - x = self.conv1d(x)[:, :, :l] - x = rearrange(x, 'b d_in l -> b l d_in') - - x = F.silu(x) - - y = self.ssm(x) - - y = y * F.silu(res) - - output = self.out_proj(y) + x_copy - - return output - - - def ssm(self, x): - (d_in, n) = self.A_log.shape - - A = -torch.exp(self.A_log.float()) # shape (d_in, n) - D = self.D.float() - - x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n) - - (delta, B, C) = x_dbl.split(split_size=[self.config.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n) - delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in) - - y = self.selective_scan(x, delta, A, B, C, D) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2] - - return y - - - def selective_scan(self, u, delta, A, B, C, D): - (b, l, d_in) = u.shape - n = A.shape[1] - deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b d_in l n')) - deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b d_in l n') - x = torch.zeros((b, d_in, n), device=deltaA.device) - ys = [] - for i in range(l): - x = deltaA[:, :, i] * x + deltaB_u[:, :, i] - y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') - ys.append(y) - y = torch.stack(ys, dim=1) # shape (b, l, d_in) - - y = y + u * D - - return y - - -class MambaModel(MambaPreTrainedModel): - def __init__(self, config: MambaConfig): - super().__init__(config) - self.config = config - - self.embedding = nn.Embedding(config.vocab_size, config.d_model) - self.layers = nn.ModuleList([MambaBlock(config) for _ in range(config.n_layer)]) - self.norm_f = MambaRMSNorm(config.d_model) - - def forward(self, input_ids: torch.LongTensor = None): - x = self.embedding(input_ids) - all_hidden_states = list() - for layer in self.layers: - x = layer(x) - all_hidden_states.append(x) - return self.norm_f(x) - - -class MambaForCausalLM(MambaPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = MambaModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) - self.lm_head.weight = self.model.embedding.weight - - - def forward(self, input_ids: torch.LongTensor = None): - hidden_states = self.model(input_ids=input_ids) - return self.lm_head(hidden_states) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py deleted file mode 100644 index 340cce831..000000000 --- a/llms/mlx_lm/models/mamba.py +++ /dev/null @@ -1,294 +0,0 @@ -from dataclasses import dataclass -from typing import Optional - -import math - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - vocab_size: int - hidden_size: int - intermediate_size: int - state_size: int - num_hidden_layers: int - layer_norm_epsilon: float - expand: int - conv_kernel: int - use_bias: bool - use_conv_bias: bool - initializer_range: float - time_step_rank: int - time_step_scale: float - time_step_min: float - time_step_max: float - time_step_init_scheme: str - time_step_floor: float - rescale_prenorm_residual: bool - use_cache: bool - use_mambapy: bool = False - dt_rank: str = "auto" - - def __post_init__(self): - if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'): - self.hidden_size = self.d_model - if not hasattr(self, 'intermediate_size') and hasattr(self, 'd_inner'): - self.intermediate_size = self.d_inner - if not hasattr(self, 'state_size') and hasattr(self, 'd_state'): - self.state_size = self.d_state - if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'): - self.num_hidden_layers = self.n_layer - if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'): - self.num_hidden_layers = self.n_layers - if not hasattr(self, 'conv_kernel') and hasattr(self, 'd_conv'): - self.conv_kernel = self.d_conv - if not hasattr(self, 'use_bias') and hasattr(self, 'bias'): - self.use_bias = self.bias - if not hasattr(self, 'use_conv_bias') and hasattr(self, 'conv_bias'): - self.use_conv_bias = self.conv_bias - - self.intermediate_size = self.expand * self.hidden_size - if self.dt_rank == "auto": - self.dt_rank = math.ceil(self.hidden_size / 16) - -class DepthWiseConv1d(nn.Module): - def __init__(self, channels, kernel_size, bias, padding): - super().__init__() - self.channels = int(channels) - self.kernel_size = int(kernel_size) - self.bias = bias - self.padding = padding - self.weight = mx.random.normal(shape=(self.channels, 1, self.kernel_size)) - scale = math.sqrt(1.0 / (self.channels * self.kernel_size)) - self.weight *= scale # Ensure scaling is applied correctly - if bias: - self.bias = mx.zeros((self.channels,)) - else: - self.bias = None - - def __call__(self, x): - B, D, L = x.shape - assert D == self.channels, f"Input channels ({D}) must match the initialized channels ({self.channels})." - print("FORWARD PASS THROUGH CONV") - print(self.kernel_size) - print(self.weight) - out = nn.Conv1d(x, self.weight, kernel_size=self.kernel_size, bias=self.bias, padding=self.padding) - return out - - -def clamp(x, min=None, max=None): - if min is not None: - mask_lower = x < min - if max is not None: - mask_upper = x > max - if min is not None: - if max is not None: - return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) - return mx.where(mask_lower, min, x) - return mx.where(mask_upper, max, x) - - -class MambaBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - - self.hidden_size = args.hidden_size - self.ssm_state_size = args.state_size - self.conv_kernel_size = args.conv_kernel - self.intermediate_size = args.intermediate_size - self.time_step_rank = int(args.time_step_rank) - self.use_conv_bias = args.use_conv_bias - - self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) - - self.conv1d = DepthWiseConv1d( - channels=self.intermediate_size, - kernel_size=self.conv_kernel_size, - bias=self.use_conv_bias, - padding=self.conv_kernel_size-1 - ) - - self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) - self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) - - dt_init_std = args.dt_rank**-0.5 * args.state_size - if args.time_step_init_scheme == "constant": - self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) - elif args.time_step_init_scheme == "random": - self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) - else: - raise NotImplementedError - - dt = clamp(mx.exp( - mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min) - ), min=args.time_step_floor) - inv_dt = dt + mx.log1p(-mx.exp(-dt)) - self.dt_proj.bias = inv_dt - - A = mx.repeat(mx.arange(1., self.ssm_state_size + 1.).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) - self.A_log = mx.log(A) - self.D = mx.ones([self.intermediate_size]) - - self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) - - def ssm(self, x, h): - A = -mx.exp(self.A_log) # (ED, N) - D = self.D - - deltaBC = self.x_proj(x) # (B, dt_rank+2*N) - - delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) - delta = nn.softplus(self.dt_proj(delta)) # (B, ED) - - deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) - deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N) - - BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) - - if h is None: - h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) - - h = deltaA * h + BX # (B, ED, N) - - y = (h @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1) - - y = y + D * x - return y, h - - def __call__(self, x, cache): - h, inputs = cache - x, z = self.in_proj(x).split(indices_or_sections=2, axis=-1) # (B, ED), (B, ED) - # x branch - x_cache = mx.expand_dims(x, 1) # (B, 1, ED) - # Ensure inputs has the correct shape - if inputs.ndim == 2: - inputs = mx.expand_dims(inputs, 1) # Add a dimension if it's missing - - print(f"inputs shape: {inputs.shape}") - print(f"x_cache shape: {x_cache.shape}") - - conv_input = mx.concatenate([inputs, x_cache], axis=1) # (B, d_conv, ED) <---------- Here is the problem ValueError: [concatenate] All the input arrays must have the same number of dimensions. However, got arrays with dimensions 3 and 4. ||| inputs shape: (1, 3, 1536) x_cache shape: (1, 1, 5, 1536) - x = self.conv1d(conv_input)[:, -1, :] # (B, ED) - y, h = self.ssm(nn.silu(x), h) - output = y * nn.silu(z) # * z branch - # prepare cache for next call - inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, d_conv-1, ED) - return self.out_proj(output), (h, inputs) # (B, D), cache - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.mixer = MambaBlock(args) - self.norm = nn.RMSNorm(args.hidden_size) - - def __call__(self, inputs: mx.array, cache): - output, cache = self.mixer(self.norm(inputs), cache) - output = output + inputs - return output, cache - - -class Mamba(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] - self.norm_f = nn.RMSNorm(args.hidden_size) - - def __call__(self, inputs: mx.array, cache): - tokens = self.embeddings(inputs) - for i, layer in enumerate(self.layers): - h, cache[i] = layer(tokens, cache[i]) - h = self.norm_f(h) - return h, cache - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.backbone = Mamba(args) - - def __call__(self, inputs: mx.array, cache=None): - out, cache = self.backbone(inputs, cache) - out = self.backbone.embeddings.as_linear(out) - return out, cache - - @property - def layers(self): - return self.backbone.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_hidden_layers - - @property - def n_kv_heads(self): - return self.args.num_hidden_layers - - def make_cache(self): - # return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] - return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))] - - def generate(self, input_ids: mx.array, n_tokens_to_gen: int = 50, sample: bool = True, temperature: float = 1.0, top_k: int = None): - self.eval() - - if input_ids.ndim == 1: - input_ids = mx.expand_dims(input_ids, 0) - - caches = self.make_cache() - - for i in range(input_ids.shape[1] + n_tokens_to_gen - 1): - next_token_logits, caches = self(input_ids[:, i], caches) - - if i+1 >= input_ids.shape[1]: - - if top_k is not None: - values = mx.topk(next_token_logits, k=top_k) - mask = next_token_logits < (values[:, 0, None]) - next_token_logits = mx.where(mask, -5000, next_token_logits) - - if sample and temperature > 0: - next_token = mx.random.categorical(next_token_logits * (1/temperature), num_samples=1) - else: - next_token = mx.argmax(next_token_logits, axis=-1)[:, None] - - input_ids = mx.concatenate([input_ids, next_token], axis=1) - - self.train() - return input_ids - - def generate_step(self, input_ids: mx.array, sample: bool = True, temperature: float = 1.0, top_k: int = None): - self.eval() - - if input_ids.ndim == 1: - input_ids = mx.expand_dims(input_ids, 0) - - caches = self.make_cache() - - # Generate the next token logits - next_token_logits, caches = self(input_ids, caches) - - # Apply top_k filtering if specified - if top_k is not None: - values = mx.topk(next_token_logits, k=top_k) # (1, k) ordered from lowest to highest - mask = next_token_logits < (values[:, -1, None]) - next_token_logits = mx.where(mask, -5000, next_token_logits) # -mx.inf is problematic for now - - # Sample the next token or take the argmax based on the temperature - if sample and temperature > 0: - next_token = mx.random.categorical(next_token_logits * (1 / temperature), num_samples=1) - else: - next_token = mx.argmax(next_token_logits, axis=-1)[:, None] - - # Concatenate the next token to the input_ids - input_ids = mx.concatenate([input_ids, next_token], axis=1) - - self.train() - return input_ids, caches \ No newline at end of file From 236acb16a87d2afe5622f9786968e80624eecd68 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 4 Sep 2024 22:08:32 +0200 Subject: [PATCH 24/40] Fixing the Batching Depfwise Comnvolution and multi token input --- llms/mlx_lm/models/mamba.py | 272 ++++++++++++++++---------------- llms/mlx_lm/models/mamba1.py | 293 ----------------------------------- llms/mlx_lm/models/mamba2.py | 258 ------------------------------ 3 files changed, 135 insertions(+), 688 deletions(-) delete mode 100644 llms/mlx_lm/models/mamba1.py delete mode 100644 llms/mlx_lm/models/mamba2.py diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 172aab685..86e4977f1 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional import math @@ -32,6 +33,8 @@ class ModelArgs(BaseModelArgs): use_cache: bool use_mambapy: bool = False dt_rank: str = "auto" + tie_word_embeddings: bool = True + def __post_init__(self): if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'): @@ -40,12 +43,6 @@ def __post_init__(self): self.intermediate_size = self.d_inner if not hasattr(self, 'state_size') and hasattr(self, 'd_state'): self.state_size = self.d_state - if not hasattr(self, 'time_step_min') and hasattr(self, 'dt_min'): - self.time_step_min = self.dt_min - if not hasattr(self, 'time_step_max') and hasattr(self, 'dt_max'): - self.time_step_min = self.dt_max - if not hasattr(self, 'time_step_floor') and hasattr(self, 'dt_init_floor'): - self.time_step_min = self.dt_init_floor if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'): self.num_hidden_layers = self.n_layer if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'): @@ -61,20 +58,7 @@ def __post_init__(self): if self.dt_rank == "auto": self.dt_rank = math.ceil(self.hidden_size / 16) - -def clamp(x, min=None, max=None): - if min is not None: - mask_lower = x < min - if max is not None: - mask_upper = x > max - if min is not None: - if max is not None: - return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) - return mx.where(mask_lower, min, x) - return mx.where(mask_upper, max, x) - - -class Conv1d(nn.Module): +class DepthWiseConv1d(nn.Module): def __init__( self, channels: int, @@ -85,92 +69,103 @@ def __init__( super().__init__() self.channels = channels self.kernel_size = kernel_size - self.use_bias = bias self.padding = padding - - # Change the weight initialization to match the expected shape - self.weight = mx.zeros((kernel_size, 1, channels)) - if self.use_bias: + self.weight = mx.random.normal((channels, 1, kernel_size)) + if bias: self.bias = mx.zeros((channels,)) else: self.bias = None def __call__(self, x, cache=None): - # Use the weight directly without transposing - w = self.weight + B, L, C = x.shape + assert C == self.channels, f"Input channels ({C}) must match the initialized channels ({self.channels})." + + w = self.weight # Shape: (C, 1, K) + K = self.kernel_size + total_padding = self.padding + K - 1 + if cache is not None: l = [] - # Pad the cache if needed - if cache.shape[1] < self.kernel_size - 1: - l.append( - mx.zeros( - (x.shape[0], self.kernel_size - 1 - cache.shape[1], self.channels), dtype=x.dtype - ) - ) + if cache.shape[1] < total_padding: + l.append(mx.zeros((B, total_padding - cache.shape[1], C), dtype=x.dtype)) l.extend([cache, x]) x = mx.concatenate(l, axis=1) - y = mx.conv_general(x, w, padding=([0], [0]), groups=self.channels) else: - y = mx.conv_general(x, w, padding=([self.padding], [0]), groups=self.channels) + x = mx.pad(x, [(0, 0), (total_padding, 0), (0, 0)]) - # The cache is always kernel_size - 1 - cache = x[:, max(x.shape[1] - self.kernel_size + 1, 0) :, :] + # Manual depthwise convolution + output = [] + for i in range(K): + slice = x[:, i:i+L, :] + output.append(slice * w[:, 0, i]) + y = mx.sum(mx.stack(output), axis=0) + + # The cache is always total_padding + cache = x[:, max(x.shape[1] - total_padding, 0):, :] - if self.use_bias: - y = y + self.bias + if self.bias is not None: + y = y + self.bias.reshape(1, 1, -1) return y, cache +def clamp(x, min=None, max=None): + if min is not None: + mask_lower = x < min + if max is not None: + mask_upper = x > max + if min is not None: + if max is not None: + return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) + return mx.where(mask_lower, min, x) + return mx.where(mask_upper, max, x) + + class MambaBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args - # projects block input from D to 2*ED (two branches) - self.in_proj = nn.Linear(args.hidden_size, 2 * args.intermediate_size, bias=args.use_bias) + self.hidden_size = args.hidden_size + self.ssm_state_size = args.state_size + self.conv_kernel_size = args.conv_kernel + self.intermediate_size = args.intermediate_size + self.time_step_rank = int(args.time_step_rank) + self.use_conv_bias = args.use_conv_bias - # short 1d conv over time - self.conv1d = Conv1d( - channels=args.intermediate_size, - kernel_size=args.conv_kernel, - bias=args.use_conv_bias, - padding=args.conv_kernel-1 - ) + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) - # projects x to input-dependent Δ, B, C - self.x_proj = nn.Linear(args.intermediate_size, args.dt_rank + 2 * args.state_size, bias=False) + self.conv1d = DepthWiseConv1d( + channels=self.intermediate_size, + kernel_size=self.conv_kernel_size, + bias=self.use_conv_bias, + padding=self.conv_kernel_size-1 + ) - # projects Δ from dt_rank to intermediate_size - self.dt_proj = nn.Linear(args.dt_rank, args.intermediate_size, bias=True) + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) - # dt initialization - # dt weights - dt_init_std = args.dt_rank**-0.5 * args.state_size - + dt_init_std = args.time_step_rank**-0.5 * args.state_size if args.time_step_init_scheme == "constant": self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) elif args.time_step_init_scheme == "random": self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) else: raise NotImplementedError - - # dt bias + dt = clamp(mx.exp( mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min) ), min=args.time_step_floor) inv_dt = dt + mx.log1p(-mx.exp(-dt)) self.dt_proj.bias = inv_dt - # S4D real initialization - A = mx.repeat(mx.arange(1., 16 + 1.).reshape([1, 16]), repeats=args.intermediate_size, axis=0) - self.A_log = mx.log(A) # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ? - self.D = mx.ones([args.intermediate_size]) + A = mx.repeat(mx.arange(1., self.ssm_state_size + 1.).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) + self.A_log = mx.log(A) + self.D = mx.ones([self.intermediate_size]) - # projects block output from ED back to D - self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) - def ssm(self, x, h): + def ssm_step(self, x, h): # x : (B, ED) # h : (B, ED, N) @@ -182,7 +177,7 @@ def ssm(self, x, h): deltaBC = self.x_proj(x) # (B, dt_rank+2*N) - delta, B, C = mx.split(deltaBC, indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) + delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) delta = nn.softplus(self.dt_proj(delta)) # (B, ED) deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) @@ -191,51 +186,55 @@ def ssm(self, x, h): BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) if h is None: - h = mx.zeros([x.shape[0], self.args.hidden_size, self.args.state_size]) # (B, ED, N) + h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) h = deltaA * h + BX # (B, ED, N) y = (h @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1) y = y + D * x - + return y, h def __call__(self, x, cache): - # x : (B, D) - # cache : (h, inputs) - # h : (B, ED, N) - # inputs : (B, conv_kernel-1, ED) - - # y : (B, D) - # cache : (h, inputs) - + # x : (B, T, D) where T is the number of tokens (5 in this case) + # cache : (h, inputs) + # h : (B, ED, N) + # inputs : (B, d_conv-1, ED) + h, inputs = cache - - print("Input shape:", x.shape) - xz = self.in_proj(x) # (B, 2*ED) - xz = xz.reshape(x.shape[0], -1) # Ensure shape is (B, 2*ED) - print("After in_proj shape:", xz.shape) - x, z = xz.split(indices_or_sections=2, axis=1) # (B, ED), (B, ED) + B, T, D = x.shape + + outputs = [] + for t in range(T): + xt = x[:, t, :] # (B, D) + xz = self.in_proj(xt) # (B, 2*ED) + x_t, z_t = xz.split(indices_or_sections=2, axis=1) # (B, ED), (B, ED) + + # x branch + x_cache = mx.expand_dims(x_t, 1) # (B, 1, ED) + conv_input = mx.concatenate([inputs, x_cache], axis=1) # (B, d_conv, ED) + conv_out, new_inputs = self.conv1d(conv_input) # (B, d_conv, ED), (B, d_conv-1, ED) + x_t = conv_out[:, -1, :] # (B, ED) - # x branch - x_cache = mx.expand_dims(x, 1) - x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.args.conv_kernel-1, :] # (B, ED) + x_t = nn.silu(x_t) + y_t, h = self.ssm_step(x_t, h) - x = nn.silu(x) - y, h = self.ssm_step(x, h) + # z branch + z_t = nn.silu(z_t) - # z branch - z = nn.silu(z) + output_t = y_t * z_t + output_t = self.out_proj(output_t) # (B, D) + outputs.append(output_t) - output = y * z - output = self.out_proj(output) # (B, D) + # Update inputs for next token + inputs = new_inputs - # prepare cache for next call - inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, conv_kernel-1, ED) + output = mx.stack(outputs, axis=1) # (B, T, D) cache = (h, inputs) - + return output, cache + class ResidualBlock(nn.Module): def __init__(self, args: ModelArgs): @@ -243,19 +242,12 @@ def __init__(self, args: ModelArgs): self.mixer = MambaBlock(args) self.norm = nn.RMSNorm(args.hidden_size) - def __call__(self, inputs: mx.array, cache): - # x : (B, D) - # cache : (h, inputs) - # h : (B, ED, N) - # inputs: (B, conv_kernel-1, ED) - - # output : (B, D) - # cache : (h, inputs) - - output, cache = self.mixer(self.norm(inputs), cache) - output = output + inputs + def __call__(self, x: mx.array, cache): + output, cache = self.mixer(self.norm(x), cache) + output = output + x return output, cache + class Mamba(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -263,22 +255,11 @@ def __init__(self, args: ModelArgs): self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] self.norm_f = nn.RMSNorm(args.hidden_size) - def __call__(self, tokens: mx.array, caches): - # tokens : (B, L) - - # logits : (B, L, vocab_size) - - x = self.embeddings(tokens) - - # x : (B, L, D) - # caches : [cache(layer) for all layers], cache : (h, inputs) - - # y : (B, L, D) - # caches : [cache(layer) for all layers], cache : (h, inputs) - + def __call__(self, x: mx.array, caches): + x = self.embeddings(x) + print(x.shape) for i, layer in enumerate(self.layers): x, caches[i] = layer(x, caches[i]) - return x, caches @@ -289,10 +270,39 @@ def __init__(self, args: ModelArgs): self.model_type = args.model_type self.backbone = Mamba(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + def __call__(self, inputs: mx.array, cache=None): - out, cache = self.backbone(inputs, cache) - # out = self.backbone.embeddings.as_linear(out) - return out, cache + # inputs : (B, T) where T is the number of tokens + # caches : [cache(layer) for all layers], cache : (h, inputs) + + if inputs.ndim == 1: + inputs = mx.expand_dims(inputs, 0) # Add batch dimension if not present + + B, T = inputs.shape + x = self.backbone.embeddings(inputs) # (B, T, D) + + for i, layer in enumerate(self.backbone.layers): + x, cache[i] = layer(x, cache[i]) + + x = self.backbone.norm_f(x) + + if self.args.tie_word_embeddings: + logits = self.backbone.embeddings.as_linear(x) + else: + logits = self.lm_head(x) + + print(f"Logits shape: {logits.shape}") + # logits : (B, T, vocab_size) + print(logits) + + return logits, cache + + def make_cache(self): + B = 1 # Assuming batch size of 1 for simplicity + return [(None, mx.zeros((B, self.args.conv_kernel-1, self.args.intermediate_size))) + for _ in range(self.args.num_hidden_layers)] @property def layers(self): @@ -306,17 +316,5 @@ def head_dim(self): def n_kv_heads(self): return self.args.num_hidden_layers - def make_cache(self): - return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] - - def sanitize(self, weights): - for key, value in weights.items(): - if "mixer.conv1d.weight" in key: - # Ensure the weight is in the shape (kernel_size, 1, channels) - if value.shape != (self.args.conv_kernel, 1, self.args.intermediate_size): - weights[key] = value.reshape(self.args.conv_kernel, 1, self.args.intermediate_size) - elif key == "backbone.embeddings.weight": - # Ensure the embedding weight is in the shape (vocab_size, hidden_size) - if value.shape != (self.args.vocab_size, self.args.hidden_size): - weights[key] = value.T - return weights \ No newline at end of file + # def make_cache(self): + # return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba1.py b/llms/mlx_lm/models/mamba1.py deleted file mode 100644 index 0b64f967e..000000000 --- a/llms/mlx_lm/models/mamba1.py +++ /dev/null @@ -1,293 +0,0 @@ -from dataclasses import dataclass - -import math - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - vocab_size: int - hidden_size: int - intermediate_size: int - state_size: int - num_hidden_layers: int - layer_norm_epsilon: float - expand: int - conv_kernel: int - use_bias: bool - use_conv_bias: bool - initializer_range: float - time_step_rank: int - time_step_scale: float - time_step_min: float - time_step_max: float - time_step_init_scheme: str - time_step_floor: float - rescale_prenorm_residual: bool - use_cache: bool - use_mambapy: bool = False - dt_rank: str = "auto" - - def __post_init__(self): - if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'): - self.hidden_size = self.d_model - if not hasattr(self, 'intermediate_size') and hasattr(self, 'd_inner'): - self.intermediate_size = self.d_inner - if not hasattr(self, 'state_size') and hasattr(self, 'd_state'): - self.state_size = self.d_state - if not hasattr(self, 'time_step_min') and hasattr(self, 'dt_min'): - self.time_step_min = self.dt_min - if not hasattr(self, 'time_step_max') and hasattr(self, 'dt_max'): - self.time_step_min = self.dt_max - if not hasattr(self, 'time_step_floor') and hasattr(self, 'dt_init_floor'): - self.time_step_min = self.dt_init_floor - if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'): - self.num_hidden_layers = self.n_layer - if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'): - self.num_hidden_layers = self.n_layers - if not hasattr(self, 'conv_kernel') and hasattr(self, 'd_conv'): - self.conv_kernel = self.d_conv - if not hasattr(self, 'use_bias') and hasattr(self, 'bias'): - self.use_bias = self.bias - if not hasattr(self, 'use_conv_bias') and hasattr(self, 'conv_bias'): - self.use_conv_bias = self.conv_bias - - self.intermediate_size = self.expand * self.hidden_size - if self.dt_rank == "auto": - self.dt_rank = math.ceil(self.hidden_size / 16) - - -def clamp(x, min=None, max=None): - if min is not None: - mask_lower = x < min - if max is not None: - mask_upper = x > max - if min is not None: - if max is not None: - return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) - return mx.where(mask_lower, min, x) - return mx.where(mask_upper, max, x) - - -class DepthWiseConv1d(nn.Module): - def __init__(self, channels, kernel_size, bias, padding): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.bias = bias - self.padding = padding - - self.conv1d = nn.Conv1d( - in_channels=channels, - out_channels=channels, - kernel_size=kernel_size, - bias=True, - padding=padding - ) - indices = mx.arange(channels) - mask = mx.zeros_like(self.conv1d.weight) - mask[indices, :, indices] = 1 - self.conv1d.weight *= mask - - def __call__(self, x): - return self.conv1d(x) - - -class MambaBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - - # projects block input from D to 2*ED (two branches) - self.in_proj = nn.Linear(args.hidden_size, 2 * args.intermediate_size, bias=args.use_bias) - - # short 1d conv over time - self.conv1d = DepthWiseConv1d( - channels=args.intermediate_size, - kernel_size=args.conv_kernel, - bias=args.use_conv_bias, - padding=args.conv_kernel-1 - ) - - # projects x to input-dependent Δ, B, C - self.x_proj = nn.Linear(args.intermediate_size, args.dt_rank + 2 * args.state_size, bias=False) - - # projects Δ from dt_rank to intermediate_size - self.dt_proj = nn.Linear(args.dt_rank, args.intermediate_size, bias=True) - - # dt initialization - # dt weights - dt_init_std = args.dt_rank**-0.5 * args.state_size - - if args.time_step_init_scheme == "constant": - self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) - elif args.time_step_init_scheme == "random": - self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) - else: - raise NotImplementedError - - # dt bias - dt = clamp(mx.exp( - mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min) - ), min=args.time_step_floor) - inv_dt = dt + mx.log1p(-mx.exp(-dt)) # inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - self.dt_proj.bias = inv_dt - - # S4D real initialization - A = mx.repeat(mx.arange(1., 16 + 1.).reshape([1, 16]), repeats=args.intermediate_size, axis=0) - self.A_log = mx.log(A) # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ? - self.D = mx.ones([args.intermediate_size]) - - # projects block output from ED back to D - self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) - - def ssm(self, x, h): - # x : (B, ED) - # h : (B, ED, N) - - # y : (B, ED) - # h : (B, ED, N) - - A = -mx.exp(self.A_log) # (ED, N) # todo : move out of step (timestep independent) - D = self.D - - deltaBC = self.x_proj(x) # (B, dt_rank+2*N) - - delta, B, C = mx.split(deltaBC, indices_or_sections=[self.args.dt_rank, self.args.dt_rank+self.args.state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) - delta = nn.softplus(self.dt_proj(delta)) # (B, ED) - - deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) - deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N) - - BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) - - if h is None: - h = mx.zeros([x.shape[0], self.args.hidden_size, self.args.state_size]) # (B, ED, N) - - h = deltaA * h + BX # (B, ED, N) - - y = (h @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1) - - y = y + D * x - - return y, h - - def __call__(self, x, cache): - # x : (B, D) - # cache : (h, inputs) - # h : (B, ED, N) - # inputs : (B, conv_kernel-1, ED) - - # y : (B, D) - # cache : (h, inputs) - - h, inputs = cache - - xz = self.in_proj(x) # (B, 2*ED) - x, z = xz.split(indices_or_sections=2, axis=1) # (B, ED), (B, ED) - - # x branch - x_cache = mx.expand_dims(x, 1) - x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.args.conv_kernel-1, :] # (B, ED) - - x = nn.silu(x) - y, h = self.ssm_step(x, h) - - # z branch - z = nn.silu(z) - - output = y * z - output = self.out_proj(output) # (B, D) - - # prepare cache for next call - inputs = mx.concatenate([inputs[:, 1:, :], x_cache], axis=1) # (B, conv_kernel-1, ED) - cache = (h, inputs) - - return output, cache - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.mixer = MambaBlock(args) - self.norm = nn.RMSNorm(args.hidden_size) - - def __call__(self, inputs: mx.array, cache): - # x : (B, D) - # cache : (h, inputs) - # h : (B, ED, N) - # inputs: (B, conv_kernel-1, ED) - - # output : (B, D) - # cache : (h, inputs) - - output, cache = self.mixer(self.norm(inputs), cache) - output = output + inputs - return output, cache - -class Mamba(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] - self.norm_f = nn.RMSNorm(args.hidden_size) - - def __call__(self, tokens: mx.array, caches): - # tokens : (B, L) - - # logits : (B, L, vocab_size) - - x = self.embeddings(tokens) - - # x : (B, L, D) - # caches : [cache(layer) for all layers], cache : (h, inputs) - - # y : (B, L, D) - # caches : [cache(layer) for all layers], cache : (h, inputs) - - for i, layer in enumerate(self.layers): - x, caches[i] = layer(x, caches[i]) - - return x, caches - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.backbone = Mamba(args) - - def __call__(self, inputs: mx.array, cache=None): - out, cache = self.backbone(inputs, cache) - # out = self.backbone.embeddings.as_linear(out) - return out, cache - - @property - def layers(self): - return self.backbone.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_hidden_layers - - @property - def n_kv_heads(self): - return self.args.num_hidden_layers - - def make_cache(self): - # return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] - return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))] - - def sanitize(self, weights): - new_weights = {} - for key, value in weights.items(): - if "mixer.conv1d.weight" in key: - weights[key] = value.T - new_key = key.replace('mixer.conv1d', 'mixer.conv1d.conv1d') - new_weights[new_key] = value - return new_weights \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py deleted file mode 100644 index 04f67d050..000000000 --- a/llms/mlx_lm/models/mamba2.py +++ /dev/null @@ -1,258 +0,0 @@ -from dataclasses import dataclass - -import math - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - vocab_size: int - hidden_size: int - intermediate_size: int - state_size: int - num_hidden_layers: int - layer_norm_epsilon: float - expand: int - conv_kernel: int - use_bias: bool - use_conv_bias: bool - initializer_range: float - time_step_rank: int - time_step_scale: float - time_step_min: float - time_step_max: float - time_step_init_scheme: str - time_step_floor: float - rescale_prenorm_residual: bool - use_cache: bool - use_mambapy: bool = False - dt_rank: str = "auto" - - def __post_init__(self): - if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'): - self.hidden_size = self.d_model - if not hasattr(self, 'intermediate_size') and hasattr(self, 'd_inner'): - self.intermediate_size = self.d_inner - if not hasattr(self, 'state_size') and hasattr(self, 'd_state'): - self.state_size = self.d_state - if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'): - self.num_hidden_layers = self.n_layer - if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'): - self.num_hidden_layers = self.n_layers - if not hasattr(self, 'conv_kernel') and hasattr(self, 'd_conv'): - self.conv_kernel = self.d_conv - if not hasattr(self, 'use_bias') and hasattr(self, 'bias'): - self.use_bias = self.bias - if not hasattr(self, 'use_conv_bias') and hasattr(self, 'conv_bias'): - self.use_conv_bias = self.conv_bias - - self.intermediate_size = self.expand * self.hidden_size - if self.dt_rank == "auto": - self.dt_rank = math.ceil(self.hidden_size / 16) - - -def clamp(x, min=None, max=None): - if min is not None: - mask_lower = x < min - if max is not None: - mask_upper = x > max - if min is not None: - if max is not None: - return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) - return mx.where(mask_lower, min, x) - return mx.where(mask_upper, max, x) - - -class DepthWiseConv1d(nn.Module): - def __init__(self, channels, kernel_size, bias, padding): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.padding = padding - self.weight = mx.random.normal(shape=(channels, 1, kernel_size)) - scale = math.sqrt(1.0 / (channels * kernel_size)) - self.weight *= scale - if bias: - self.bias = mx.zeros((channels,)) - else: - self.bias = None - - def __call__(self, x): - # x shape is (B, C, L) - B, C, L = x.shape - - # Pad the input - if self.padding > 0: - padding = [(0, 0), (0, 0), (self.padding, self.padding)] - x_padded = mx.pad(x, padding) - else: - x_padded = x - - # Perform depthwise convolution manually - out = [] - for i in range(L): - slice = x_padded[:, :, i:i+self.kernel_size] - out.append(mx.sum(slice * self.weight, axis=2)) - - out = mx.stack(out, axis=2) - - # Apply bias if present - if self.bias is not None: - out = out + self.bias.reshape(1, -1, 1) - - return out - - -class MambaBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - - self.hidden_size = args.hidden_size - self.ssm_state_size = args.state_size - self.conv_kernel_size = args.conv_kernel - self.intermediate_size = args.intermediate_size - self.time_step_rank = int(args.time_step_rank) - self.use_conv_bias = args.use_conv_bias - - self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) - - self.conv1d = DepthWiseConv1d( - channels=int(self.intermediate_size), - kernel_size=int(self.conv_kernel_size), - bias=self.use_conv_bias, - padding=int(self.conv_kernel_size - 1) - ) - - self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) - self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) - - dt_init_std = args.dt_rank**-0.5 * args.state_size - if args.time_step_init_scheme == "constant": - self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) - elif args.time_step_init_scheme == "random": - self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) - else: - raise NotImplementedError - - dt = clamp(mx.exp( - mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min) - ), min=args.time_step_floor) - inv_dt = dt + mx.log1p(-mx.exp(-dt)) - self.dt_proj.bias = inv_dt - - A = mx.repeat(mx.arange(1., self.ssm_state_size + 1.).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) - self.A_log = mx.log(A) - self.D = mx.ones([self.intermediate_size]) - - self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) - - def ssm(self, x, h): - A = -mx.exp(self.A_log) # (ED, N) - D = self.D - - deltaBC = self.x_proj(x) # (B, dt_rank+2*N) - - delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) - delta = nn.softplus(self.dt_proj(delta)) # (B, ED) - - deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) - deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N) - - BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) - - if h is None: - h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) - - h = deltaA * h + BX # (B, ED, N) - - y = mx.sum(h * mx.expand_dims(C, 1), axis=-1) # (B, ED) - - y = y + D * x - return y, h - - def __call__(self, x, cache): - h, inputs = cache - - x, z = self.in_proj(x).split(indices_or_sections=2, axis=-1) - - # x is now (B, L, C), we need (B, C, L) for conv1d - x_cache = x.transpose(0, 2, 1) - - if inputs is None: - inputs = mx.zeros((x.shape[0], self.intermediate_size, self.conv_kernel_size - 1)) - else: - inputs = inputs.transpose(0, 2, 1) # Change to (batch, channels, sequence) - - conv_input = mx.concatenate([inputs, x_cache], axis=2) - - x = self.conv1d(conv_input) - x = x[:, :, -1] # Take the last element of the sequence - - y, h = self.ssm(x, h) - output = y * nn.silu(z[:, -1, :]) - - # Update inputs for the next iteration - inputs = conv_input[:, :, 1:] - - return self.out_proj(output), (h, inputs.transpose(0, 2, 1)) - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.mixer = MambaBlock(args) - self.norm = nn.RMSNorm(args.hidden_size) - - def __call__(self, inputs: mx.array, cache): - output, cache = self.mixer(self.norm(inputs), cache) - output = output + inputs[:, -1, :] # Add residual only for the last time step - return output, cache - - -class Mamba(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] - self.norm_f = nn.RMSNorm(args.hidden_size) - - def __call__(self, inputs: mx.array, cache): - tokens = self.embeddings(inputs) - for i, layer in enumerate(self.layers): - h, cache[i] = layer(tokens, cache[i]) - h = self.norm_f(h) - return h, cache - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.backbone = Mamba(args) - - def __call__(self, inputs: mx.array, cache=None): - out, cache = self.backbone(inputs, cache) - out = self.backbone.embeddings.as_linear(out) - return out, cache - - @property - def layers(self): - return self.backbone.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_hidden_layers - - @property - def n_kv_heads(self): - return self.args.num_hidden_layers - - def make_cache(self): - # return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] - return [(None, mx.zeros([1, self.backbone.layers[0].mixer.conv_kernel_size-1, self.backbone.layers[0].mixer.intermediate_size])) for _ in range(len(self.backbone.layers))] \ No newline at end of file From de1fdc7fdf522809fff65e04fdca42d5af7185fe Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 4 Sep 2024 22:46:45 +0200 Subject: [PATCH 25/40] fixing generate and logits outputs --- llms/mlx_lm/models/base.py | 20 +++++++++++--------- llms/mlx_lm/models/mamba.py | 6 ------ llms/mlx_lm/tuner/utils.py | 1 - llms/mlx_lm/utils.py | 33 +++++++++++++++++---------------- 4 files changed, 28 insertions(+), 32 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index ca237014d..db568b03a 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -9,17 +9,19 @@ class MambaCache: - def __init__(self, batch_size, intermediate_size, ssm_state_size, conv_kernel_size): - self.h = mx.zeros((batch_size, intermediate_size, ssm_state_size)) - self.conv_states = mx.zeros((batch_size, conv_kernel_size - 1, intermediate_size)) + def __init__(self, num_layers, conv_state_size, ssm_state_size): + self.conv_states = [None for _ in range(num_layers)] + self.ssm_states = [None for _ in range(num_layers)] + self.offset = 0 - def update(self, new_h, new_conv_state): - self.h = new_h - self.conv_states = mx.concatenate([self.conv_states[:, 1:, :], new_conv_state], axis=1) + def update(self, layer_idx, conv_state, ssm_state): + self.conv_states[layer_idx] = conv_state + self.ssm_states[layer_idx] = ssm_state + self.offset += 1 - @classmethod - def init_cache(cls, batch_size, intermediate_size, ssm_state_size, conv_kernel_size): - return cls(batch_size, intermediate_size, ssm_state_size, conv_kernel_size) + @property + def state(self): + return self.conv_states, self.ssm_states class KVCache: diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 86e4977f1..30fa5a4cf 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional import math @@ -257,7 +256,6 @@ def __init__(self, args: ModelArgs): def __call__(self, x: mx.array, caches): x = self.embeddings(x) - print(x.shape) for i, layer in enumerate(self.layers): x, caches[i] = layer(x, caches[i]) return x, caches @@ -292,10 +290,6 @@ def __call__(self, inputs: mx.array, cache=None): logits = self.backbone.embeddings.as_linear(x) else: logits = self.lm_head(x) - - print(f"Logits shape: {logits.shape}") - # logits : (B, T, vocab_size) - print(logits) return logits, cache diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 15db25b50..235b54b1f 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -104,7 +104,6 @@ def to_lora(layer): "cohere", "minicpm", "deepseek", - "mamba" ]: keys = set(["self_attn.q_proj", "self_attn.v_proj"]) if model.model_type in ["mixtral", "phimoe"]: diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index cfbd9971e..6506a1c98 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -19,7 +19,7 @@ from transformers import PreTrainedTokenizer # Local imports -from .models.base import KVCache, RotatingKVCache, MambaCache +from .models.base import KVCache, RotatingKVCache from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import apply_lora_layers @@ -165,7 +165,7 @@ def generate_step( Args: prompt (mx.array): The input prompt. - model: The model to use for generation. + model (nn.Module): The model to use for generation. temp (float): The temperature for sampling, if 0 the argmax is used. Default: ``0``. repetition_penalty (float, optional): The penalty factor for repeating @@ -236,35 +236,36 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]: def _step(y): nonlocal repetition_context - logits = model(y[None], cache=cache) + if model.model_type == "mamba": + logits, _ = model(y[None], cache=cache) + else: + logits = model(y[None], cache=cache) logits = logits[:, -1, :] if repetition_penalty: logits = apply_repetition_penalty( logits, repetition_context, repetition_penalty ) - next_token, logprobs = sample(logits) - repetition_context.append(next_token.item()) + y, logprobs = sample(logits) + repetition_context.append(y.item()) else: - next_token, logprobs = sample(logits) + y, logprobs = sample(logits) if repetition_context_size: if len(repetition_context) > repetition_context_size: repetition_context = repetition_context[-repetition_context_size:] - - return next_token, logprobs.squeeze(0) + return y, logprobs.squeeze(0) - if hasattr(model, 'generate_step'): - y, logprobs = model.generate_step(prompt) - else: - y, logprobs = _step(y) + while y.size > prefill_step_size: + model(y[:prefill_step_size][None], cache=cache) + mx.eval([c.state for c in cache]) + y = y[prefill_step_size:] + + y, logprobs = _step(y) mx.async_eval(y) while True: - if hasattr(model, 'generate_step'): - next_y, next_logprobs = model.generate_step(y) - else: - next_y, next_logprobs = _step(y) + next_y, next_logprobs = _step(y) mx.async_eval(next_y) yield y.item(), logprobs y, logprobs = next_y, next_logprobs From 107575133e6824b9102f50ce6696fc10eab6d137 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 4 Sep 2024 22:58:55 +0200 Subject: [PATCH 26/40] Done! --- llms/mlx_lm/models/base.py | 1 + llms/mlx_lm/models/mamba.py | 88 ++++++++++++++----------------------- 2 files changed, 34 insertions(+), 55 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index db568b03a..92d14c5a1 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -23,6 +23,7 @@ def update(self, layer_idx, conv_state, ssm_state): def state(self): return self.conv_states, self.ssm_states + class KVCache: def __init__(self, head_dim, n_kv_heads): diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 30fa5a4cf..47ee4a817 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, MambaCache @dataclass @@ -58,54 +58,35 @@ def __post_init__(self): self.dt_rank = math.ceil(self.hidden_size / 16) class DepthWiseConv1d(nn.Module): - def __init__( - self, - channels: int, - kernel_size: int, - bias: bool = True, - padding: int = 0 - ): + def __init__(self, channels, kernel_size, bias=True, padding=0): super().__init__() self.channels = channels self.kernel_size = kernel_size self.padding = padding self.weight = mx.random.normal((channels, 1, kernel_size)) - if bias: - self.bias = mx.zeros((channels,)) - else: - self.bias = None + self.bias = mx.zeros((channels,)) if bias else None - def __call__(self, x, cache=None): + def __call__(self, x, conv_state=None): B, L, C = x.shape - assert C == self.channels, f"Input channels ({C}) must match the initialized channels ({self.channels})." - - w = self.weight # Shape: (C, 1, K) K = self.kernel_size - total_padding = self.padding + K - 1 - - if cache is not None: - l = [] - if cache.shape[1] < total_padding: - l.append(mx.zeros((B, total_padding - cache.shape[1], C), dtype=x.dtype)) - l.extend([cache, x]) - x = mx.concatenate(l, axis=1) - else: - x = mx.pad(x, [(0, 0), (total_padding, 0), (0, 0)]) - - # Manual depthwise convolution + + if conv_state is None: + conv_state = mx.zeros((B, K - 1, C)) + + x = mx.concatenate([conv_state, x], axis=1) + output = [] for i in range(K): slice = x[:, i:i+L, :] - output.append(slice * w[:, 0, i]) + output.append(slice * self.weight[:, 0, i]) y = mx.sum(mx.stack(output), axis=0) - - # The cache is always total_padding - cache = x[:, max(x.shape[1] - total_padding, 0):, :] if self.bias is not None: y = y + self.bias.reshape(1, 1, -1) - - return y, cache + + new_conv_state = x[:, -K+1:, :] + + return y, new_conv_state def clamp(x, min=None, max=None): @@ -164,36 +145,33 @@ def __init__(self, args: ModelArgs): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) - def ssm_step(self, x, h): - # x : (B, ED) - # h : (B, ED, N) + def ssm_step(self, x, ssm_state): + # x : (B, ED) + # ssm_state : (B, ED, N) - # y : (B, ED) - # h : (B, ED, N) + A = -mx.exp(self.A_log) # (ED, N) + D = self.D # (ED,) - A = -mx.exp(self.A_log) # (ED, N) # todo : move out of step (timestep independent) - D = self.D + deltaBC = self.x_proj(x) # (B, dt_rank+2*N) + delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) + delta = nn.softplus(self.dt_proj(delta)) # (B, ED) - deltaBC = self.x_proj(x) # (B, dt_rank+2*N) + deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) + deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N) - delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, dt_rank), (B, N), (B, N) - delta = nn.softplus(self.dt_proj(delta)) # (B, ED) + BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) - deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) - deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N) + if ssm_state is None: + ssm_state = mx.zeros((x.shape[0], self.intermediate_size, self.ssm_state_size)) # (B, ED, N) - BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) + new_ssm_state = deltaA * ssm_state + BX # (B, ED, N) - if h is None: - h = mx.zeros([x.shape[0], self.intermediate_size, self.ssm_state_size]) # (B, ED, N) + y = (new_ssm_state @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED) - h = deltaA * h + BX # (B, ED, N) + y = y + D * x # (B, ED) - y = (h @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1) - - y = y + D * x - - return y, h + return y, new_ssm_state + def __call__(self, x, cache): # x : (B, T, D) where T is the number of tokens (5 in this case) From fd3bd6d9aac8b12c2801d7ad8f8a76e6222d66d4 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 4 Sep 2024 23:00:25 +0200 Subject: [PATCH 27/40] Fixing the cache handling, generating works now trying training --- llms/mlx_lm/models/base.py | 16 ---------------- llms/mlx_lm/models/mamba.py | 2 +- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 92d14c5a1..1a5cd42ba 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -6,22 +6,6 @@ import mlx.core as mx import mlx.nn as nn - - -class MambaCache: - def __init__(self, num_layers, conv_state_size, ssm_state_size): - self.conv_states = [None for _ in range(num_layers)] - self.ssm_states = [None for _ in range(num_layers)] - self.offset = 0 - - def update(self, layer_idx, conv_state, ssm_state): - self.conv_states[layer_idx] = conv_state - self.ssm_states[layer_idx] = ssm_state - self.offset += 1 - - @property - def state(self): - return self.conv_states, self.ssm_states class KVCache: diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 47ee4a817..49c0ea11a 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, MambaCache +from .base import BaseModelArgs @dataclass From 290c1a4dda3c4a64809de3ca2c05ebfc492e55b0 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 4 Sep 2024 23:03:19 +0200 Subject: [PATCH 28/40] update ACKNOWLEDGEMENTS --- ACKNOWLEDGMENTS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 2b98bc95a..2037a0764 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -14,3 +14,4 @@ MLX Examples was developed with contributions from the following individuals: - Markus Enzweiler: Added the `cvae` examples. - Prince Canuma: Helped add support for `Starcoder2` models. - Shiyu Li: Added the `Segment Anything Model`. +- Gökdeniz Gülmez: Added support for `MiniCPM` and `Mamba`. From 9d14ea57e3f952a7537bef76ec9f33d958632762 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 4 Sep 2024 23:23:59 +0200 Subject: [PATCH 29/40] removing the model_type if stuff in the _step loop in generate_step and adding MambaCache in base.py for training easier generations and removing mamba in tuner/utils. --- llms/mlx_lm/models/base.py | 16 ++++++++ llms/mlx_lm/models/mamba.py | 77 +++++++++++++++---------------------- llms/mlx_lm/tuner/utils.py | 2 +- llms/mlx_lm/utils.py | 5 +-- 4 files changed, 49 insertions(+), 51 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 1a5cd42ba..73b3a1f14 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -7,7 +7,23 @@ import mlx.core as mx import mlx.nn as nn + +class MambaCache: + def __init__(self, num_layers, batch_size, conv_state_size, ssm_state_size): + self.conv_states = [mx.zeros((batch_size, *conv_state_size)) for _ in range(num_layers)] + self.ssm_states = [mx.zeros((batch_size, *ssm_state_size)) for _ in range(num_layers)] + self.offset = 0 + + def update(self, layer_idx, conv_state, ssm_state): + self.conv_states[layer_idx] = conv_state + self.ssm_states[layer_idx] = ssm_state + self.offset += 1 + + @property + def state(self): + return self.conv_states, self.ssm_states + class KVCache: def __init__(self, head_dim, n_kv_heads): diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 49c0ea11a..e5cb8f63c 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -5,7 +5,7 @@ import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, MambaCache @dataclass @@ -173,57 +173,47 @@ def ssm_step(self, x, ssm_state): return y, new_ssm_state - def __call__(self, x, cache): - # x : (B, T, D) where T is the number of tokens (5 in this case) - # cache : (h, inputs) - # h : (B, ED, N) - # inputs : (B, d_conv-1, ED) - - h, inputs = cache + def __call__(self, x, cache: MambaCache, layer_idx: int): B, T, D = x.shape + conv_state, ssm_state = cache.state[0][layer_idx], cache.state[1][layer_idx] + outputs = [] for t in range(T): xt = x[:, t, :] # (B, D) xz = self.in_proj(xt) # (B, 2*ED) x_t, z_t = xz.split(indices_or_sections=2, axis=1) # (B, ED), (B, ED) - # x branch - x_cache = mx.expand_dims(x_t, 1) # (B, 1, ED) - conv_input = mx.concatenate([inputs, x_cache], axis=1) # (B, d_conv, ED) - conv_out, new_inputs = self.conv1d(conv_input) # (B, d_conv, ED), (B, d_conv-1, ED) - x_t = conv_out[:, -1, :] # (B, ED) + conv_out, new_conv_state = self.conv1d(mx.expand_dims(x_t, 1), conv_state) + x_t = conv_out.squeeze(1) # (B, ED) x_t = nn.silu(x_t) - y_t, h = self.ssm_step(x_t, h) + y_t, new_ssm_state = self.ssm_step(x_t, ssm_state) - # z branch z_t = nn.silu(z_t) output_t = y_t * z_t output_t = self.out_proj(output_t) # (B, D) outputs.append(output_t) - # Update inputs for next token - inputs = new_inputs + conv_state = new_conv_state + ssm_state = new_ssm_state output = mx.stack(outputs, axis=1) # (B, T, D) - cache = (h, inputs) + cache.update(layer_idx, conv_state, ssm_state) - return output, cache + return output - class ResidualBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.mixer = MambaBlock(args) self.norm = nn.RMSNorm(args.hidden_size) - def __call__(self, x: mx.array, cache): - output, cache = self.mixer(self.norm(x), cache) + def __call__(self, x: mx.array, cache: MambaCache, layer_idx: int): + output = self.mixer(self.norm(x), cache, layer_idx) output = output + x - return output, cache - + return output class Mamba(nn.Module): def __init__(self, args: ModelArgs): @@ -232,12 +222,11 @@ def __init__(self, args: ModelArgs): self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] self.norm_f = nn.RMSNorm(args.hidden_size) - def __call__(self, x: mx.array, caches): + def __call__(self, x: mx.array, cache: MambaCache): x = self.embeddings(x) for i, layer in enumerate(self.layers): - x, caches[i] = layer(x, caches[i]) - return x, caches - + x = layer(x, cache, i) + return self.norm_f(x) class Model(nn.Module): def __init__(self, args: ModelArgs): @@ -249,32 +238,31 @@ def __init__(self, args: ModelArgs): if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - def __call__(self, inputs: mx.array, cache=None): - # inputs : (B, T) where T is the number of tokens - # caches : [cache(layer) for all layers], cache : (h, inputs) - + def __call__(self, inputs: mx.array, cache: MambaCache = None): if inputs.ndim == 1: - inputs = mx.expand_dims(inputs, 0) # Add batch dimension if not present + inputs = mx.expand_dims(inputs, 0) B, T = inputs.shape - x = self.backbone.embeddings(inputs) # (B, T, D) - for i, layer in enumerate(self.backbone.layers): - x, cache[i] = layer(x, cache[i]) + if cache is None: + cache = self.make_cache(batch_size=B) - x = self.backbone.norm_f(x) + x = self.backbone(inputs, cache) if self.args.tie_word_embeddings: logits = self.backbone.embeddings.as_linear(x) else: logits = self.lm_head(x) - return logits, cache + return logits - def make_cache(self): - B = 1 # Assuming batch size of 1 for simplicity - return [(None, mx.zeros((B, self.args.conv_kernel-1, self.args.intermediate_size))) - for _ in range(self.args.num_hidden_layers)] + def make_cache(self, batch_size: int = 1): + return MambaCache( + num_layers=self.args.num_hidden_layers, + batch_size=batch_size, + conv_state_size=(self.args.conv_kernel - 1, self.args.intermediate_size), + ssm_state_size=(self.args.intermediate_size, self.args.state_size) + ) @property def layers(self): @@ -286,7 +274,4 @@ def head_dim(self): @property def n_kv_heads(self): - return self.args.num_hidden_layers - - # def make_cache(self): - # return [(None, mx.zeros([1, self.args.conv_kernel-1, self.args.intermediate_size])) for _ in range(self.args.num_hidden_layers)] \ No newline at end of file + return self.args.num_hidden_layers \ No newline at end of file diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 235b54b1f..4c853d8bc 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -103,7 +103,7 @@ def to_lora(layer): "starcoder2", "cohere", "minicpm", - "deepseek", + "deepseek" ]: keys = set(["self_attn.q_proj", "self_attn.v_proj"]) if model.model_type in ["mixtral", "phimoe"]: diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 6506a1c98..7c2add585 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -236,10 +236,7 @@ def sample(logits: mx.array) -> Tuple[mx.array, float]: def _step(y): nonlocal repetition_context - if model.model_type == "mamba": - logits, _ = model(y[None], cache=cache) - else: - logits = model(y[None], cache=cache) + logits = model(y[None], cache=cache) logits = logits[:, -1, :] if repetition_penalty: From e8f5a6b213213e1b003520fb25a4181b4e179c9b Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Thu, 5 Sep 2024 09:17:10 +0200 Subject: [PATCH 30/40] quick clean up --- llms/mlx_lm/models/mamba.py | 3 +-- llms/tests/test_models.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index e5cb8f63c..616b011f3 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -31,7 +31,6 @@ class ModelArgs(BaseModelArgs): rescale_prenorm_residual: bool use_cache: bool use_mambapy: bool = False - dt_rank: str = "auto" tie_word_embeddings: bool = True @@ -152,7 +151,7 @@ def ssm_step(self, x, ssm_state): A = -mx.exp(self.A_log) # (ED, N) D = self.D # (ED,) - deltaBC = self.x_proj(x) # (B, dt_rank+2*N) + deltaBC = self.x_proj(x) # (B, time_step_rank+2*N) delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) delta = nn.softplus(self.dt_proj(delta)) # (B, ED) diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index fcf1dc331..83f15cda0 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -397,6 +397,40 @@ def test_minicpm(self): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_mamba(self): + from mlx_lm.models import mamba + + args = mamba.ModelArgs( + conv_kernel=4, + d_inner=1536, + d_model=768, + expand=2, + hidden_size=768, + initializer_range=0.1, + intermediate_size=1536, + layer_norm_epsilon=1e-05, + model_type="mamba", + n_layer=24, + num_hidden_layers=24, + state_size=16, + rms_norm=True, + rescale_prenorm_residual=False, + time_step_floor= 0.0001, + time_step_init_scheme="random", + time_step_max=0.1, + time_step_min=0.001, + time_step_rank=48, + time_step_scale=1.0, + vocab_size=10000, + use_bias=False, + use_conv_bias=True, + use_cache=True, + ) + model = mamba.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + def test_gpt2(self): from mlx_lm.models import gpt2 From 511cdf89b1af1feab7ac9634cd444c54787b78cc Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 18 Sep 2024 11:04:01 +0200 Subject: [PATCH 31/40] update trainer/utils for right initialisation of the layers for LoRA, but not working. --- llms/mlx_lm/models/mamba-infer.py | 276 ++++++++++++++++++++++++++++++ llms/mlx_lm/models/mamba.py | 217 ++++++++++++++++++++--- llms/mlx_lm/tuner/utils.py | 9 +- 3 files changed, 474 insertions(+), 28 deletions(-) create mode 100644 llms/mlx_lm/models/mamba-infer.py diff --git a/llms/mlx_lm/models/mamba-infer.py b/llms/mlx_lm/models/mamba-infer.py new file mode 100644 index 000000000..2ee8b47aa --- /dev/null +++ b/llms/mlx_lm/models/mamba-infer.py @@ -0,0 +1,276 @@ +from dataclasses import dataclass + +import math + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, MambaCache + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + vocab_size: int + hidden_size: int + intermediate_size: int + state_size: int + num_hidden_layers: int + layer_norm_epsilon: float + expand: int + conv_kernel: int + use_bias: bool + use_conv_bias: bool + initializer_range: float + time_step_rank: int + time_step_scale: float + time_step_min: float + time_step_max: float + time_step_init_scheme: str + time_step_floor: float + rescale_prenorm_residual: bool + use_cache: bool + use_mambapy: bool = False + tie_word_embeddings: bool = True + + + def __post_init__(self): + if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'): + self.hidden_size = self.d_model + if not hasattr(self, 'intermediate_size') and hasattr(self, 'd_inner'): + self.intermediate_size = self.d_inner + if not hasattr(self, 'state_size') and hasattr(self, 'd_state'): + self.state_size = self.d_state + if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'): + self.num_hidden_layers = self.n_layer + if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'): + self.num_hidden_layers = self.n_layers + if not hasattr(self, 'conv_kernel') and hasattr(self, 'd_conv'): + self.conv_kernel = self.d_conv + if not hasattr(self, 'use_bias') and hasattr(self, 'bias'): + self.use_bias = self.bias + if not hasattr(self, 'use_conv_bias') and hasattr(self, 'conv_bias'): + self.use_conv_bias = self.conv_bias + + self.intermediate_size = self.expand * self.hidden_size + if self.time_step_rank == "auto": + self.time_step_rank = math.ceil(self.hidden_size / 16) + +class DepthWiseConv1d(nn.Module): + def __init__(self, channels, kernel_size, bias=True, padding=0): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.padding = padding + self.weight = mx.random.normal((channels, 1, kernel_size)) + self.bias = mx.zeros((channels,)) if bias else None + + def __call__(self, x, conv_state=None): + B, L, C = x.shape + K = self.kernel_size + + if conv_state is None: + conv_state = mx.zeros((B, K - 1, C)) + + x = mx.concatenate([conv_state, x], axis=1) + + output = [] + for i in range(K): + slice = x[:, i:i+L, :] + output.append(slice * self.weight[:, 0, i]) + y = mx.sum(mx.stack(output), axis=0) + + if self.bias is not None: + y = y + self.bias.reshape(1, 1, -1) + + new_conv_state = x[:, -K+1:, :] + + return y, new_conv_state + + +def clamp(x, min=None, max=None): + if min is not None: + mask_lower = x < min + if max is not None: + mask_upper = x > max + if min is not None: + if max is not None: + return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) + return mx.where(mask_lower, min, x) + return mx.where(mask_upper, max, x) + + +class MambaBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.hidden_size = args.hidden_size + self.ssm_state_size = args.state_size + self.conv_kernel_size = args.conv_kernel + self.intermediate_size = args.intermediate_size + self.time_step_rank = int(args.time_step_rank) + self.use_conv_bias = args.use_conv_bias + + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) + + self.conv1d = DepthWiseConv1d( + channels=self.intermediate_size, + kernel_size=self.conv_kernel_size, + bias=self.use_conv_bias, + padding=self.conv_kernel_size-1 + ) + + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + + dt_init_std = args.time_step_rank**-0.5 * args.state_size + if args.time_step_init_scheme == "constant": + self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) + elif args.time_step_init_scheme == "random": + self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) + else: + raise NotImplementedError + + dt = clamp(mx.exp( + mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min) + ), min=args.time_step_floor) + inv_dt = dt + mx.log1p(-mx.exp(-dt)) + self.dt_proj.bias = inv_dt + + A = mx.repeat(mx.arange(1., self.ssm_state_size + 1.).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) + self.A_log = mx.log(A) + self.D = mx.ones([self.intermediate_size]) + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) + + def ssm_step(self, x, ssm_state): + # x : (B, ED) + # ssm_state : (B, ED, N) + + A = -mx.exp(self.A_log) # (ED, N) + D = self.D # (ED,) + + deltaBC = self.x_proj(x) # (B, time_step_rank+2*N) + delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) + delta = nn.softplus(self.dt_proj(delta)) # (B, ED) + + deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) + deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N) + + BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) + + if ssm_state is None: + ssm_state = mx.zeros((x.shape[0], self.intermediate_size, self.ssm_state_size)) # (B, ED, N) + + new_ssm_state = deltaA * ssm_state + BX # (B, ED, N) + + y = (new_ssm_state @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED) + + y = y + D * x # (B, ED) + + return y, new_ssm_state + + + def __call__(self, x, cache: MambaCache, layer_idx: int): + B, T, D = x.shape + + conv_state, ssm_state = cache.state[0][layer_idx], cache.state[1][layer_idx] + + outputs = [] + for t in range(T): + xt = x[:, t, :] # (B, D) + xz = self.in_proj(xt) # (B, 2*ED) + x_t, z_t = xz.split(indices_or_sections=2, axis=1) # (B, ED), (B, ED) + + conv_out, new_conv_state = self.conv1d(mx.expand_dims(x_t, 1), conv_state) + x_t = conv_out.squeeze(1) # (B, ED) + + x_t = nn.silu(x_t) + y_t, new_ssm_state = self.ssm_step(x_t, ssm_state) + + z_t = nn.silu(z_t) + + output_t = y_t * z_t + output_t = self.out_proj(output_t) # (B, D) + outputs.append(output_t) + + conv_state = new_conv_state + ssm_state = new_ssm_state + + output = mx.stack(outputs, axis=1) # (B, T, D) + cache.update(layer_idx, conv_state, ssm_state) + + return output + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.mixer = MambaBlock(args) + self.norm = nn.RMSNorm(args.hidden_size) + + def __call__(self, x: mx.array, cache: MambaCache, layer_idx: int): + output = self.mixer(self.norm(x), cache, layer_idx) + output = output + x + return output + +class Mamba(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] + self.norm_f = nn.RMSNorm(args.hidden_size) + + def __call__(self, x: mx.array, cache: MambaCache): + x = self.embeddings(x) + for i, layer in enumerate(self.layers): + x = layer(x, cache, i) + return self.norm_f(x) + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.backbone = Mamba(args) + + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__(self, inputs: mx.array, cache = None): + if inputs.ndim == 1: + inputs = mx.expand_dims(inputs, 0) + + B, T = inputs.shape + + if cache is None: + cache = self.make_cache(batch_size=B) + + x = self.backbone(inputs, cache) + + if self.args.tie_word_embeddings: + logits = self.backbone.embeddings.as_linear(x) + else: + logits = self.lm_head(x) + + return logits + + def make_cache(self, batch_size: int = 1): + return MambaCache( + num_layers=self.args.num_hidden_layers, + batch_size=batch_size, + conv_state_size=(self.args.conv_kernel - 1, self.args.intermediate_size), + ssm_state_size=(self.args.intermediate_size, self.args.state_size) + ) + + @property + def layers(self): + return self.backbone.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_hidden_layers + + @property + def n_kv_heads(self): + return self.args.num_hidden_layers \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 616b011f3..574445c7a 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -30,7 +30,8 @@ class ModelArgs(BaseModelArgs): time_step_floor: float rescale_prenorm_residual: bool use_cache: bool - use_mambapy: bool = False + pscan: bool = False # use parallel scan mode or sequential mode when training + use_mambapy: bool = False tie_word_embeddings: bool = True @@ -53,8 +54,8 @@ def __post_init__(self): self.use_conv_bias = self.conv_bias self.intermediate_size = self.expand * self.hidden_size - if self.dt_rank == "auto": - self.dt_rank = math.ceil(self.hidden_size / 16) + if self.time_step_rank == "auto": + self.time_step_rank = math.ceil(self.hidden_size / 16) class DepthWiseConv1d(nn.Module): def __init__(self, channels, kernel_size, bias=True, padding=0): @@ -100,6 +101,84 @@ def clamp(x, min=None, max=None): return mx.where(mask_upper, max, x) +def pscan_f(A, X): + # A : (B, D, L, N) + # X : (B, D, L, N) + + # modifies X in place by doing a parallel scan. + # more formally, X will be populated by these values : + # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0 + # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps) + + Aa = A + Xa = X + + B, D, L, _ = A.shape + + num_steps = int(math.log2(L)) + + # up sweep + for k in range(num_steps): + T = 2 * (Xa.shape[2] // 2) + + Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) + Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) + + Xa[:, :, :, 1] += Aa[:, :, :, 1] * Xa[:, :, :, 0] + Aa[:, :, :, 1] *= Aa[:, :, :, 0] + + A[:, :, 2**(k+1)-1::2**(k+1)] = Aa[:, :, :, 1] + X[:, :, 2**(k+1)-1::2**(k+1)] = Xa[:, :, :, 1] + + Aa = Aa[:, :, :, 1] + Xa = Xa[:, :, :, 1] + + # down sweep + for k in range(num_steps-1, -1, -1): + Aa = A[:, :, 2**k-1::2**k] + Xa = X[:, :, 2**k-1::2**k] + + step_len = Xa.shape[2] + T = 2 * (step_len // 2) + + if T < step_len: + last_val_aa = Aa[:, :, -1] * Aa[:, :, -2] + last_val_xa = Xa[:, :, -1] + Aa[:, :, -1] * Xa[:, :, -2] + + Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) + Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) + + Xa[:, :, 1:, 0] += Aa[:, :, 1:, 0] * Xa[:, :, :-1, 1] + Aa[:, :, 1:, 0] *= Aa[:, :, :-1, 1] + + if T == step_len: + A[:, :, 2**k-1::2**(k+1)] = Aa[:, :, :, 0] + X[:, :, 2**k-1::2**(k+1)] = Xa[:, :, :, 0] + else: + A[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Aa[:, :, :, 0], mx.array([last_val_aa]).reshape(B, D, 1, -1)], axis=2) + X[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Xa[:, :, :, 0], mx.array([last_val_xa]).reshape(B, D, 1, -1)], axis=2) + +# main function, used in the Mamba model (mamba_mlx.py) +def pscan(A_in, X_in): + """ + Applies the parallel scan operation, as defined above. Returns a new tensor. + + Args: + A_in : (B, L, ED, N) + X_in : (B, L, ED, N) + + Returns: + H : (B, L, ED, N) + """ + + A = A_in[:].transpose(0, 2, 1, 3) + X = X_in[:].transpose(0, 2, 1, 3) + + pscan_f(A, X) + + return X.transpose(0, 2, 1, 3) + + class MambaBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -144,10 +223,8 @@ def __init__(self, args: ModelArgs): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) - def ssm_step(self, x, ssm_state): - # x : (B, ED) - # ssm_state : (B, ED, N) - + def ssm_step(self, x, ssm_state=None): + # Modify this method to work without state during training A = -mx.exp(self.A_log) # (ED, N) D = self.D # (ED,) @@ -160,34 +237,125 @@ def ssm_step(self, x, ssm_state): BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) - if ssm_state is None: - ssm_state = mx.zeros((x.shape[0], self.intermediate_size, self.ssm_state_size)) # (B, ED, N) - - new_ssm_state = deltaA * ssm_state + BX # (B, ED, N) + if self.training: + # During training, we don't use or update the state + new_ssm_state = BX + else: + if ssm_state is None: + ssm_state = mx.zeros((x.shape[0], self.intermediate_size, self.ssm_state_size)) # (B, ED, N) + new_ssm_state = deltaA * ssm_state + BX # (B, ED, N) y = (new_ssm_state @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED) - y = y + D * x # (B, ED) - return y, new_ssm_state + if self.training: + return y + else: + return y, new_ssm_state + + + def ssm(self, x): + # x : (B, L, ED) + + # y : (B, L, ED) + + A = -mx.exp(self.A_log) # (ED, N) + D = self.D + + deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N) + + delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, L, dt_rank), (B, L, N), (B, L, N) + delta = nn.softplus(self.dt_proj(delta)) # (B, L, ED) + + if self.args.pscan: + y = self.selective_scan(x, delta, A, B, C, D) + else: + y = self.selective_scan_seq(x, delta, A, B, C, D) + + return y + + + def selective_scan(self, x, delta, A, B, C, D): + # x : (B, L, ED) + # Δ : (B, L, ED) + # A : (ED, N) + # B : (B, L, N) + # C : (B, L, N) + # D : (ED) + + # y : (B, L, ED) + + deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N) + deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N) + + BX = deltaB * mx.expand_dims(x, -1) # (B, L, ED, N) + + hs = pscan(deltaA, BX) + + y = (hs @ mx.expand_dims(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) + + y = y + D * x + + return y + + def selective_scan_seq(self, x, delta, A, B, C, D): + # x : (B, L, ED) + # Δ : (B, L, ED) + # A : (ED, N) + # B : (B, L, N) + # C : (B, L, N) + # D : (ED) + + # y : (B, L, ED) + + _, L, _ = x.shape + + deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N) + deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N) + + BX = deltaB * mx.expand_dims(x, -1) # (B, L, ED, N) + + h = mx.zeros([x.shape[0], self.config.d_inner, self.config.d_state]) # (B, ED, N) + hs = [] + + for t in range(0, L): + h = deltaA[:, t] * h + BX[:, t] + hs.append(h) + + hs = mx.stack(hs, axis=1) + + y = (hs @ mx.expand_dims(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) + + y = y + D * x + + return y def __call__(self, x, cache: MambaCache, layer_idx: int): B, T, D = x.shape - conv_state, ssm_state = cache.state[0][layer_idx], cache.state[1][layer_idx] - outputs = [] for t in range(T): xt = x[:, t, :] # (B, D) xz = self.in_proj(xt) # (B, 2*ED) x_t, z_t = xz.split(indices_or_sections=2, axis=1) # (B, ED), (B, ED) - conv_out, new_conv_state = self.conv1d(mx.expand_dims(x_t, 1), conv_state) - x_t = conv_out.squeeze(1) # (B, ED) + if self.training: + conv_out, _ = self.conv1d(mx.expand_dims(x_t, 1)) + else: + conv_state = cache.state[0][layer_idx] + conv_out, new_conv_state = self.conv1d(mx.expand_dims(x_t, 1), conv_state) + cache.state[0][layer_idx] = new_conv_state + x_t = conv_out.squeeze(1) # (B, ED) x_t = nn.silu(x_t) - y_t, new_ssm_state = self.ssm_step(x_t, ssm_state) + + if self.training: + y_t = self.ssm_step(x_t) + else: + ssm_state = cache.state[1][layer_idx] + y_t, new_ssm_state = self.ssm_step(x_t, ssm_state) + cache.state[1][layer_idx] = new_ssm_state z_t = nn.silu(z_t) @@ -195,12 +363,7 @@ def __call__(self, x, cache: MambaCache, layer_idx: int): output_t = self.out_proj(output_t) # (B, D) outputs.append(output_t) - conv_state = new_conv_state - ssm_state = new_ssm_state - output = mx.stack(outputs, axis=1) # (B, T, D) - cache.update(layer_idx, conv_state, ssm_state) - return output class ResidualBlock(nn.Module): @@ -210,6 +373,10 @@ def __init__(self, args: ModelArgs): self.norm = nn.RMSNorm(args.hidden_size) def __call__(self, x: mx.array, cache: MambaCache, layer_idx: int): + # Ensure x is 3D before passing to mixer + if x.ndim == 2: + x = mx.expand_dims(x, 1) # Make it (B, 1, D) + output = self.mixer(self.norm(x), cache, layer_idx) output = output + x return output @@ -237,13 +404,13 @@ def __init__(self, args: ModelArgs): if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - def __call__(self, inputs: mx.array, cache: MambaCache = None): + def __call__(self, inputs: mx.array, cache = None): if inputs.ndim == 1: inputs = mx.expand_dims(inputs, 0) B, T = inputs.shape - if cache is None: + if not self.training and cache is None: cache = self.make_cache(batch_size=B) x = self.backbone(inputs, cache) diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 4c853d8bc..ab0e470e0 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -52,8 +52,10 @@ def linear_to_lora_layers( use_dora (bool): If True, uses DoRA instead of LoRA. Default: ``False`` """ - - num_layers = len(model.layers) + if hasattr(model, "backbone"): + num_layers = len(model.backbone.layers) + else: + num_layers = len(model.layers) if num_lora_layers < 0: num_lora_layers = num_layers @@ -103,7 +105,8 @@ def to_lora(layer): "starcoder2", "cohere", "minicpm", - "deepseek" + "deepseek", + "mamba" ]: keys = set(["self_attn.q_proj", "self_attn.v_proj"]) if model.model_type in ["mixtral", "phimoe"]: From 602c9f18bd86aac9857c8689ec544727575fe7a4 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 18 Sep 2024 11:13:22 +0200 Subject: [PATCH 32/40] clean up --- llms/mlx_lm/models/mamba.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 574445c7a..3771a25a0 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -106,8 +106,7 @@ def pscan_f(A, X): # X : (B, D, L, N) # modifies X in place by doing a parallel scan. - # more formally, X will be populated by these values : - # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0 + # more formally, X will be populated by these values: H[t] = A[t] * H[t-1] + X[t] with H[0] = 0 # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps) Aa = A @@ -164,11 +163,11 @@ def pscan(A_in, X_in): Applies the parallel scan operation, as defined above. Returns a new tensor. Args: - A_in : (B, L, ED, N) - X_in : (B, L, ED, N) + A_in: mx.array =-> Shape(B, L, ED, N) + X_in: mx.array -> Shape (B, L, ED, N) Returns: - H : (B, L, ED, N) + H: mx.array -> Shape (B, L, ED, N) """ A = A_in[:].transpose(0, 2, 1, 3) @@ -277,14 +276,11 @@ def ssm(self, x): def selective_scan(self, x, delta, A, B, C, D): # x : (B, L, ED) - # Δ : (B, L, ED) # A : (ED, N) # B : (B, L, N) # C : (B, L, N) # D : (ED) - # y : (B, L, ED) - deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N) deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N) @@ -296,18 +292,15 @@ def selective_scan(self, x, delta, A, B, C, D): y = y + D * x - return y + return y # (B, L, ED) def selective_scan_seq(self, x, delta, A, B, C, D): # x : (B, L, ED) - # Δ : (B, L, ED) # A : (ED, N) # B : (B, L, N) # C : (B, L, N) # D : (ED) - # y : (B, L, ED) - _, L, _ = x.shape deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N) @@ -328,7 +321,7 @@ def selective_scan_seq(self, x, delta, A, B, C, D): y = y + D * x - return y + return y # (B, L, ED) def __call__(self, x, cache: MambaCache, layer_idx: int): From 40f9e83306a45dd2e35aa1ecffab082d9931d2dd Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 18 Sep 2024 13:24:39 +0200 Subject: [PATCH 33/40] Forther update to trainer/utils for correct layer selection. Successfull training --- llms/mlx_lm/tuner/utils.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index ab0e470e0..0a8a0913f 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -53,9 +53,13 @@ def linear_to_lora_layers( Default: ``False`` """ if hasattr(model, "backbone"): - num_layers = len(model.backbone.layers) + layers = model.backbone.layers + elif hasattr(model, "layers"): + layers = model.layers else: - num_layers = len(model.layers) + raise ValueError("Unsupported model structure") + + num_layers = len(layers) if num_lora_layers < 0: num_lora_layers = num_layers @@ -143,9 +147,18 @@ def to_lora(layer): "self_attn.kv_b_proj", ] ) + if model.model_type == "mamba": + keys = set([ + "mixer.in_proj", + "mixer.x_proj", + "mixer.dt_proj", + "mixer.out_proj", + ]) else: raise ValueError(f"Lora does not support {model.model_type}") + # Modified the layer selection to handle both regular and backbone structures: + layers = model.backbone.layers if hasattr(model, "backbone") else model.layers for l in model.layers[num_layers - num_lora_layers :]: lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys] if lora_layers: From 399de78f51bc3665cbdafc3ddf827af004902943 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 18 Sep 2024 13:37:43 +0200 Subject: [PATCH 34/40] removing extra mamba-infer.py file --- llms/mlx_lm/models/mamba-infer.py | 276 ------------------------------ 1 file changed, 276 deletions(-) delete mode 100644 llms/mlx_lm/models/mamba-infer.py diff --git a/llms/mlx_lm/models/mamba-infer.py b/llms/mlx_lm/models/mamba-infer.py deleted file mode 100644 index 2ee8b47aa..000000000 --- a/llms/mlx_lm/models/mamba-infer.py +++ /dev/null @@ -1,276 +0,0 @@ -from dataclasses import dataclass - -import math - -import mlx.core as mx -import mlx.nn as nn - -from .base import BaseModelArgs, MambaCache - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - vocab_size: int - hidden_size: int - intermediate_size: int - state_size: int - num_hidden_layers: int - layer_norm_epsilon: float - expand: int - conv_kernel: int - use_bias: bool - use_conv_bias: bool - initializer_range: float - time_step_rank: int - time_step_scale: float - time_step_min: float - time_step_max: float - time_step_init_scheme: str - time_step_floor: float - rescale_prenorm_residual: bool - use_cache: bool - use_mambapy: bool = False - tie_word_embeddings: bool = True - - - def __post_init__(self): - if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'): - self.hidden_size = self.d_model - if not hasattr(self, 'intermediate_size') and hasattr(self, 'd_inner'): - self.intermediate_size = self.d_inner - if not hasattr(self, 'state_size') and hasattr(self, 'd_state'): - self.state_size = self.d_state - if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'): - self.num_hidden_layers = self.n_layer - if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'): - self.num_hidden_layers = self.n_layers - if not hasattr(self, 'conv_kernel') and hasattr(self, 'd_conv'): - self.conv_kernel = self.d_conv - if not hasattr(self, 'use_bias') and hasattr(self, 'bias'): - self.use_bias = self.bias - if not hasattr(self, 'use_conv_bias') and hasattr(self, 'conv_bias'): - self.use_conv_bias = self.conv_bias - - self.intermediate_size = self.expand * self.hidden_size - if self.time_step_rank == "auto": - self.time_step_rank = math.ceil(self.hidden_size / 16) - -class DepthWiseConv1d(nn.Module): - def __init__(self, channels, kernel_size, bias=True, padding=0): - super().__init__() - self.channels = channels - self.kernel_size = kernel_size - self.padding = padding - self.weight = mx.random.normal((channels, 1, kernel_size)) - self.bias = mx.zeros((channels,)) if bias else None - - def __call__(self, x, conv_state=None): - B, L, C = x.shape - K = self.kernel_size - - if conv_state is None: - conv_state = mx.zeros((B, K - 1, C)) - - x = mx.concatenate([conv_state, x], axis=1) - - output = [] - for i in range(K): - slice = x[:, i:i+L, :] - output.append(slice * self.weight[:, 0, i]) - y = mx.sum(mx.stack(output), axis=0) - - if self.bias is not None: - y = y + self.bias.reshape(1, 1, -1) - - new_conv_state = x[:, -K+1:, :] - - return y, new_conv_state - - -def clamp(x, min=None, max=None): - if min is not None: - mask_lower = x < min - if max is not None: - mask_upper = x > max - if min is not None: - if max is not None: - return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) - return mx.where(mask_lower, min, x) - return mx.where(mask_upper, max, x) - - -class MambaBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - - self.hidden_size = args.hidden_size - self.ssm_state_size = args.state_size - self.conv_kernel_size = args.conv_kernel - self.intermediate_size = args.intermediate_size - self.time_step_rank = int(args.time_step_rank) - self.use_conv_bias = args.use_conv_bias - - self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) - - self.conv1d = DepthWiseConv1d( - channels=self.intermediate_size, - kernel_size=self.conv_kernel_size, - bias=self.use_conv_bias, - padding=self.conv_kernel_size-1 - ) - - self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) - self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) - - dt_init_std = args.time_step_rank**-0.5 * args.state_size - if args.time_step_init_scheme == "constant": - self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) - elif args.time_step_init_scheme == "random": - self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) - else: - raise NotImplementedError - - dt = clamp(mx.exp( - mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min) - ), min=args.time_step_floor) - inv_dt = dt + mx.log1p(-mx.exp(-dt)) - self.dt_proj.bias = inv_dt - - A = mx.repeat(mx.arange(1., self.ssm_state_size + 1.).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) - self.A_log = mx.log(A) - self.D = mx.ones([self.intermediate_size]) - - self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) - - def ssm_step(self, x, ssm_state): - # x : (B, ED) - # ssm_state : (B, ED, N) - - A = -mx.exp(self.A_log) # (ED, N) - D = self.D # (ED,) - - deltaBC = self.x_proj(x) # (B, time_step_rank+2*N) - delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) - delta = nn.softplus(self.dt_proj(delta)) # (B, ED) - - deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) - deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N) - - BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) - - if ssm_state is None: - ssm_state = mx.zeros((x.shape[0], self.intermediate_size, self.ssm_state_size)) # (B, ED, N) - - new_ssm_state = deltaA * ssm_state + BX # (B, ED, N) - - y = (new_ssm_state @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED) - - y = y + D * x # (B, ED) - - return y, new_ssm_state - - - def __call__(self, x, cache: MambaCache, layer_idx: int): - B, T, D = x.shape - - conv_state, ssm_state = cache.state[0][layer_idx], cache.state[1][layer_idx] - - outputs = [] - for t in range(T): - xt = x[:, t, :] # (B, D) - xz = self.in_proj(xt) # (B, 2*ED) - x_t, z_t = xz.split(indices_or_sections=2, axis=1) # (B, ED), (B, ED) - - conv_out, new_conv_state = self.conv1d(mx.expand_dims(x_t, 1), conv_state) - x_t = conv_out.squeeze(1) # (B, ED) - - x_t = nn.silu(x_t) - y_t, new_ssm_state = self.ssm_step(x_t, ssm_state) - - z_t = nn.silu(z_t) - - output_t = y_t * z_t - output_t = self.out_proj(output_t) # (B, D) - outputs.append(output_t) - - conv_state = new_conv_state - ssm_state = new_ssm_state - - output = mx.stack(outputs, axis=1) # (B, T, D) - cache.update(layer_idx, conv_state, ssm_state) - - return output - -class ResidualBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.mixer = MambaBlock(args) - self.norm = nn.RMSNorm(args.hidden_size) - - def __call__(self, x: mx.array, cache: MambaCache, layer_idx: int): - output = self.mixer(self.norm(x), cache, layer_idx) - output = output + x - return output - -class Mamba(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] - self.norm_f = nn.RMSNorm(args.hidden_size) - - def __call__(self, x: mx.array, cache: MambaCache): - x = self.embeddings(x) - for i, layer in enumerate(self.layers): - x = layer(x, cache, i) - return self.norm_f(x) - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.model_type = args.model_type - self.backbone = Mamba(args) - - if not args.tie_word_embeddings: - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__(self, inputs: mx.array, cache = None): - if inputs.ndim == 1: - inputs = mx.expand_dims(inputs, 0) - - B, T = inputs.shape - - if cache is None: - cache = self.make_cache(batch_size=B) - - x = self.backbone(inputs, cache) - - if self.args.tie_word_embeddings: - logits = self.backbone.embeddings.as_linear(x) - else: - logits = self.lm_head(x) - - return logits - - def make_cache(self, batch_size: int = 1): - return MambaCache( - num_layers=self.args.num_hidden_layers, - batch_size=batch_size, - conv_state_size=(self.args.conv_kernel - 1, self.args.intermediate_size), - ssm_state_size=(self.args.intermediate_size, self.args.state_size) - ) - - @property - def layers(self): - return self.backbone.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_hidden_layers - - @property - def n_kv_heads(self): - return self.args.num_hidden_layers \ No newline at end of file From 13af75d88a56e9299f79ea4625b6adf1742de0b9 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 18 Sep 2024 14:44:49 +0200 Subject: [PATCH 35/40] clean up, reformating will come later --- llms/mlx_lm/models/mamba.py | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 3771a25a0..05d2adfb7 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -30,8 +30,7 @@ class ModelArgs(BaseModelArgs): time_step_floor: float rescale_prenorm_residual: bool use_cache: bool - pscan: bool = False # use parallel scan mode or sequential mode when training - use_mambapy: bool = False + pscan: bool = False tie_word_embeddings: bool = True @@ -102,13 +101,6 @@ def clamp(x, min=None, max=None): def pscan_f(A, X): - # A : (B, D, L, N) - # X : (B, D, L, N) - - # modifies X in place by doing a parallel scan. - # more formally, X will be populated by these values: H[t] = A[t] * H[t-1] + X[t] with H[0] = 0 - # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps) - Aa = A Xa = X @@ -157,10 +149,10 @@ def pscan_f(A, X): A[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Aa[:, :, :, 0], mx.array([last_val_aa]).reshape(B, D, 1, -1)], axis=2) X[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Xa[:, :, :, 0], mx.array([last_val_xa]).reshape(B, D, 1, -1)], axis=2) -# main function, used in the Mamba model (mamba_mlx.py) + def pscan(A_in, X_in): """ - Applies the parallel scan operation, as defined above. Returns a new tensor. + Applies the parallel scan operation, as defined above. Returns a new array. Args: A_in: mx.array =-> Shape(B, L, ED, N) @@ -169,12 +161,9 @@ def pscan(A_in, X_in): Returns: H: mx.array -> Shape (B, L, ED, N) """ - A = A_in[:].transpose(0, 2, 1, 3) X = X_in[:].transpose(0, 2, 1, 3) - pscan_f(A, X) - return X.transpose(0, 2, 1, 3) @@ -223,7 +212,6 @@ def __init__(self, args: ModelArgs): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) def ssm_step(self, x, ssm_state=None): - # Modify this method to work without state during training A = -mx.exp(self.A_log) # (ED, N) D = self.D # (ED,) @@ -237,7 +225,6 @@ def ssm_step(self, x, ssm_state=None): BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) if self.training: - # During training, we don't use or update the state new_ssm_state = BX else: if ssm_state is None: @@ -255,7 +242,6 @@ def ssm_step(self, x, ssm_state=None): def ssm(self, x): # x : (B, L, ED) - # y : (B, L, ED) A = -mx.exp(self.A_log) # (ED, N) @@ -280,7 +266,6 @@ def selective_scan(self, x, delta, A, B, C, D): # B : (B, L, N) # C : (B, L, N) # D : (ED) - deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N) deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N) @@ -300,7 +285,6 @@ def selective_scan_seq(self, x, delta, A, B, C, D): # B : (B, L, N) # C : (B, L, N) # D : (ED) - _, L, _ = x.shape deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N) @@ -366,7 +350,6 @@ def __init__(self, args: ModelArgs): self.norm = nn.RMSNorm(args.hidden_size) def __call__(self, x: mx.array, cache: MambaCache, layer_idx: int): - # Ensure x is 3D before passing to mixer if x.ndim == 2: x = mx.expand_dims(x, 1) # Make it (B, 1, D) From 9457329ed3b0ec4c7bd5d7ca78a86660303443e3 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Thu, 19 Sep 2024 18:13:40 +0200 Subject: [PATCH 36/40] reformat and big clean up, final commit --- llms/mlx_lm/models/mamba.py | 150 +++++++++--------------------------- 1 file changed, 38 insertions(+), 112 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 05d2adfb7..66f609986 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -33,7 +33,6 @@ class ModelArgs(BaseModelArgs): pscan: bool = False tie_word_embeddings: bool = True - def __post_init__(self): if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'): self.hidden_size = self.d_model @@ -56,6 +55,7 @@ def __post_init__(self): if self.time_step_rank == "auto": self.time_step_rank = math.ceil(self.hidden_size / 16) + class DepthWiseConv1d(nn.Module): def __init__(self, channels, kernel_size, bias=True, padding=0): super().__init__() @@ -68,23 +68,17 @@ def __init__(self, channels, kernel_size, bias=True, padding=0): def __call__(self, x, conv_state=None): B, L, C = x.shape K = self.kernel_size - if conv_state is None: conv_state = mx.zeros((B, K - 1, C)) - x = mx.concatenate([conv_state, x], axis=1) - output = [] for i in range(K): slice = x[:, i:i+L, :] output.append(slice * self.weight[:, 0, i]) y = mx.sum(mx.stack(output), axis=0) - if self.bias is not None: y = y + self.bias.reshape(1, 1, -1) - new_conv_state = x[:, -K+1:, :] - return y, new_conv_state @@ -103,45 +97,32 @@ def clamp(x, min=None, max=None): def pscan_f(A, X): Aa = A Xa = X - B, D, L, _ = A.shape - num_steps = int(math.log2(L)) - # up sweep for k in range(num_steps): T = 2 * (Xa.shape[2] // 2) - Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa[:, :, :, 1] += Aa[:, :, :, 1] * Xa[:, :, :, 0] Aa[:, :, :, 1] *= Aa[:, :, :, 0] - A[:, :, 2**(k+1)-1::2**(k+1)] = Aa[:, :, :, 1] X[:, :, 2**(k+1)-1::2**(k+1)] = Xa[:, :, :, 1] - Aa = Aa[:, :, :, 1] Xa = Xa[:, :, :, 1] - # down sweep for k in range(num_steps-1, -1, -1): Aa = A[:, :, 2**k-1::2**k] Xa = X[:, :, 2**k-1::2**k] - step_len = Xa.shape[2] T = 2 * (step_len // 2) - if T < step_len: last_val_aa = Aa[:, :, -1] * Aa[:, :, -2] last_val_xa = Xa[:, :, -1] + Aa[:, :, -1] * Xa[:, :, -2] - Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa[:, :, 1:, 0] += Aa[:, :, 1:, 0] * Xa[:, :, :-1, 1] Aa[:, :, 1:, 0] *= Aa[:, :, :-1, 1] - if T == step_len: A[:, :, 2**k-1::2**(k+1)] = Aa[:, :, :, 0] X[:, :, 2**k-1::2**(k+1)] = Xa[:, :, :, 0] @@ -151,16 +132,6 @@ def pscan_f(A, X): def pscan(A_in, X_in): - """ - Applies the parallel scan operation, as defined above. Returns a new array. - - Args: - A_in: mx.array =-> Shape(B, L, ED, N) - X_in: mx.array -> Shape (B, L, ED, N) - - Returns: - H: mx.array -> Shape (B, L, ED, N) - """ A = A_in[:].transpose(0, 2, 1, 3) X = X_in[:].transpose(0, 2, 1, 3) pscan_f(A, X) @@ -212,137 +183,92 @@ def __init__(self, args: ModelArgs): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) def ssm_step(self, x, ssm_state=None): - A = -mx.exp(self.A_log) # (ED, N) - D = self.D # (ED,) - - deltaBC = self.x_proj(x) # (B, time_step_rank+2*N) + A = -mx.exp(self.A_log) + D = self.D + deltaBC = self.x_proj(x) delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) - delta = nn.softplus(self.dt_proj(delta)) # (B, ED) - - deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, ED, N) - deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) # (B, ED, N) - - BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N) - + delta = nn.softplus(self.dt_proj(delta)) + deltaA = mx.exp(mx.expand_dims(delta, -1) * A) + deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) + BX = deltaB * mx.expand_dims(x, -1) if self.training: new_ssm_state = BX else: if ssm_state is None: - ssm_state = mx.zeros((x.shape[0], self.intermediate_size, self.ssm_state_size)) # (B, ED, N) - new_ssm_state = deltaA * ssm_state + BX # (B, ED, N) - - y = (new_ssm_state @ mx.expand_dims(C, -1)).squeeze(2) # (B, ED) - y = y + D * x # (B, ED) - + ssm_state = mx.zeros((x.shape[0], self.intermediate_size, self.ssm_state_size)) + new_ssm_state = deltaA * ssm_state + BX + y = (new_ssm_state @ mx.expand_dims(C, -1)).squeeze(2) + y = y + D * x if self.training: return y else: return y, new_ssm_state - def ssm(self, x): - # x : (B, L, ED) - # y : (B, L, ED) - - A = -mx.exp(self.A_log) # (ED, N) + A = -mx.exp(self.A_log) D = self.D - - deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N) - - delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) # (B, L, dt_rank), (B, L, N), (B, L, N) - delta = nn.softplus(self.dt_proj(delta)) # (B, L, ED) - + deltaBC = self.x_proj(x) + delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) + delta = nn.softplus(self.dt_proj(delta)) if self.args.pscan: y = self.selective_scan(x, delta, A, B, C, D) else: y = self.selective_scan_seq(x, delta, A, B, C, D) - return y - def selective_scan(self, x, delta, A, B, C, D): - # x : (B, L, ED) - # A : (ED, N) - # B : (B, L, N) - # C : (B, L, N) - # D : (ED) - deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N) - deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N) - - BX = deltaB * mx.expand_dims(x, -1) # (B, L, ED, N) - + deltaA = mx.exp(mx.expand_dims(delta, -1) * A) + deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) + BX = deltaB * mx.expand_dims(x, -1) hs = pscan(deltaA, BX) - - y = (hs @ mx.expand_dims(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) - + y = (hs @ mx.expand_dims(C, -1)).squeeze(3) y = y + D * x - - return y # (B, L, ED) + return y def selective_scan_seq(self, x, delta, A, B, C, D): - # x : (B, L, ED) - # A : (ED, N) - # B : (B, L, N) - # C : (B, L, N) - # D : (ED) _, L, _ = x.shape - - deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N) - deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N) - - BX = deltaB * mx.expand_dims(x, -1) # (B, L, ED, N) - - h = mx.zeros([x.shape[0], self.config.d_inner, self.config.d_state]) # (B, ED, N) + deltaA = mx.exp(mx.expand_dims(delta, -1) * A) + deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) + BX = deltaB * mx.expand_dims(x, -1) + h = mx.zeros([x.shape[0], self.config.d_inner, self.config.d_state]) hs = [] - for t in range(0, L): h = deltaA[:, t] * h + BX[:, t] hs.append(h) - hs = mx.stack(hs, axis=1) - - y = (hs @ mx.expand_dims(C, -1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1) - + y = (hs @ mx.expand_dims(C, -1)).squeeze(3) y = y + D * x - - return y # (B, L, ED) - + return y def __call__(self, x, cache: MambaCache, layer_idx: int): B, T, D = x.shape - outputs = [] for t in range(T): - xt = x[:, t, :] # (B, D) - xz = self.in_proj(xt) # (B, 2*ED) - x_t, z_t = xz.split(indices_or_sections=2, axis=1) # (B, ED), (B, ED) - + xt = x[:, t, :] + xz = self.in_proj(xt) + x_t, z_t = xz.split(indices_or_sections=2, axis=1) if self.training: conv_out, _ = self.conv1d(mx.expand_dims(x_t, 1)) else: conv_state = cache.state[0][layer_idx] conv_out, new_conv_state = self.conv1d(mx.expand_dims(x_t, 1), conv_state) cache.state[0][layer_idx] = new_conv_state - - x_t = conv_out.squeeze(1) # (B, ED) + x_t = conv_out.squeeze(1) x_t = nn.silu(x_t) - if self.training: y_t = self.ssm_step(x_t) else: ssm_state = cache.state[1][layer_idx] y_t, new_ssm_state = self.ssm_step(x_t, ssm_state) cache.state[1][layer_idx] = new_ssm_state - z_t = nn.silu(z_t) - output_t = y_t * z_t - output_t = self.out_proj(output_t) # (B, D) + output_t = self.out_proj(output_t) outputs.append(output_t) - - output = mx.stack(outputs, axis=1) # (B, T, D) + output = mx.stack(outputs, axis=1) return output - + + class ResidualBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -351,12 +277,12 @@ def __init__(self, args: ModelArgs): def __call__(self, x: mx.array, cache: MambaCache, layer_idx: int): if x.ndim == 2: - x = mx.expand_dims(x, 1) # Make it (B, 1, D) - + x = mx.expand_dims(x, 1) output = self.mixer(self.norm(x), cache, layer_idx) output = output + x return output + class Mamba(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -370,13 +296,13 @@ def __call__(self, x: mx.array, cache: MambaCache): x = layer(x, cache, i) return self.norm_f(x) + class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args self.model_type = args.model_type self.backbone = Mamba(args) - if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) From ae9cc8c86231bfbeb1affd5c6939ccd4fc41c1f9 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 27 Sep 2024 17:37:34 -0700 Subject: [PATCH 37/40] some speedups and cleanups --- llms/mlx_lm/models/base.py | 16 -- llms/mlx_lm/models/mamba.py | 300 +++++++++++------------------------- 2 files changed, 93 insertions(+), 223 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 73b3a1f14..dc19dd05f 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -6,23 +6,7 @@ import mlx.core as mx import mlx.nn as nn - -class MambaCache: - def __init__(self, num_layers, batch_size, conv_state_size, ssm_state_size): - self.conv_states = [mx.zeros((batch_size, *conv_state_size)) for _ in range(num_layers)] - self.ssm_states = [mx.zeros((batch_size, *ssm_state_size)) for _ in range(num_layers)] - self.offset = 0 - - def update(self, layer_idx, conv_state, ssm_state): - self.conv_states[layer_idx] = conv_state - self.ssm_states[layer_idx] = ssm_state - self.offset += 1 - - @property - def state(self): - return self.conv_states, self.ssm_states - class KVCache: diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 66f609986..264084262 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -1,11 +1,12 @@ -from dataclasses import dataclass +# Copyright © 2024 Apple Inc. import math +from dataclasses import dataclass import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, MambaCache +from .base import BaseModelArgs @dataclass @@ -16,126 +17,73 @@ class ModelArgs(BaseModelArgs): intermediate_size: int state_size: int num_hidden_layers: int - layer_norm_epsilon: float - expand: int conv_kernel: int use_bias: bool use_conv_bias: bool - initializer_range: float time_step_rank: int - time_step_scale: float - time_step_min: float - time_step_max: float - time_step_init_scheme: str - time_step_floor: float - rescale_prenorm_residual: bool - use_cache: bool - pscan: bool = False tie_word_embeddings: bool = True def __post_init__(self): - if not hasattr(self, 'hidden_size') and hasattr(self, 'd_model'): + if not hasattr(self, "hidden_size") and hasattr(self, "d_model"): self.hidden_size = self.d_model - if not hasattr(self, 'intermediate_size') and hasattr(self, 'd_inner'): + if not hasattr(self, "intermediate_size") and hasattr(self, "d_inner"): self.intermediate_size = self.d_inner - if not hasattr(self, 'state_size') and hasattr(self, 'd_state'): + if not hasattr(self, "state_size") and hasattr(self, "d_state"): self.state_size = self.d_state - if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layer'): + if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layer"): self.num_hidden_layers = self.n_layer - if not hasattr(self, 'num_hidden_layers') and hasattr(self, 'n_layers'): + if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layers"): self.num_hidden_layers = self.n_layers - if not hasattr(self, 'conv_kernel') and hasattr(self, 'd_conv'): + if not hasattr(self, "conv_kernel") and hasattr(self, "d_conv"): self.conv_kernel = self.d_conv - if not hasattr(self, 'use_bias') and hasattr(self, 'bias'): + if not hasattr(self, "use_bias") and hasattr(self, "bias"): self.use_bias = self.bias - if not hasattr(self, 'use_conv_bias') and hasattr(self, 'conv_bias'): + if not hasattr(self, "use_conv_bias") and hasattr(self, "conv_bias"): self.use_conv_bias = self.conv_bias - self.intermediate_size = self.expand * self.hidden_size if self.time_step_rank == "auto": self.time_step_rank = math.ceil(self.hidden_size / 16) +class MambaCache: + def __init__(self): + self.cache = [None, None] + + def __setitem__(self, idx, value): + self.cache[idx] = value + + def __getitem__(self, idx): + return self.cache[idx] + + @property + def state(self): + return self.cache + + class DepthWiseConv1d(nn.Module): def __init__(self, channels, kernel_size, bias=True, padding=0): super().__init__() self.channels = channels self.kernel_size = kernel_size self.padding = padding - self.weight = mx.random.normal((channels, 1, kernel_size)) + self.weight = mx.random.normal((self.channels, kernel_size, 1)) self.bias = mx.zeros((channels,)) if bias else None - def __call__(self, x, conv_state=None): + def __call__(self, x, cache=None): B, L, C = x.shape - K = self.kernel_size - if conv_state is None: - conv_state = mx.zeros((B, K - 1, C)) - x = mx.concatenate([conv_state, x], axis=1) - output = [] - for i in range(K): - slice = x[:, i:i+L, :] - output.append(slice * self.weight[:, 0, i]) - y = mx.sum(mx.stack(output), axis=0) - if self.bias is not None: - y = y + self.bias.reshape(1, 1, -1) - new_conv_state = x[:, -K+1:, :] - return y, new_conv_state - - -def clamp(x, min=None, max=None): - if min is not None: - mask_lower = x < min - if max is not None: - mask_upper = x > max - if min is not None: - if max is not None: - return mx.where(mask_upper, max, mx.where(mask_lower, min, x)) - return mx.where(mask_lower, min, x) - return mx.where(mask_upper, max, x) - - -def pscan_f(A, X): - Aa = A - Xa = X - B, D, L, _ = A.shape - num_steps = int(math.log2(L)) - # up sweep - for k in range(num_steps): - T = 2 * (Xa.shape[2] // 2) - Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa[:, :, :, 1] += Aa[:, :, :, 1] * Xa[:, :, :, 0] - Aa[:, :, :, 1] *= Aa[:, :, :, 0] - A[:, :, 2**(k+1)-1::2**(k+1)] = Aa[:, :, :, 1] - X[:, :, 2**(k+1)-1::2**(k+1)] = Xa[:, :, :, 1] - Aa = Aa[:, :, :, 1] - Xa = Xa[:, :, :, 1] - # down sweep - for k in range(num_steps-1, -1, -1): - Aa = A[:, :, 2**k-1::2**k] - Xa = X[:, :, 2**k-1::2**k] - step_len = Xa.shape[2] - T = 2 * (step_len // 2) - if T < step_len: - last_val_aa = Aa[:, :, -1] * Aa[:, :, -2] - last_val_xa = Xa[:, :, -1] + Aa[:, :, -1] * Xa[:, :, -2] - Aa = Aa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa = Xa[:, :, :T].reshape(B, D, T//2, 2, -1) - Xa[:, :, 1:, 0] += Aa[:, :, 1:, 0] * Xa[:, :, :-1, 1] - Aa[:, :, 1:, 0] *= Aa[:, :, :-1, 1] - if T == step_len: - A[:, :, 2**k-1::2**(k+1)] = Aa[:, :, :, 0] - X[:, :, 2**k-1::2**(k+1)] = Xa[:, :, :, 0] + groups, K, _ = self.weight.shape + + if cache is not None: + x = mx.concatenate([cache, x], axis=1) else: - A[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Aa[:, :, :, 0], mx.array([last_val_aa]).reshape(B, D, 1, -1)], axis=2) - X[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Xa[:, :, :, 0], mx.array([last_val_xa]).reshape(B, D, 1, -1)], axis=2) + x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)]) + y = mx.conv_general(x, self.weight, groups=groups) -def pscan(A_in, X_in): - A = A_in[:].transpose(0, 2, 1, 3) - X = X_in[:].transpose(0, 2, 1, 3) - pscan_f(A, X) - return X.transpose(0, 2, 1, 3) + if self.bias is not None: + y = y + self.bias + + return y, x[:, -K + 1 :, :] class MambaBlock(nn.Module): @@ -150,117 +98,70 @@ def __init__(self, args: ModelArgs): self.time_step_rank = int(args.time_step_rank) self.use_conv_bias = args.use_conv_bias - self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=args.use_bias) + self.in_proj = nn.Linear( + self.hidden_size, self.intermediate_size * 2, bias=args.use_bias + ) self.conv1d = DepthWiseConv1d( channels=self.intermediate_size, kernel_size=self.conv_kernel_size, bias=self.use_conv_bias, - padding=self.conv_kernel_size-1 + padding=self.conv_kernel_size - 1, ) - self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + 2 * self.ssm_state_size, bias=False) + self.x_proj = nn.Linear( + self.intermediate_size, + self.time_step_rank + 2 * self.ssm_state_size, + bias=False, + ) self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) - dt_init_std = args.time_step_rank**-0.5 * args.state_size - if args.time_step_init_scheme == "constant": - self.dt_proj.weight = dt_init_std * mx.ones_like(self.dt_proj.weight) - elif args.time_step_init_scheme == "random": - self.dt_proj.weight = mx.random.uniform(-dt_init_std, dt_init_std, self.dt_proj.weight.shape) - else: - raise NotImplementedError - - dt = clamp(mx.exp( - mx.random.uniform(shape=[args.intermediate_size]) * (math.log(args.time_step_max) - math.log(args.time_step_min)) + math.log(args.time_step_min) - ), min=args.time_step_floor) - inv_dt = dt + mx.log1p(-mx.exp(-dt)) - self.dt_proj.bias = inv_dt - - A = mx.repeat(mx.arange(1., self.ssm_state_size + 1.).reshape([1, self.ssm_state_size]), repeats=self.intermediate_size, axis=0) + A = mx.repeat( + mx.arange(1.0, self.ssm_state_size + 1.0).reshape([1, self.ssm_state_size]), + repeats=self.intermediate_size, + axis=0, + ) self.A_log = mx.log(A) self.D = mx.ones([self.intermediate_size]) - self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) + self.out_proj = nn.Linear( + self.intermediate_size, self.hidden_size, bias=args.use_bias + ) - def ssm_step(self, x, ssm_state=None): + def ssm_step(self, x, state=None): A = -mx.exp(self.A_log) D = self.D deltaBC = self.x_proj(x) - delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) - delta = nn.softplus(self.dt_proj(delta)) - deltaA = mx.exp(mx.expand_dims(delta, -1) * A) - deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 1) - BX = deltaB * mx.expand_dims(x, -1) - if self.training: - new_ssm_state = BX - else: - if ssm_state is None: - ssm_state = mx.zeros((x.shape[0], self.intermediate_size, self.ssm_state_size)) - new_ssm_state = deltaA * ssm_state + BX - y = (new_ssm_state @ mx.expand_dims(C, -1)).squeeze(2) - y = y + D * x - if self.training: - return y - else: - return y, new_ssm_state - - def ssm(self, x): - A = -mx.exp(self.A_log) - D = self.D - deltaBC = self.x_proj(x) - delta, B, C = mx.split(deltaBC, indices_or_sections=[self.time_step_rank, self.time_step_rank+self.ssm_state_size], axis=-1) + delta, B, C = mx.split( + deltaBC, + indices_or_sections=[ + self.time_step_rank, + self.time_step_rank + self.ssm_state_size, + ], + axis=-1, + ) delta = nn.softplus(self.dt_proj(delta)) - if self.args.pscan: - y = self.selective_scan(x, delta, A, B, C, D) - else: - y = self.selective_scan_seq(x, delta, A, B, C, D) - return y - - def selective_scan(self, x, delta, A, B, C, D): - deltaA = mx.exp(mx.expand_dims(delta, -1) * A) - deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) - BX = deltaB * mx.expand_dims(x, -1) - hs = pscan(deltaA, BX) - y = (hs @ mx.expand_dims(C, -1)).squeeze(3) + new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1) + if state is not None: + new_state += state * mx.exp(mx.expand_dims(delta, -1) * A) + y = (new_state @ mx.expand_dims(C, -1)).squeeze(2) y = y + D * x - return y - - def selective_scan_seq(self, x, delta, A, B, C, D): - _, L, _ = x.shape - deltaA = mx.exp(mx.expand_dims(delta, -1) * A) - deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) - BX = deltaB * mx.expand_dims(x, -1) - h = mx.zeros([x.shape[0], self.config.d_inner, self.config.d_state]) - hs = [] - for t in range(0, L): - h = deltaA[:, t] * h + BX[:, t] - hs.append(h) - hs = mx.stack(hs, axis=1) - y = (hs @ mx.expand_dims(C, -1)).squeeze(3) - y = y + D * x - return y - - def __call__(self, x, cache: MambaCache, layer_idx: int): + return y, new_state + + def __call__(self, x, cache): B, T, D = x.shape + if cache is None: + cache = [None, None] + outputs = [] for t in range(T): xt = x[:, t, :] xz = self.in_proj(xt) x_t, z_t = xz.split(indices_or_sections=2, axis=1) - if self.training: - conv_out, _ = self.conv1d(mx.expand_dims(x_t, 1)) - else: - conv_state = cache.state[0][layer_idx] - conv_out, new_conv_state = self.conv1d(mx.expand_dims(x_t, 1), conv_state) - cache.state[0][layer_idx] = new_conv_state + conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) x_t = conv_out.squeeze(1) x_t = nn.silu(x_t) - if self.training: - y_t = self.ssm_step(x_t) - else: - ssm_state = cache.state[1][layer_idx] - y_t, new_ssm_state = self.ssm_step(x_t, ssm_state) - cache.state[1][layer_idx] = new_ssm_state + y_t, cache[1] = self.ssm_step(x_t, cache[1]) z_t = nn.silu(z_t) output_t = y_t * z_t output_t = self.out_proj(output_t) @@ -275,12 +176,8 @@ def __init__(self, args: ModelArgs): self.mixer = MambaBlock(args) self.norm = nn.RMSNorm(args.hidden_size) - def __call__(self, x: mx.array, cache: MambaCache, layer_idx: int): - if x.ndim == 2: - x = mx.expand_dims(x, 1) - output = self.mixer(self.norm(x), cache, layer_idx) - output = output + x - return output + def __call__(self, x: mx.array, cache): + return self.mixer(self.norm(x), cache) + x class Mamba(nn.Module): @@ -290,10 +187,12 @@ def __init__(self, args: ModelArgs): self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] self.norm_f = nn.RMSNorm(args.hidden_size) - def __call__(self, x: mx.array, cache: MambaCache): + def __call__(self, x: mx.array, cache): x = self.embeddings(x) - for i, layer in enumerate(self.layers): - x = layer(x, cache, i) + if cache is None: + cache = [None] * len(self.layers) + for layer, c in zip(self.layers, cache): + x = layer(x, c) return self.norm_f(x) @@ -306,17 +205,11 @@ def __init__(self, args: ModelArgs): if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - def __call__(self, inputs: mx.array, cache = None): - if inputs.ndim == 1: - inputs = mx.expand_dims(inputs, 0) - + def __call__(self, inputs: mx.array, cache=None): B, T = inputs.shape - - if not self.training and cache is None: - cache = self.make_cache(batch_size=B) - + x = self.backbone(inputs, cache) - + if self.args.tie_word_embeddings: logits = self.backbone.embeddings.as_linear(x) else: @@ -324,22 +217,15 @@ def __call__(self, inputs: mx.array, cache = None): return logits + def sanitize(self, weights): + for k, v in weights.items(): + if "conv1d.weight" in k and v.ndim == 3: + weights[k] = v.moveaxis(2, 1) + return weights + def make_cache(self, batch_size: int = 1): - return MambaCache( - num_layers=self.args.num_hidden_layers, - batch_size=batch_size, - conv_state_size=(self.args.conv_kernel - 1, self.args.intermediate_size), - ssm_state_size=(self.args.intermediate_size, self.args.state_size) - ) + return [MambaCache() for _ in range(len(self.layers))] @property def layers(self): return self.backbone.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_hidden_layers - - @property - def n_kv_heads(self): - return self.args.num_hidden_layers \ No newline at end of file From 03d3d19e6a26a73e4f6bfcbe6145ad40f3a1bd32 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 27 Sep 2024 17:49:48 -0700 Subject: [PATCH 38/40] fix test --- llms/tests/test_models.py | 33 +++++++-------------------------- 1 file changed, 7 insertions(+), 26 deletions(-) diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 83f15cda0..cd7e7fd07 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -5,6 +5,7 @@ import mlx.core as mx from mlx.utils import tree_map from mlx_lm.models.base import KVCache, RotatingKVCache +from mlx_lm.utils import make_kv_caches class TestModels(unittest.TestCase): @@ -100,13 +101,7 @@ def model_test_runner(self, model, model_type, vocab_size, num_layers): self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t) - kv_heads = ( - [model.n_kv_heads] * len(model.layers) - if isinstance(model.n_kv_heads, int) - else model.n_kv_heads - ) - cache = [KVCache(model.head_dim, n) for n in kv_heads] - + cache = make_kv_caches(model) outputs = model(inputs, cache) self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t) @@ -401,30 +396,16 @@ def test_mamba(self): from mlx_lm.models import mamba args = mamba.ModelArgs( + model_type="mamba", + vocab_size=10000, + use_bias=False, + use_conv_bias=True, conv_kernel=4, - d_inner=1536, - d_model=768, - expand=2, hidden_size=768, - initializer_range=0.1, - intermediate_size=1536, - layer_norm_epsilon=1e-05, - model_type="mamba", - n_layer=24, num_hidden_layers=24, state_size=16, - rms_norm=True, - rescale_prenorm_residual=False, - time_step_floor= 0.0001, - time_step_init_scheme="random", - time_step_max=0.1, - time_step_min=0.001, + intermediate_size=1536, time_step_rank=48, - time_step_scale=1.0, - vocab_size=10000, - use_bias=False, - use_conv_bias=True, - use_cache=True, ) model = mamba.Model(args) self.model_test_runner( From a10f20654a959682b0337133d57949f4adb14e20 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 28 Sep 2024 06:35:27 -0700 Subject: [PATCH 39/40] nits --- llms/mlx_lm/tuner/utils.py | 23 +++++++++++------------ llms/mlx_lm/utils.py | 3 +-- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 0a8a0913f..a66383770 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -58,7 +58,7 @@ def linear_to_lora_layers( layers = model.layers else: raise ValueError("Unsupported model structure") - + num_layers = len(layers) if num_lora_layers < 0: @@ -110,7 +110,6 @@ def to_lora(layer): "cohere", "minicpm", "deepseek", - "mamba" ]: keys = set(["self_attn.q_proj", "self_attn.v_proj"]) if model.model_type in ["mixtral", "phimoe"]: @@ -147,18 +146,18 @@ def to_lora(layer): "self_attn.kv_b_proj", ] ) - if model.model_type == "mamba": - keys = set([ - "mixer.in_proj", - "mixer.x_proj", - "mixer.dt_proj", - "mixer.out_proj", - ]) + elif model.model_type == "mamba": + keys = set( + [ + "mixer.in_proj", + "mixer.x_proj", + "mixer.dt_proj", + "mixer.out_proj", + ] + ) else: raise ValueError(f"Lora does not support {model.model_type}") - # Modified the layer selection to handle both regular and backbone structures: - layers = model.backbone.layers if hasattr(model, "backbone") else model.layers for l in model.layers[num_layers - num_lora_layers :]: lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys] if lora_layers: @@ -276,4 +275,4 @@ def nparams(m): print( f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% " f"({trainable_p:.3f}M/{total_p:.3f}M)" - ) \ No newline at end of file + ) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index c27962935..5621609de 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -27,7 +27,7 @@ # Constants MODEL_REMAPPING = { "mistral": "llama", # mistral is compatible with llama - "phi-msft": "phixtral" + "phi-msft": "phixtral", } MAX_FILE_SIZE_GB = 5 @@ -341,7 +341,6 @@ def generate( print("Prompt:", prompt) prompt_tokens = mx.array(tokenizer.encode(prompt)) - detokenizer = tokenizer.detokenizer tic = time.perf_counter() From 1738a06c3345f07caab03ca7f02b64380fd7431e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 28 Sep 2024 06:36:47 -0700 Subject: [PATCH 40/40] nits --- llms/mlx_lm/tuner/utils.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index a66383770..ab9d37aaa 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -52,14 +52,7 @@ def linear_to_lora_layers( use_dora (bool): If True, uses DoRA instead of LoRA. Default: ``False`` """ - if hasattr(model, "backbone"): - layers = model.backbone.layers - elif hasattr(model, "layers"): - layers = model.layers - else: - raise ValueError("Unsupported model structure") - - num_layers = len(layers) + num_layers = len(model.layers) if num_lora_layers < 0: num_lora_layers = num_layers