Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor fix #99

Merged
merged 7 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions mambular/arch_utils/embedding_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(
layer_norm_after_embedding=False,
use_cls=False,
cls_position=0,
cat_encoding="int",
):
"""
Embedding layer that handles numerical and categorical embeddings.
Expand Down Expand Up @@ -56,15 +57,23 @@ def __init__(
]
)

self.cat_embeddings = nn.ModuleList(
[
nn.Sequential(
nn.Embedding(num_categories + 1, d_model),
self.embedding_activation,
self.cat_embeddings = nn.ModuleList()
for feature_name, num_categories in cat_feature_info.items():
if cat_encoding == "int":
self.cat_embeddings.append(
nn.Sequential(
nn.Embedding(num_categories + 1, d_model),
self.embedding_activation,
)
)
elif cat_encoding == "one-hot":
self.cat_embeddings.append(
nn.Sequential(
OneHotEncoding(num_categories),
nn.Linear(num_categories, d_model, bias=False),
self.embedding_activation,
)
)
for feature_name, num_categories in cat_feature_info.items()
]
)

if self.use_cls:
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
Expand Down Expand Up @@ -143,3 +152,12 @@ def forward(self, num_features=None, cat_features=None):
)

return x


class OneHotEncoding(nn.Module):
def __init__(self, num_categories):
super(OneHotEncoding, self).__init__()
self.num_categories = num_categories

def forward(self, x):
return torch.nn.functional.one_hot(x, num_classes=self.num_categories).float()
5 changes: 4 additions & 1 deletion mambular/base_models/ft_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,12 @@ def __init__(
embedding_activation=self.hparams.get(
"embedding_activation", config.embedding_activation
),
layer_norm_after_embedding=self.hparams.get("layer_norm_after_embedding"),
layer_norm_after_embedding=self.hparams.get(
"layer_norm_after_embedding", config.layer_norm_after_embedding
),
use_cls=True,
cls_position=0,
cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding),
)

head_activation = self.hparams.get("head_activation", config.head_activation)
Expand Down
42 changes: 6 additions & 36 deletions mambular/base_models/lightning_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
else:
output_dim = num_classes

self.model = model_class(
self.base_model = model_class(
config=config,
num_feature_info=num_feature_info,
cat_feature_info=cat_feature_info,
Expand All @@ -107,7 +107,7 @@ def forward(self, num_features, cat_features):
Model output.
"""

return self.model.forward(num_features, cat_features)
return self.base_model.forward(num_features, cat_features)

def compute_loss(self, predictions, y_true):
"""
Expand Down Expand Up @@ -168,16 +168,6 @@ def training_step(self, batch, batch_idx):
prog_bar=True,
logger=True,
)
elif isinstance(self.loss_fct, nn.MSELoss):
rmse = torch.sqrt(loss)
self.log(
"train_rmse",
rmse,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)

return loss

Expand Down Expand Up @@ -205,7 +195,7 @@ def validation_step(self, batch, batch_idx):
self.log(
"val_loss",
val_loss,
on_step=True,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
Expand All @@ -218,17 +208,7 @@ def validation_step(self, batch, batch_idx):
self.log(
"val_acc",
acc,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
elif isinstance(self.loss_fct, nn.MSELoss):
rmse = torch.sqrt(val_loss)
self.log(
"val_rmse",
rmse,
on_step=True,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
Expand Down Expand Up @@ -272,17 +252,7 @@ def test_step(self, batch, batch_idx):
self.log(
"test_acc",
acc,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
elif isinstance(self.loss_fct, nn.MSELoss):
rmse = torch.sqrt(test_loss)
self.log(
"test_rmse",
rmse,
on_step=True,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
Expand All @@ -300,7 +270,7 @@ def configure_optimizers(self):
A dictionary containing the optimizer and lr_scheduler configurations.
"""
optimizer = torch.optim.Adam(
self.model.parameters(),
self.base_model.parameters(),
lr=self.lr,
weight_decay=self.weight_decay,
)
Expand Down
9 changes: 6 additions & 3 deletions mambular/base_models/mambular.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,12 @@ def __init__(
embedding_activation=self.hparams.get(
"embedding_activation", config.embedding_activation
),
layer_norm_after_embedding=self.hparams.get("layer_norm_after_embedding"),
use_cls=True,
cls_position=0,
layer_norm_after_embedding=self.hparams.get(
"layer_norm_after_embedding", config.layer_norm_after_embedding
),
use_cls=False,
cls_position=-1,
cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding),
)

head_activation = self.hparams.get("head_activation", config.head_activation)
Expand Down
5 changes: 4 additions & 1 deletion mambular/base_models/tabtransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,12 @@ def __init__(
embedding_activation=self.hparams.get(
"embedding_activation", config.embedding_activation
),
layer_norm_after_embedding=self.hparams.get("layer_norm_after_embedding"),
layer_norm_after_embedding=self.hparams.get(
"layer_norm_after_embedding", config.layer_norm_after_embedding
),
use_cls=True,
cls_position=0,
cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding),
)

