Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

include tabularRNN #102

Merged
merged 1 commit into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions mambular/base_models/tabularnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import torch
import torch.nn as nn
from ..arch_utils.mlp_utils import MLP
from ..configs.tabularnn_config import DefaultTabulaRNNConfig
from .basemodel import BaseModel
from ..arch_utils.embedding_layer import EmbeddingLayer
from ..arch_utils.normalization_layers import (
RMSNorm,
LayerNorm,
LearnableLayerScaling,
BatchNorm,
InstanceNorm,
GroupNorm,
)


class TabulaRNN(BaseModel):
def __init__(
self,
cat_feature_info,
num_feature_info,
num_classes=1,
config: DefaultTabulaRNNConfig = DefaultTabulaRNNConfig(),
**kwargs,
):
super().__init__(**kwargs)
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])

self.lr = self.hparams.get("lr", config.lr)
self.lr_patience = self.hparams.get("lr_patience", config.lr_patience)
self.weight_decay = self.hparams.get("weight_decay", config.weight_decay)
self.lr_factor = self.hparams.get("lr_factor", config.lr_factor)
self.pooling_method = self.hparams.get("pooling_method", config.pooling_method)
self.cat_feature_info = cat_feature_info
self.num_feature_info = num_feature_info

norm_layer = self.hparams.get("norm", config.norm)
if norm_layer == "RMSNorm":
self.norm_f = RMSNorm(
self.hparams.get("dim_feedforward", config.dim_feedforward)
)
elif norm_layer == "LayerNorm":
self.norm_f = LayerNorm(
self.hparams.get("dim_feedforward", config.dim_feedforward)
)
elif norm_layer == "BatchNorm":
self.norm_f = BatchNorm(
self.hparams.get("dim_feedforward", config.dim_feedforward)
)
elif norm_layer == "InstanceNorm":
self.norm_f = InstanceNorm(
self.hparams.get("dim_feedforward", config.dim_feedforward)
)
elif norm_layer == "GroupNorm":
self.norm_f = GroupNorm(
1, self.hparams.get("dim_feedforward", config.dim_feedforward)
)
elif norm_layer == "LearnableLayerScaling":
self.norm_f = LearnableLayerScaling(
self.hparams.get("dim_feedforward", config.dim_feedforward)
)
else:
self.norm_f = None

rnn_layer = {"RNN": nn.RNN, "LSTM": nn.LSTM, "GRU": nn.GRU}[config.model_type]
self.rnn = rnn_layer(
input_size=self.hparams.get("d_model", config.d_model),
hidden_size=self.hparams.get("dim_feedforward", config.dim_feedforward),
num_layers=self.hparams.get("n_layers", config.n_layers),
bidirectional=self.hparams.get("bidirectional", config.bidirectional),
batch_first=True,
dropout=self.hparams.get("rnn_dropout", config.rnn_dropout),
bias=self.hparams.get("bias", config.bias),
nonlinearity=(
self.hparams.get("rnn_activation", config.rnn_activation)
if config.model_type == "RNN"
else None
),
)

self.embedding_layer = EmbeddingLayer(
num_feature_info=num_feature_info,
cat_feature_info=cat_feature_info,
d_model=self.hparams.get("d_model", config.d_model),
embedding_activation=self.hparams.get(
"embedding_activation", config.embedding_activation
),
layer_norm_after_embedding=self.hparams.get(
"layer_norm_after_embedding", config.layer_norm_after_embedding
),
use_cls=False,
cls_position=-1,
cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding),
)

head_activation = self.hparams.get("head_activation", config.head_activation)

self.tabular_head = MLP(
self.hparams.get("dim_feedforward", config.dim_feedforward),
hidden_units_list=self.hparams.get(
"head_layer_sizes", config.head_layer_sizes
),
dropout_rate=self.hparams.get("head_dropout", config.head_dropout),
use_skip_layers=self.hparams.get(
"head_skip_layers", config.head_skip_layers
),
activation_fn=head_activation,
use_batch_norm=self.hparams.get(
"head_use_batch_norm", config.head_use_batch_norm
),
n_output_units=num_classes,
)

self.linear = nn.Linear(config.d_model, config.dim_feedforward)

