Skip to content

Commit

Permalink
Merge pull request #99 from basf/minor_fix
Browse files Browse the repository at this point in the history
Minor fix
  • Loading branch information
AnFreTh authored Aug 5, 2024
2 parents 8933cf7 + 71f35e6 commit e6b90dc
Show file tree
Hide file tree
Showing 14 changed files with 133 additions and 90 deletions.
34 changes: 26 additions & 8 deletions mambular/arch_utils/embedding_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(
layer_norm_after_embedding=False,
use_cls=False,
cls_position=0,
cat_encoding="int",
):
"""
Embedding layer that handles numerical and categorical embeddings.
Expand Down Expand Up @@ -56,15 +57,23 @@ def __init__(
]
)

self.cat_embeddings = nn.ModuleList(
[
nn.Sequential(
nn.Embedding(num_categories + 1, d_model),
self.embedding_activation,
self.cat_embeddings = nn.ModuleList()
for feature_name, num_categories in cat_feature_info.items():
if cat_encoding == "int":
self.cat_embeddings.append(
nn.Sequential(
nn.Embedding(num_categories + 1, d_model),
self.embedding_activation,
)
)
elif cat_encoding == "one-hot":
self.cat_embeddings.append(
nn.Sequential(
OneHotEncoding(num_categories),
nn.Linear(num_categories, d_model, bias=False),
self.embedding_activation,
)
)
for feature_name, num_categories in cat_feature_info.items()
]
)

if self.use_cls:
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
Expand Down Expand Up @@ -143,3 +152,12 @@ def forward(self, num_features=None, cat_features=None):
)

return x


class OneHotEncoding(nn.Module):
def __init__(self, num_categories):
super(OneHotEncoding, self).__init__()
self.num_categories = num_categories

def forward(self, x):
return torch.nn.functional.one_hot(x, num_classes=self.num_categories).float()
5 changes: 4 additions & 1 deletion mambular/base_models/ft_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,12 @@ def __init__(
embedding_activation=self.hparams.get(
"embedding_activation", config.embedding_activation
),
layer_norm_after_embedding=self.hparams.get("layer_norm_after_embedding"),
layer_norm_after_embedding=self.hparams.get(
"layer_norm_after_embedding", config.layer_norm_after_embedding
),
use_cls=True,
cls_position=0,
cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding),
)

head_activation = self.hparams.get("head_activation", config.head_activation)
Expand Down
42 changes: 6 additions & 36 deletions mambular/base_models/lightning_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
else:
output_dim = num_classes

self.model = model_class(
self.base_model = model_class(
config=config,
num_feature_info=num_feature_info,
cat_feature_info=cat_feature_info,
Expand All @@ -107,7 +107,7 @@ def forward(self, num_features, cat_features):
Model output.
"""

return self.model.forward(num_features, cat_features)
return self.base_model.forward(num_features, cat_features)

def compute_loss(self, predictions, y_true):
"""
Expand Down Expand Up @@ -168,16 +168,6 @@ def training_step(self, batch, batch_idx):
prog_bar=True,
logger=True,
)
elif isinstance(self.loss_fct, nn.MSELoss):
rmse = torch.sqrt(loss)
self.log(
"train_rmse",
rmse,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)

return loss

Expand Down Expand Up @@ -205,7 +195,7 @@ def validation_step(self, batch, batch_idx):
self.log(
"val_loss",
val_loss,
on_step=True,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
Expand All @@ -218,17 +208,7 @@ def validation_step(self, batch, batch_idx):
self.log(
"val_acc",
acc,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
elif isinstance(self.loss_fct, nn.MSELoss):
rmse = torch.sqrt(val_loss)
self.log(
"val_rmse",
rmse,
on_step=True,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
Expand Down Expand Up @@ -272,17 +252,7 @@ def test_step(self, batch, batch_idx):
self.log(
"test_acc",
acc,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
elif isinstance(self.loss_fct, nn.MSELoss):
rmse = torch.sqrt(test_loss)
self.log(
"test_rmse",
rmse,
on_step=True,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
Expand All @@ -300,7 +270,7 @@ def configure_optimizers(self):
A dictionary containing the optimizer and lr_scheduler configurations.
"""
optimizer = torch.optim.Adam(
self.model.parameters(),
self.base_model.parameters(),
lr=self.lr,
weight_decay=self.weight_decay,
)
Expand Down
9 changes: 6 additions & 3 deletions mambular/base_models/mambular.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,12 @@ def __init__(
embedding_activation=self.hparams.get(
"embedding_activation", config.embedding_activation
),
layer_norm_after_embedding=self.hparams.get("layer_norm_after_embedding"),
use_cls=True,
cls_position=0,
layer_norm_after_embedding=self.hparams.get(
"layer_norm_after_embedding", config.layer_norm_after_embedding
),
use_cls=False,
cls_position=-1,
cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding),
)

head_activation = self.hparams.get("head_activation", config.head_activation)
Expand Down
5 changes: 4 additions & 1 deletion mambular/base_models/tabtransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,12 @@ def __init__(
embedding_activation=self.hparams.get(
"embedding_activation", config.embedding_activation
),
layer_norm_after_embedding=self.hparams.get("layer_norm_after_embedding"),
layer_norm_after_embedding=self.hparams.get(
"layer_norm_after_embedding", config.layer_norm_after_embedding
),
use_cls=True,
cls_position=0,
cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding),
)

