PyxLSTM is a Python library that provides an efficient and extensible implementation of the Extended Long Short-Term Memory (xLSTM) architecture based on the research paper "xLSTM: Extended Long Short-Term Memory" by Beck et al. (2024). xLSTM enhances the traditional LSTM by introducing exponential gating, memory mixing, and a matrix memory structure, enabling improved performance and scalability for sequence modeling tasks.
- Features
- Installation
- Development Installation
- Usage
- Code Directory Structure
- Running and Testing the Codebase
- Documentation
- Citation
- Contributing
- License
- Acknowledgements
- Contact
- Star History
- TODO
- Implements the sLSTM (scalar LSTM) and mLSTM (matrix LSTM) variants of xLSTM
- Supports pre and post up-projection block structures for flexible model architectures
- Provides high-level model definition and training utilities for ease of use
- Includes scripts for training, evaluation, and text generation
- Offers data processing utilities and customizable dataset classes
- Lightweight and modular design for seamless integration into existing projects
- Extensively tested and documented for reliability and usability
- Suitable for a wide range of sequence modeling tasks, including language modeling, text generation, and more
To install PyxLSTM, you can use pip:
pip install PyxLSTM
For development installation with testing dependencies:
pip install PyxLSTM[dev]
Alternatively, you can clone the repository and install it manually:
git clone https://github.com/muditbhargava66/PyxLSTM.git
cd PyxLSTM
pip install -r requirements.txt
pip install -e .
Here's a basic example of how to use PyxLSTM for language modeling:
import torch
from xLSTM.model import xLSTM
from xLSTM.data import LanguageModelingDataset, Tokenizer
from xLSTM.utils import load_config, set_seed, get_device
from xLSTM.training import train # Assuming train function is defined in training module
# Load configuration
config = load_config("path/to/config.yaml")
set_seed(config.seed)
device = get_device()
# Initialize tokenizer and dataset
tokenizer = Tokenizer(config.vocab_file)
train_dataset = LanguageModelingDataset(config.train_data, tokenizer, config.max_length)
# Create xLSTM model
model = xLSTM(len(tokenizer), config.embedding_size, config.hidden_size,
config.num_layers, config.num_blocks, config.dropout,
config.bidirectional, config.lstm_type)
model.to(device)
# Train the model
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
criterion = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
train(model, train_dataset, optimizer, criterion, config, device)
For more detailed usage instructions and examples, please refer to the documentation.
xLSTM/
│
├── xLSTM/
│ ├── __init__.py
│ ├── slstm.py
│ ├── mlstm.py
│ ├── block.py
│ └── model.py
│
├── utils/
│ ├── config.py
│ ├── logging.py
│ └── utils.py
│
├── tests/
│ ├── test_slstm.py
│ ├── test_mlstm.py
│ ├── test_block.py
│ └── test_model.py
│
├── docs/
│ ├── slstm.md
│ ├── mlstm.md
│ └── training.md
│
├── examples/
│ ├── language_modeling.py
│ └── xLSTM_shape_verification.py
│
├── .gitignore
├── pyproject.toml
├── MANIFEST.in
├── requirements.txt
├── README.md
└── LICENSE
-
xLSTM/: The main Python package containing the implementation.
- slstm.py: Implementation of the sLSTM module.
- mlstm.py: Implementation of the mLSTM module.
- block.py: Implementation of the xLSTM blocks (pre and post up-projection).
- model.py: High-level xLSTM model definition.
-
utils/: Utility modules.
config.py
: Configuration management.logging.py
: Logging setup.utils.py
: Miscellaneous utility functions.
-
tests/: Unit tests for different modules.
test_slstm.py
: Tests for sLSTM module.test_mlstm.py
: Tests for mLSTM module.test_block.py
: Tests for xLSTM blocks.test_model.py
: Tests for the overall xLSTM model.
-
docs/: Documentation files.
README.md
: Main documentation file.slstm.md
: Documentation for sLSTM.mlstm.md
: Documentation for mLSTM.training.md
: Training guide.
-
.gitignore: Git ignore file to exclude unnecessary files/directories.
-
setup.py: Package setup script.
-
requirements.txt: List of required Python dependencies.
-
README.md: Project README file.
-
LICENSE: Project license file.
To run and test the PyxLSTM codebase, follow these steps:
-
Clone the PyxLSTM repository:
git clone https://github.com/muditbhargava66/PyxLSTM.git
-
Navigate to the cloned directory:
cd PyxLSTM
-
Install the required dependencies:
pip install -r requirements.txt
-
Run the unit tests:
python -m unittest discover tests
This command will run all the unit tests located in the
tests
directory. It will execute the test filestest_slstm.py
,test_mlstm.py
,test_block.py
, andtest_model.py
.
If you encounter any issues or have further questions, please refer to the PyxLSTM documentation or reach out to the maintainers for assistance.
The documentation for PyxLSTM can be found in the docs directory. It provides detailed information about the library's components, usage guidelines, and examples.
If you use PyxLSTM in your research or projects, please cite the original xLSTM paper:
@article{Beck2024xLSTM,
title={xLSTM: Extended Long Short-Term Memory},
author={Beck, Maximilian and Pöppel, Korbinian and Spanring, Markus and Auer, Andreas and Prudnikova, Oleksandra and Kopp, Michael and Klambauer, Günter and Brandstetter, Johannes and Hochreiter, Sepp},
journal={arXiv preprint arXiv:2405.04517},
year={2024}
}
Paper link: https://arxiv.org/abs/2405.04517
Contributions to PyxLSTM are welcome! If you find any issues or have suggestions for improvements, please open an issue or submit a pull request on the GitHub repository.
PyxLSTM is released under the MIT License. See the LICENSE
file for more information.
We would like to acknowledge the original authors of the xLSTM architecture for their valuable research and contributions to the field of sequence modeling.
For any questions or inquiries, please contact the project maintainer:
- Name: Mudit Bhargava
- GitHub: @muditbhargava66
We hope you find PyxLSTM useful for your sequence modeling projects!
- Add support for Python 3.10
- Add support for macOS MPS
- Add support for Windows MPS
- Add support for Linux MPS
- Provide more examples on time series prediction
- Include reinforcement learning examples
- Add examples for modeling physical systems
- Enhance documentation with advanced usage scenarios
- Improve unit tests for new features
- Add support for bidirectional parameter as it's not implemented in the current xLSTM model