head_activation = self.hparams.get("head_activation", config.head_activation)
Expand Down
4 changes: 3 additions & 1 deletion mambular/configs/fttransformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class DefaultFTTransformerConfig:
Epsilon value for layer normalization.
transformer_dim_feedforward : int, default=512
Dimensionality of the feed-forward layers in the transformer.
cat_encoding : str, default="int"
whether to use integer encoding or one-hot encoding for cat features.
"""

lr: float = 1e-04
Expand All @@ -84,4 +86,4 @@ class DefaultFTTransformerConfig:
transformer_activation: callable = ReGLU()
layer_norm_eps: float = 1e-05
transformer_dim_feedforward: int = 256
numerical_embedding: str = "ple"
cat_encoding: str = "int"
7 changes: 5 additions & 2 deletions mambular/configs/mambular_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ class DefaultMambularConfig:
layer_norm_eps : float, default=1e-05
Epsilon value for layer normalization.
AD_weight_decay : bool, default=False
whether weight decay is also applied to A-D matrices
whether weight decay is also applied to A-D matrices.
BC_layer_norm: bool, default=True
whether to apply layer normalization to B-C matrices
whether to apply layer normalization to B-C matrices.
cat_encoding : str, default="int"
whether to use integer encoding or one-hot encoding for cat features.
"""

lr: float = 1e-04
Expand Down Expand Up @@ -116,3 +118,4 @@ class DefaultMambularConfig:
layer_norm_eps: float = 1e-05
AD_weight_decay: bool = False
BC_layer_norm: bool = True
cat_encoding: str = "int"
3 changes: 3 additions & 0 deletions mambular/configs/tabtransformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class DefaultTabTransformerConfig:
Epsilon value for layer normalization.
transformer_dim_feedforward : int, default=512
Dimensionality of the feed-forward layers in the transformer.
cat_encoding : str, default="int"
whether to use integer encoding or one-hot encoding for cat features.
"""

lr: float = 1e-04
Expand All @@ -84,3 +86,4 @@ class DefaultTabTransformerConfig:
transformer_activation: callable = ReGLU()
layer_norm_eps: float = 1e-05
transformer_dim_feedforward: int = 512
cat_encoding: str = "int"
6 changes: 6 additions & 0 deletions mambular/models/fttransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class FTTransformerRegressor(SklearnBaseRegressor):
Epsilon value for layer normalization.
transformer_dim_feedforward : int, default=512
Dimensionality of the feed-forward layers in the transformer.
cat_encoding : str, default="int"
whether to use integer encoding or one-hot encoding for cat features.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
Expand Down Expand Up @@ -171,6 +173,8 @@ class FTTransformerClassifier(SklearnBaseClassifier):
Epsilon value for layer normalization.
transformer_dim_feedforward : int, default=512
Dimensionality of the feed-forward layers in the transformer.
cat_encoding : str, default="int"
whether to use integer encoding or one-hot encoding for cat features.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
Expand Down Expand Up @@ -278,6 +282,8 @@ class FTTransformerLSS(SklearnBaseLSS):
Epsilon value for layer normalization.
transformer_dim_feedforward : int, default=512
Dimensionality of the feed-forward layers in the transformer.
cat_encoding : str, default="int"
whether to use integer encoding or one-hot encoding for cat features.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
Expand Down
24 changes: 24 additions & 0 deletions mambular/models/mambular.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ class MambularRegressor(SklearnBaseRegressor):
Whether to append a cls to the end of each 'sequence'.
shuffle_embeddings : bool, default=False.
Whether to shuffle the embeddings before being passed to the Mamba layers.
layer_norm_eps : float, default=1e-05
Epsilon value for layer normalization.
AD_weight_decay : bool, default=False
whether weight decay is also applied to A-D matrices.
BC_layer_norm: bool, default=True
whether to apply layer normalization to B-C matrices.
cat_encoding : str, default="int"
whether to use integer encoding or one-hot encoding for cat features.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
Expand Down Expand Up @@ -198,6 +206,14 @@ class MambularClassifier(SklearnBaseClassifier):
Whether to use learnable feature interactions before passing through mamba blocks.
shuffle_embeddings : bool, default=False.
Whether to shuffle the embeddings before being passed to the Mamba layers.
layer_norm_eps : float, default=1e-05
Epsilon value for layer normalization.
AD_weight_decay : bool, default=False
whether weight decay is also applied to A-D matrices.
BC_layer_norm: bool, default=True
whether to apply layer normalization to B-C matrices.
cat_encoding : str, default="int"
whether to use integer encoding or one-hot encoding for cat features.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
Expand Down Expand Up @@ -320,6 +336,14 @@ class MambularLSS(SklearnBaseLSS):
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
shuffle_embeddings : bool, default=False.
Whether to shuffle the embeddings before being passed to the Mamba layers.
layer_norm_eps : float, default=1e-05
Epsilon value for layer normalization.
AD_weight_decay : bool, default=False
whether weight decay is also applied to A-D matrices.
BC_layer_norm: bool, default=True
whether to apply layer normalization to B-C matrices.
cat_encoding : str, default="int"
whether to use integer encoding or one-hot encoding for cat features.
numerical_preprocessing : str, default="ple"
The preprocessing strategy for numerical features. Valid options are
'binning', 'one_hot', 'standardization', and 'normalization'.
Expand Down
Loading
Loading