def forward(self, num_features, cat_features):
"""
Defines the forward pass of the model.

Parameters
----------
num_features : Tensor
Tensor containing the numerical features.
cat_features : Tensor
Tensor containing the categorical features.

Returns
-------
Tensor
The output predictions of the model.
"""

x = self.embedding_layer(num_features, cat_features)
# RNN forward pass
out, _ = self.rnn(x)
z = self.linear(torch.mean(x, dim=1))

if self.pooling_method == "avg":
x = torch.mean(out, dim=1)
elif self.pooling_method == "max":
x, _ = torch.max(out, dim=1)
elif self.pooling_method == "sum":
x = torch.sum(out, dim=1)
elif self.pooling_method == "last":
x = x[:, -1, :]
else:
raise ValueError(f"Invalid pooling method: {self.pooling_method}")
x = x + z
if self.norm_f is not None:
x = self.norm_f(x)
preds = self.tabular_head(x)

return preds
83 changes: 83 additions & 0 deletions mambular/configs/tabularnn_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from dataclasses import dataclass
import torch.nn as nn


@dataclass
class DefaultTabulaRNNConfig:
"""
Configuration class for the default TabulaRNN model with predefined hyperparameters.

Parameters
----------
lr : float, default=1e-04
Learning rate for the optimizer.
model_type : str, default="RNN"
type of model, one of "RNN", "LSTM", "GRU"
lr_patience : int, default=10
Number of epochs with no improvement after which learning rate will be reduced.
weight_decay : float, default=1e-06
Weight decay (L2 penalty) for the optimizer.
lr_factor : float, default=0.1
Factor by which the learning rate will be reduced.
d_model : int, default=64
Dimensionality of the model.
n_layers : int, default=8
Number of layers in the transformer.
norm : str, default="RMSNorm"
Normalization method to be used.
activation : callable, default=nn.SELU()
Activation function for the transformer.
embedding_activation : callable, default=nn.Identity()
Activation function for numerical embeddings.
head_layer_sizes : list, default=(128, 64, 32)
Sizes of the layers in the head of the model.
head_dropout : float, default=0.5
Dropout rate for the head layers.
head_skip_layers : bool, default=False
Whether to skip layers in the head.
head_activation : callable, default=nn.SELU()
Activation function for the head layers.
head_use_batch_norm : bool, default=False
Whether to use batch normalization in the head layers.
layer_norm_after_embedding : bool, default=False
Whether to apply layer normalization after embedding.
pooling_method : str, default="cls"
Pooling method to be used ('cls', 'avg', etc.).
norm_first : bool, default=False
Whether to apply normalization before other operations in each transformer block.
bias : bool, default=True
Whether to use bias in the linear layers.
rnn_activation : callable, default=nn.SELU()
Activation function for the transformer layers.
bidirectional : bool, default=False.
Whether to process data bidirectionally
cat_encoding : str, default="int"
Encoding method for categorical features.
"""

lr: float = 1e-04
model_type: str = "RNN"
lr_patience: int = 10
weight_decay: float = 1e-06
lr_factor: float = 0.1
d_model: int = 128
n_layers: int = 4
rnn_dropout: float = 0.2
norm: str = "RMSNorm"
activation: callable = nn.SELU()
embedding_activation: callable = nn.Identity()
head_layer_sizes: list = ()
head_dropout: float = 0.5
head_skip_layers: bool = False
head_activation: callable = nn.SELU()
head_use_batch_norm: bool = False
layer_norm_after_embedding: bool = False
pooling_method: str = "avg"
norm_first: bool = False
bias: bool = True
rnn_activation: str = "relu"
layer_norm_eps: float = 1e-05
dim_feedforward: int = 256
numerical_embedding: str = "ple"
bidirectional: bool = False
cat_encoding: str = "int"
4 changes: 4 additions & 0 deletions mambular/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)

from .mambatab import MambaTabClassifier, MambaTabRegressor, MambaTabLSS
from .tabularnn import TabulaRNNClassifier, TabulaRNNRegressor, TabulaRNNLSS


__all__ = [
Expand All @@ -40,4 +41,7 @@
"MambaTabRegressor",
"MambaTabClassifier",
"MambaTabLSS",
"TabulaRNNClassifier",
"TabulaRNNRegressor",
"TabulaRNNLSS",
]
Loading
Loading