Skip to content

Commit

Permalink
Merge pull request #145 from basf/efficiency
Browse files Browse the repository at this point in the history
Include original mamba-version, MambAttention and QuantileRegression
  • Loading branch information
AnFreTh authored Oct 24, 2024
2 parents ed5a0f3 + 10881eb commit 95e87dd
Show file tree
Hide file tree
Showing 18 changed files with 1,215 additions and 363 deletions.
33 changes: 24 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,17 @@ Mambular is a Python package that brings the power of advanced deep learning arc

# 🤖 Models

| Model | Description |
| ---------------- | --------------------------------------------------------------------------------------------------------------------------------------------------- |
| `Mambular` | A sequential model using Mamba blocks [Gu and Dao](https://arxiv.org/pdf/2312.00752) specifically designed for various tabular data tasks. |
| `FTTransformer` | A model leveraging transformer encoders, as introduced by [Gorishniy et al.](https://arxiv.org/abs/2106.11959), for tabular data. |
| `MLP` | A classical Multi-Layer Perceptron (MLP) model for handling tabular data tasks. |
| `ResNet` | An adaptation of the ResNet architecture for tabular data applications. |
| `TabTransformer` | A transformer-based model for tabular data introduced by [Huang et al.](https://arxiv.org/abs/2012.06678), enhancing feature learning capabilities. |
| `MambaTab` | A tabular model using a Mamba-Block on a joint input representation described [here](https://arxiv.org/abs/2401.08867) . Not a sequential model. |
| `TabulaRNN` | A Recurrent Neural Network for Tabular data. Not yet included in the benchmarks |
| Model | Description |
| ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `Mambular` | A sequential model using Mamba blocks [Gu and Dao](https://arxiv.org/pdf/2312.00752) specifically designed for various tabular data tasks. |
| `FTTransformer` | A model leveraging transformer encoders, as introduced by [Gorishniy et al.](https://arxiv.org/abs/2106.11959), for tabular data. |
| `MLP` | A classical Multi-Layer Perceptron (MLP) model for handling tabular data tasks. |
| `ResNet` | An adaptation of the ResNet architecture for tabular data applications. |
| `TabTransformer` | A transformer-based model for tabular data introduced by [Huang et al.](https://arxiv.org/abs/2012.06678), enhancing feature learning capabilities. |
| `MambaTab` | A tabular model using a Mamba-Block on a joint input representation described [here](https://arxiv.org/abs/2401.08867) . Not a sequential model. |
| `TabulaRNN` | A Recurrent Neural Network for Tabular data. Not yet included in the benchmarks |
| `MambAttention` | A combination between Mamba and Transformers, similar to Jamba by [Lieber et al.](https://arxiv.org/abs/2403.19887). Not yet included in the benchmarks |



All models are available for `regression`, `classification` and distributional regression, denoted by `LSS`.
Expand Down Expand Up @@ -135,6 +137,19 @@ Install Mambular using pip:
pip install mambular
```

If you want to use the original mamba and mamba2 implementations, additionally install mamba-ssm via:

```sh
pip install mamba-ssm
```

Be careful to use the correct torch and cuda versions:

```sh
pip install torch==2.0.0+cu118 torchvision==0.15.0+cu118 torchaudio==2.0.0+cu118 -f https://download.pytorch.org/whl/cu118/torch_stable.html
pip install mamba-ssm
```

# 🚀 Usage

<h2> Preprocessing </h2>
Expand Down
Loading

0 comments on commit 95e87dd

Please sign in to comment.