Skip to content

MLX implementation of xLSTM model by Beck et al. (2024)

License

Notifications You must be signed in to change notification settings

abeleinin/mlx-xLSTM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

18 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MLX xLSTM

A pure MLX implementation of xLSTM: Extended Long Short-Term Memory by Beck et al. (2024)

Install

Use the following commands to install the package:

git clone git@github.com:abeleinin/mlx-xLSTM.git
cd mlx-xLSTM/
pip install -r requirement.txt
pip install -e .

Usage

The models are implemented in their own respective python files in the mlx_xlstm/ directory. Here is an overview of what models are currently implemented:

files description
mLSTM.py implements mLSTM and mLSTMBlock shown in (Figure 10)
sLSTM.py implements sLSTMCell, sLSTM, and sLSTMBlock shown in (Figure 9)
xLSTM.py implements xLSTM

If you're interested, I've also created a simple training example in the examples/ directory, which showcases how to use the different models on a simple learning task.

mLSTM Training

Here is a brief example on how to train a mLSTMBlock in mlx:

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

from mlx_xlstm import mLSTMBlock

def loss_fn(model, X, states, y):
    return nn.losses.mse_loss(model(X, states)[0], y) # choose loss function

input_size = 1
head_dim = 4
head_num = 8

batch_size = 5

model = mLSTMBlock(input_size, head_dim, head_num)
mx.eval(model.parameters())

loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

optimizer = optim.Adam(learning_rate=0.01) # choose optimizer 

data = ... # choose dataset

for t in range(seq_len - 1):
    X = data[:, t, :]
    y = data[:, t+1, :]
    l, grads = loss_and_grad_fn(model, X, states, y_true)

    optimizer.update(model, grads)
    mx.eval(model.parameters(), optimizer.state)

For more details, please refer to full implementation examples/train_mLSTMBlock.py. I was able to train a simple model which learns a sine function.

mLSTMBlock sine function predition

Unit Tests

Run unit tests:

python -m unittest discover tests

Roadmap

  • Implenent sub-components
    • mLSTM implementation
    • sLSTM implementation
    • mLSTMBlock
    • sLSTMBlock
  • Add full xLSTM implementation
  • Add unit tests
  • Add training examples for each component
  • Add language model example

Citations

@article{beck2024xlstm,
  title={xLSTM: Extended Long Short-Term Memory},
  author={Beck, Maximilian and P{\"o}ppel, Korbinian and Spanring, Markus and Auer, Andreas and Prudnikova, Oleksandra and Kopp, Michael and Klambauer, G{\"u}nter and Brandstetter, Johannes and Hochreiter, Sepp},
  journal={arXiv preprint arXiv:2405.04517},
  year={2024}
}
@software{mlx2023,
  author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
  title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
  url = {https://github.com/ml-explore},
  version = {0.0},
  year = {2023},
}

References

About

MLX implementation of xLSTM model by Beck et al. (2024)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages