Skip to content

Commit

Permalink
fix mamba models conversion (#1065)
Browse files Browse the repository at this point in the history
  • Loading branch information
awni authored Oct 22, 2024
1 parent d1d4808 commit 9000e28
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion llms/mlx_lm/models/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def __call__(self, inputs: mx.array, cache=None):

def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3:
if "conv1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
return weights

Expand Down
2 changes: 1 addition & 1 deletion llms/mlx_lm/models/recurrent_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def layers(self):

def sanitize(self, weights):
for k, v in weights.items():
if "conv_1d.weight" in k and v.ndim == 3:
if "conv_1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
if "lm_head.weight" not in weights:
self.pop("lm_head")
Expand Down

0 comments on commit 9000e28

Please sign in to comment.