diff --git a/README.md b/README.md
index cbe205e..63ee098 100644
--- a/README.md
+++ b/README.md
@@ -16,21 +16,21 @@
-
Mambular: Tabular Deep Learning (with Mamba)
+ Mambular: Tabular Deep Learning
-Mambular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291). Also check out our paper introducing [TabulaRNN](https://arxiv.org/pdf/2411.17207) and analyzing the efficiency of NLP inspired tabular models.
+Mambular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, TabM and tabular ResNets. Check out our paper `Mambular: A Sequential Model for Tabular Deep Learning`, available [here](https://arxiv.org/abs/2408.06291). Also check out our paper introducing [TabulaRNN](https://arxiv.org/pdf/2411.17207) and analyzing the efficiency of NLP inspired tabular models.
Table of Contents
- [🏃 Quickstart](#-quickstart)
- [📖 Introduction](#-introduction)
- [🤖 Models](#-models)
-- [🏆 Results](#-results)
- [📚 Documentation](#-documentation)
- [🛠️ Installation](#️-installation)
- [🚀 Usage](#-usage)
- [💻 Implement Your Own Model](#-implement-your-own-model)
+- [Custom Training](#custom-training)
- [🏷️ Citation](#️-citation)
- [License](#license)
@@ -53,18 +53,18 @@ Mambular is a Python package that brings the power of advanced deep learning arc
# 🤖 Models
-| Model | Description |
-| ---------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| `Mambular` | A sequential model using Mamba blocks [Gu and Dao](https://arxiv.org/pdf/2312.00752) specifically designed for various tabular data tasks. |
-| `TabM` | Batch Ensembling for a MLP as introduced by [Gorishniy et al.](https://arxiv.org/abs/2410.24210) |
-| `NODE` | Neural Oblivious Decision Ensembles as introduced by [Popov et al.](https://arxiv.org/abs/1909.06312) |
-| `FTTransformer` | A model leveraging transformer encoders, as introduced by [Gorishniy et al.](https://arxiv.org/abs/2106.11959), for tabular data. |
-| `MLP` | A classical Multi-Layer Perceptron (MLP) model for handling tabular data tasks. |
-| `ResNet` | An adaptation of the ResNet architecture for tabular data applications. |
-| `TabTransformer` | A transformer-based model for tabular data introduced by [Huang et al.](https://arxiv.org/abs/2012.06678), enhancing feature learning capabilities. |
-| `MambaTab` | A tabular model using a Mamba-Block on a joint input representation described [here](https://arxiv.org/abs/2401.08867) . Not a sequential model. |
-| `TabulaRNN` | A Recurrent Neural Network for Tabular data. Not yet included in the benchmarks introduced [here](https://arxiv.org/pdf/2411.17207). |
-| `MambAttention` | A combination between Mamba and Transformers, similar to Jamba by [Lieber et al.](https://arxiv.org/abs/2403.19887). Not yet included in the benchmarks |
+| Model | Description |
+| ---------------- | --------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `Mambular` | A sequential model using Mamba blocks specifically designed for various tabular data tasks introduced [here](https://arxiv.org/abs/2408.06291). |
+| `TabM` | Batch Ensembling for a MLP as introduced by [Gorishniy et al.](https://arxiv.org/abs/2410.24210) |
+| `NODE` | Neural Oblivious Decision Ensembles as introduced by [Popov et al.](https://arxiv.org/abs/1909.06312) |
+| `FTTransformer` | A model leveraging transformer encoders, as introduced by [Gorishniy et al.](https://arxiv.org/abs/2106.11959), for tabular data. |
+| `MLP` | A classical Multi-Layer Perceptron (MLP) model for handling tabular data tasks. |
+| `ResNet` | An adaptation of the ResNet architecture for tabular data applications. |
+| `TabTransformer` | A transformer-based model for tabular data introduced by [Huang et al.](https://arxiv.org/abs/2012.06678), enhancing feature learning capabilities. |
+| `MambaTab` | A tabular model using a Mamba-Block on a joint input representation described [here](https://arxiv.org/abs/2401.08867) . Not a sequential model. |
+| `TabulaRNN` | A Recurrent Neural Network for Tabular data, introduced [here](https://arxiv.org/pdf/2411.17207). |
+| `MambAttention` | A combination between Mamba and Transformers, also introduced [here](https://arxiv.org/pdf/2411.17207). |
@@ -145,6 +145,59 @@ preds = model.predict(X)
preds = model.predict_proba(X)
```
+ Hyperparameter Optimization
+Since all of the models are sklearn base estimators, you can use the built-in hyperparameter optimizatino from sklearn.
+
+```python
+from sklearn.model_selection import RandomizedSearchCV
+
+param_dist = {
+ 'd_model': randint(32, 128),
+ 'n_layers': randint(2, 10),
+ 'lr': uniform(1e-5, 1e-3)
+}
+
+random_search = RandomizedSearchCV(
+ estimator=model,
+ param_distributions=param_dist,
+ n_iter=50, # Number of parameter settings sampled
+ cv=5, # 5-fold cross-validation
+ scoring='accuracy', # Metric to optimize
+ random_state=42
+)
+
+fit_params = {"max_epochs":5, "rebuild":False}
+
+# Fit the model
+random_search.fit(X, y, **fit_params)
+
+# Best parameters and score
+print("Best Parameters:", random_search.best_params_)
+print("Best Score:", random_search.best_score_)
+```
+Note, that using this, you can also optimize the preprocessing. Just use the prefix ``prepro__`` when specifying the preprocessor arguments you want to optimize:
+```python
+param_dist = {
+ 'd_model': randint(32, 128),
+ 'n_layers': randint(2, 10),
+ 'lr': uniform(1e-5, 1e-3),
+ "prepro__numerical_preprocessing": ["ple", "standardization", "box-cox"]
+}
+
+```
+
+
+Since we have early stopping integrated and return the best model with respect to the validation loss, setting max_epochs to a large number is sensible.
+
+
+Or use the built-in bayesian hpo simply by running:
+
+```python
+best_params = model.optimize_hparams(X, y)
+```
+
+This automatically sets the search space based on the default config from ``mambular.configs``. See the documentation for all params with regard to ``optimize_hparams()``. However, the preprocessor arguments are fixed and cannot be optimized here.
+
⚖️ Distributional Regression with MambularLSS
diff --git a/mambular/models/sklearn_base_classifier.py b/mambular/models/sklearn_base_classifier.py
index dc759bb..8ff0aa7 100644
--- a/mambular/models/sklearn_base_classifier.py
+++ b/mambular/models/sklearn_base_classifier.py
@@ -87,7 +87,7 @@ def get_params(self, deep=True):
if deep:
preprocessor_params = {
- "preprocessor__" + key: value
+ "prepro__" + key: value
for key, value in self.preprocessor.get_params().items()
}
params.update(preprocessor_params)
@@ -109,12 +109,12 @@ def set_params(self, **parameters):
Estimator instance.
"""
config_params = {
- k: v for k, v in parameters.items() if not k.startswith("preprocessor__")
+ k: v for k, v in parameters.items() if not k.startswith("prepro__")
}
preprocessor_params = {
k.split("__")[1]: v
for k, v in parameters.items()
- if k.startswith("preprocessor__")
+ if k.startswith("prepro__")
}
if config_params:
diff --git a/mambular/models/sklearn_base_lss.py b/mambular/models/sklearn_base_lss.py
index d97eab6..8178045 100644
--- a/mambular/models/sklearn_base_lss.py
+++ b/mambular/models/sklearn_base_lss.py
@@ -109,7 +109,7 @@ def get_params(self, deep=True):
if deep:
preprocessor_params = {
- "preprocessor__" + key: value
+ "prepro__" + key: value
for key, value in self.preprocessor.get_params().items()
}
params.update(preprocessor_params)
@@ -131,12 +131,12 @@ def set_params(self, **parameters):
Estimator instance.
"""
config_params = {
- k: v for k, v in parameters.items() if not k.startswith("preprocessor__")
+ k: v for k, v in parameters.items() if not k.startswith("prepro__")
}
preprocessor_params = {
k.split("__")[1]: v
for k, v in parameters.items()
- if k.startswith("preprocessor__")
+ if k.startswith("prepro__")
}
if config_params:
diff --git a/mambular/models/sklearn_base_regressor.py b/mambular/models/sklearn_base_regressor.py
index b77d11b..61fa01c 100644
--- a/mambular/models/sklearn_base_regressor.py
+++ b/mambular/models/sklearn_base_regressor.py
@@ -88,7 +88,7 @@ def get_params(self, deep=True):
if deep:
preprocessor_params = {
- "preprocessor__" + key: value
+ "prepro__" + key: value
for key, value in self.preprocessor.get_params().items()
}
params.update(preprocessor_params)
@@ -110,12 +110,12 @@ def set_params(self, **parameters):
Estimator instance.
"""
config_params = {
- k: v for k, v in parameters.items() if not k.startswith("preprocessor__")
+ k: v for k, v in parameters.items() if not k.startswith("prepro__")
}
preprocessor_params = {
k.split("__")[1]: v
for k, v in parameters.items()
- if k.startswith("preprocessor__")
+ if k.startswith("prepro__")
}
if config_params:
diff --git a/mambular/preprocessing/preprocessor.py b/mambular/preprocessing/preprocessor.py
index 3d7e83e..095da4b 100644
--- a/mambular/preprocessing/preprocessor.py
+++ b/mambular/preprocessing/preprocessor.py
@@ -131,7 +131,48 @@ def __init__(
self.degree = degree
self.n_knots = knots
+ def get_params(self, deep=True):
+ """
+ Get parameters for the preprocessor.
+
+ Parameters
+ ----------
+ deep : bool, default=True
+ If True, will return parameters of subobjects that are estimators.
+
+ Returns
+ -------
+ params : dict
+ Parameter names mapped to their values.
+ """
+ params = {
+ "n_bins": self.n_bins,
+ "numerical_preprocessing": self.numerical_preprocessing,
+ "categorical_preprocessing": self.categorical_preprocessing,
+ "use_decision_tree_bins": self.use_decision_tree_bins,
+ "binning_strategy": self.binning_strategy,
+ "task": self.task,
+ "cat_cutoff": self.cat_cutoff,
+ "treat_all_integers_as_numerical": self.treat_all_integers_as_numerical,
+ "degree": self.degree,
+ "knots": self.n_knots,
+ }
+ return params
+
def set_params(self, **params):
+ """
+ Set parameters for the preprocessor.
+
+ Parameters
+ ----------
+ **params : dict
+ Parameter names mapped to their new values.
+
+ Returns
+ -------
+ self : object
+ Preprocessor instance.
+ """
for key, value in params.items():
setattr(self, key, value)
return self
@@ -222,9 +263,11 @@ def fit(self, X, y=None):
(
"discretizer",
KBinsDiscretizer(
- n_bins=bins
- if isinstance(bins, int)
- else len(bins) - 1,
+ n_bins=(
+ bins
+ if isinstance(bins, int)
+ else len(bins) - 1
+ ),
encode="ordinal",
strategy=self.binning_strategy,
subsample=200_000 if len(X) > 200_000 else None,