head_activation = self.hparams.get("head_activation", config.head_activation)
Expand Down
4 changes: 3 additions & 1 deletion mambular/configs/fttransformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class DefaultFTTransformerConfig:
Epsilon value for layer normalization.
transformer_dim_feedforward : int, default=512
Dimensionality of the feed-forward layers in the transformer.
cat_encoding : str, default="int"
whether to use integer encoding or one-hot encoding for cat features.
"""

lr: float = 1e-04
Expand All @@ -84,4 +86,4 @@ class DefaultFTTransformerConfig:
transformer_activation: callable = ReGLU()
layer_norm_eps: float = 1e-05
transformer_dim_feedforward: int = 256
numerical_embedding: str = "ple"
cat_encoding: str = "int"
7 changes: 5 additions & 2 deletions mambular/configs/mambular_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ class DefaultMambularConfig:
layer_norm_eps : float, default=1e-05
Epsilon value for layer normalization.
AD_weight_decay : bool, default=False
whether weight decay is also applied to A-D matrices
whether weight decay is also applied to A-D matrices.
BC_layer_norm: bool, default=True
whether to apply layer normalization to B-C matrices
whether to apply layer normalization to B-C matrices.
cat_encoding : str, default="int"
whether to use integer encoding or one-hot encoding for cat features.
"""

lr: float = 1e-04
Expand Down Expand Up @@ -116,3 +118,4 @@ class DefaultMambularConfig:
layer_norm_eps: float = 1e-05
AD_weight_decay: bool = False
BC_layer_norm: bool = True
cat_encoding: str = "int"
3 changes: 3 additions & 0 deletions mambular/configs/tabtransformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class DefaultTabTransformerConfig:
Epsilon value for layer normalization.
transformer_dim_feedforward : int, default=512
Dimensionality of the feed-forward layers in the transformer.
cat_encoding : str, default="int"
whether to use integer encoding or one-hot encoding for cat features.
"""

lr: float = 1e-04
Expand All @@ -84,3 +86,4 @@ class DefaultTabTransformerConfig:
transformer_activation: callable = ReGLU()
layer_norm_eps: float = 1e-05
transformer_dim_feedforward: int = 512
cat_encoding: str = "int"
6 changes: 6 additions & 0 deletions mambular/models/fttransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class FTTransformerRegressor(SklearnBaseRegressor):
Epsilon value for layer normalization.
transformer_dim_feedforward : int, default=512
Dimensionality of the feed-forward layers in the transformer.
cat_encoding : str, default="int"
whether to use integer encoding or one-hot encoding for cat features.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
Expand Down Expand Up @@ -171,6 +173,8 @@ class FTTransformerClassifier(SklearnBaseClassifier):
Epsilon value for layer normalization.
transformer_dim_feedforward : int, default=512
Dimensionality of the feed-forward layers in the transformer.
cat_encoding : str, default="int"
whether to use integer encoding or one-hot encoding for cat features.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
Expand Down Expand Up @@ -278,6 +282,8 @@ class FTTransformerLSS(SklearnBaseLSS):
Epsilon value for layer normalization.
transformer_dim_feedforward : int, default=512
Dimensionality of the feed-forward layers in the transformer.
cat_encoding : str, default="int"
whether to use integer encoding or one-hot encoding for cat features.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
Expand Down
24 changes: 24 additions & 0 deletions mambular/models/mambular.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ class MambularRegressor(SklearnBaseRegressor):
Whether to append a cls to the end of each 'sequence'.
shuffle_embeddings : bool, default=False.
Whether to shuffle the embeddings before being passed to the Mamba layers.
layer_norm_eps : float, default=1e-05
Epsilon value for layer normalization.
AD_weight_decay : bool, default=False
whether weight decay is also applied to A-D matrices.
BC_layer_norm: bool, default=True
whether to apply layer normalization to B-C matrices.
cat_encoding : str, default="int"
whether to use integer encoding or one-hot encoding for cat features.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
Expand Down Expand Up @@ -198,6 +206,14 @@ class MambularClassifier(SklearnBaseClassifier):
Whether to use learnable feature interactions before passing through mamba blocks.
shuffle_embeddings : bool, default=False.
Whether to shuffle the embeddings before being passed to the Mamba layers.
layer_norm_eps : float, default=1e-05
Epsilon value for layer normalization.
AD_weight_decay : bool, default=False
whether weight decay is also applied to A-D matrices.
BC_layer_norm: bool, default=True
whether to apply layer normalization to B-C matrices.
cat_encoding : str, default="int"
whether to use integer encoding or one-hot encoding for cat features.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
Expand Down Expand Up @@ -320,6 +336,14 @@ class MambularLSS(SklearnBaseLSS):
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
shuffle_embeddings : bool, default=False.
Whether to shuffle the embeddings before being passed to the Mamba layers.
layer_norm_eps : float, default=1e-05
Epsilon value for layer normalization.
AD_weight_decay : bool, default=False
whether weight decay is also applied to A-D matrices.
BC_layer_norm: bool, default=True
whether to apply layer normalization to B-C matrices.
cat_encoding : str, default="int"
whether to use integer encoding or one-hot encoding for cat features.
numerical_preprocessing : str, default="ple"
The preprocessing strategy for numerical features. Valid options are
'binning', 'one_hot', 'standardization', and 'normalization'.
Expand Down
Loading

0 comments on commit e6b90dc

Please sign in to comment.