From bdac22be88f1672acaa5e97f81b026bbf796e0bc Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Tue, 5 Nov 2024 10:50:15 +0000 Subject: [PATCH 01/11] add ntdf config in init --- mambular/configs/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mambular/configs/__init__.py b/mambular/configs/__init__.py index a990e05..31ee342 100644 --- a/mambular/configs/__init__.py +++ b/mambular/configs/__init__.py @@ -6,6 +6,7 @@ from .mambatab_config import DefaultMambaTabConfig from .tabularnn_config import DefaultTabulaRNNConfig from .mambattention_config import DefaultMambAttentionConfig +from .ndtf_config import DefaultNDTFConfig __all__ = [ @@ -17,4 +18,5 @@ "DefaultMambaTabConfig", "DefaultTabulaRNNConfig", "DefaultMambAttentionConfig", + "DefaultNDTFConfig", ] From 9eb5d421b071f50b27db3ed0cd0e4105231deff0 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Tue, 5 Nov 2024 14:51:30 +0000 Subject: [PATCH 02/11] add sparsemax --- mambular/arch_utils/layer_utils/sparsemax.py | 117 +++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 mambular/arch_utils/layer_utils/sparsemax.py diff --git a/mambular/arch_utils/layer_utils/sparsemax.py b/mambular/arch_utils/layer_utils/sparsemax.py new file mode 100644 index 0000000..cfcc00f --- /dev/null +++ b/mambular/arch_utils/layer_utils/sparsemax.py @@ -0,0 +1,117 @@ +import torch +from torch.autograd import Function + + +def _make_ix_like(input, dim=0): + """ + Creates a tensor of indices like the input tensor along the specified dimension. + + Parameters + ---------- + input : torch.Tensor + Input tensor whose shape will be used to determine the shape of the output tensor. + dim : int, optional + Dimension along which to create the index tensor. Default is 0. + + Returns + ------- + torch.Tensor + A tensor containing indices along the specified dimension. + """ + d = input.size(dim) + rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) + view = [1] * input.dim() + view[0] = -1 + return rho.view(view).transpose(0, dim) + + +class SparsemaxFunction(Function): + """ + Implements the sparsemax function, a sparse alternative to softmax. + + References + ---------- + Martins, A. F., & Astudillo, R. F. (2016). "From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification." + """ + + @staticmethod + def forward(ctx, input, dim=-1): + """ + Forward pass of sparsemax: a normalizing, sparse transformation. + + Parameters + ---------- + input : torch.Tensor + The input tensor on which sparsemax will be applied. + dim : int, optional + Dimension along which to apply sparsemax. Default is -1. + + Returns + ------- + torch.Tensor + A tensor with the same shape as the input, with sparsemax applied. + """ + ctx.dim = dim + max_val, _ = input.max(dim=dim, keepdim=True) + input -= max_val # Numerical stability trick, as with softmax. + tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim) + output = torch.clamp(input - tau, min=0) + ctx.save_for_backward(supp_size, output) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass of sparsemax, calculating gradients. + + Parameters + ---------- + grad_output : torch.Tensor + Gradient of the loss with respect to the output of sparsemax. + + Returns + ------- + tuple + Gradients of the loss with respect to the input of sparsemax and None for the dimension argument. + """ + supp_size, output = ctx.saved_tensors + dim = ctx.dim + grad_input = grad_output.clone() + grad_input[output == 0] = 0 + + v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() + v_hat = v_hat.unsqueeze(dim) + grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) + return grad_input, None + + @staticmethod + def _threshold_and_support(input, dim=-1): + """ + Computes the threshold and support for sparsemax. + + Parameters + ---------- + input : torch.Tensor + The input tensor on which to compute the threshold and support. + dim : int, optional + Dimension along which to compute the threshold and support. Default is -1. + + Returns + ------- + tuple + - torch.Tensor : The threshold value for sparsemax. + - torch.Tensor : The support size tensor. + """ + input_srt, _ = torch.sort(input, descending=True, dim=dim) + input_cumsum = input_srt.cumsum(dim) - 1 + rhos = _make_ix_like(input, dim) + support = rhos * input_srt > input_cumsum + + support_size = support.sum(dim=dim).unsqueeze(dim) + tau = input_cumsum.gather(dim, support_size - 1) + tau /= support_size.to(input.dtype) + return tau, support_size + + +sparsemax = lambda input, dim=-1: SparsemaxFunction.apply(input, dim) +sparsemoid = lambda input: (0.5 * input + 0.5).clamp_(0, 1) From ab3abbf87ce92b2fc9e833a56eada1b3704df9d5 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Tue, 5 Nov 2024 14:51:46 +0000 Subject: [PATCH 03/11] data-aware initialization module --- .../arch_utils/data_aware_initialization.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 mambular/arch_utils/data_aware_initialization.py diff --git a/mambular/arch_utils/data_aware_initialization.py b/mambular/arch_utils/data_aware_initialization.py new file mode 100644 index 0000000..00e58a7 --- /dev/null +++ b/mambular/arch_utils/data_aware_initialization.py @@ -0,0 +1,29 @@ +import torch.nn as nn +import torch + + +class ModuleWithInit(nn.Module): + """Base class for pytorch module with data-aware initializer on first batch + See https://github.com/yandex-research/rtdl-revisiting-models/tree/main/lib/node + + Helps to avoid nans in feature logits before being passed to sparsemax""" + + def __init__(self): + super().__init__() + self._is_initialized_tensor = nn.Parameter( + torch.tensor(0, dtype=torch.uint8), requires_grad=False + ) + self._is_initialized_bool = None + + def initialize(self, *args, **kwargs): + """initialize module tensors using first batch of data""" + raise NotImplementedError("Please implement ") + + def __call__(self, *args, **kwargs): + if self._is_initialized_bool is None: + self._is_initialized_bool = bool(self._is_initialized_tensor.item()) + if not self._is_initialized_bool: + self.initialize(*args, **kwargs) + self._is_initialized_tensor.data[...] = 1 + self._is_initialized_bool = True + return super().__call__(*args, **kwargs) From 473db6be67ef6573764aeeb07b8ac8c9b19fe0ae Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Tue, 5 Nov 2024 14:52:01 +0000 Subject: [PATCH 04/11] utils func for checking if tensor or np.array --- mambular/arch_utils/numpy_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 mambular/arch_utils/numpy_utils.py diff --git a/mambular/arch_utils/numpy_utils.py b/mambular/arch_utils/numpy_utils.py new file mode 100644 index 0000000..82098dc --- /dev/null +++ b/mambular/arch_utils/numpy_utils.py @@ -0,0 +1,11 @@ +import torch +import numpy as np + + +def check_numpy(x): + """Makes sure x is a numpy array""" + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + x = np.asarray(x) + assert isinstance(x, np.ndarray) + return x From e59bfcb35f3f29a615cbefee65dadccedfb4bc63 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Tue, 5 Nov 2024 14:52:10 +0000 Subject: [PATCH 05/11] add ODST and DenseBlock --- mambular/arch_utils/node_utils.py | 370 ++++++++++++++++++++++++++++++ 1 file changed, 370 insertions(+) create mode 100644 mambular/arch_utils/node_utils.py diff --git a/mambular/arch_utils/node_utils.py b/mambular/arch_utils/node_utils.py new file mode 100644 index 0000000..03fc2e6 --- /dev/null +++ b/mambular/arch_utils/node_utils.py @@ -0,0 +1,370 @@ +# Source: https://github.com/Qwicen/node +from warnings import warn + +import numpy as np +import torch +import torch.nn as nn +from .layer_utils.sparsemax import sparsemax, sparsemoid +import torch.functional as F +from .data_aware_initialization import ModuleWithInit +from .numpy_utils import check_numpy + + +class ODST(ModuleWithInit): + def __init__( + self, + in_features, + num_trees, + depth=6, + tree_dim=1, + flatten_output=True, + choice_function=sparsemax, + bin_function=sparsemoid, + initialize_response_=nn.init.normal_, + initialize_selection_logits_=nn.init.uniform_, + threshold_init_beta=1.0, + threshold_init_cutoff=1.0, + ): + """ + Oblivious Differentiable Sparsemax Trees (ODST). + + ODST is a differentiable module for decision tree-based models, where each tree + is trained using sparsemax to compute feature weights and sparsemoid to compute + binary leaf weights. This class is designed as a drop-in replacement for `nn.Linear` layers. + + Parameters + ---------- + in_features : int + Number of features in the input tensor. + num_trees : int + Number of trees in this layer. + depth : int, optional + Number of splits (depth) in each tree. Default is 6. + tree_dim : int, optional + Number of output channels for each tree's response. Default is 1. + flatten_output : bool, optional + If True, returns output in a flattened shape of [..., num_trees * tree_dim]; + otherwise returns [..., num_trees, tree_dim]. Default is True. + choice_function : callable, optional + Function that computes feature weights as a simplex, such that + `choice_function(tensor, dim).sum(dim) == 1`. Default is `sparsemax`. + bin_function : callable, optional + Function that computes tree leaf weights as values in the range [0, 1]. + Default is `sparsemoid`. + initialize_response_ : callable, optional + In-place initializer for the response tensor in each tree. Default is `nn.init.normal_`. + initialize_selection_logits_ : callable, optional + In-place initializer for the feature selection logits. Default is `nn.init.uniform_`. + threshold_init_beta : float, optional + Initializes thresholds based on quantiles of the data using a Beta distribution. + Controls the initial threshold distribution; values > 1 make thresholds closer to the median. + Default is 1.0. + threshold_init_cutoff : float, optional + Initializer for log-temperatures, with values > 1.0 adding margin between data points + and sparse-sigmoid cutoffs. Default is 1.0. + + Attributes + ---------- + response : torch.nn.Parameter + Parameter for tree responses. + feature_selection_logits : torch.nn.Parameter + Logits that select features for the trees. + feature_thresholds : torch.nn.Parameter + Threshold values for feature splits in the trees. + log_temperatures : torch.nn.Parameter + Log-temperatures for threshold adjustments. + bin_codes_1hot : torch.nn.Parameter + One-hot encoded binary codes for leaf mapping. + + Methods + ------- + forward(input) + Forward pass through the ODST model. + initialize(input, eps=1e-6) + Data-aware initialization of thresholds and log-temperatures based on input data. + """ + + super().__init__() + self.depth, self.num_trees, self.tree_dim, self.flatten_output = ( + depth, + num_trees, + tree_dim, + flatten_output, + ) + self.choice_function, self.bin_function = choice_function, bin_function + self.threshold_init_beta, self.threshold_init_cutoff = ( + threshold_init_beta, + threshold_init_cutoff, + ) + + self.response = nn.Parameter( + torch.zeros([num_trees, tree_dim, 2**depth]), requires_grad=True + ) + initialize_response_(self.response) + + self.feature_selection_logits = nn.Parameter( + torch.zeros([in_features, num_trees, depth]), requires_grad=True + ) + initialize_selection_logits_(self.feature_selection_logits) + + self.feature_thresholds = nn.Parameter( + torch.full([num_trees, depth], float("nan"), dtype=torch.float32), + requires_grad=True, + ) # nan values will be initialized on first batch (data-aware init) + + self.log_temperatures = nn.Parameter( + torch.full([num_trees, depth], float("nan"), dtype=torch.float32), + requires_grad=True, + ) + + # binary codes for mapping between 1-hot vectors and bin indices + with torch.no_grad(): + indices = torch.arange(2**self.depth) + offsets = 2 ** torch.arange(self.depth) + bin_codes = (indices.view(1, -1) // offsets.view(-1, 1) % 2).to( + torch.float32 + ) + bin_codes_1hot = torch.stack([bin_codes, 1.0 - bin_codes], dim=-1) + self.bin_codes_1hot = nn.Parameter(bin_codes_1hot, requires_grad=False) + # ^-- [depth, 2 ** depth, 2] + + def forward(self, input): + """ + Forward pass through ODST model. + + Parameters + ---------- + input : torch.Tensor + Input tensor of shape [batch_size, in_features] or higher dimensions. + + Returns + ------- + torch.Tensor + Output tensor of shape [batch_size, num_trees * tree_dim] if `flatten_output` is True, + otherwise [batch_size, num_trees, tree_dim]. + """ + assert len(input.shape) >= 2 + if len(input.shape) > 2: + return self.forward(input.view(-1, input.shape[-1])).view( + *input.shape[:-1], -1 + ) + # new input shape: [batch_size, in_features] + + feature_logits = self.feature_selection_logits + feature_selectors = self.choice_function(feature_logits, dim=0) + # ^--[in_features, num_trees, depth] + + feature_values = torch.einsum("bi,ind->bnd", input, feature_selectors) + # ^--[batch_size, num_trees, depth] + + threshold_logits = (feature_values - self.feature_thresholds) * torch.exp( + -self.log_temperatures + ) + + threshold_logits = torch.stack([-threshold_logits, threshold_logits], dim=-1) + # ^--[batch_size, num_trees, depth, 2] + + bins = self.bin_function(threshold_logits) + # ^--[batch_size, num_trees, depth, 2], approximately binary + + bin_matches = torch.einsum("btds,dcs->btdc", bins, self.bin_codes_1hot) + # ^--[batch_size, num_trees, depth, 2 ** depth] + + response_weights = torch.prod(bin_matches, dim=-2) + # ^-- [batch_size, num_trees, 2 ** depth] + + response = torch.einsum("bnd,ncd->bnc", response_weights, self.response) + # ^-- [batch_size, num_trees, tree_dim] + + return response.flatten(1, 2) if self.flatten_output else response + + def initialize(self, input, eps=1e-6): + """ + Data-aware initialization of thresholds and log-temperatures based on input data. + + Parameters + ---------- + input : torch.Tensor + Tensor of shape [batch_size, in_features] used for threshold initialization. + eps : float, optional + Small value added to avoid log(0) errors in temperature initialization. Default is 1e-6. + """ + # data-aware initializer + assert len(input.shape) == 2 + if input.shape[0] < 1000: + warn( + "Data-aware initialization is performed on less than 1000 data points. This may cause instability." + "To avoid potential problems, run this model on a data batch with at least 1000 data samples." + "You can do so manually before training. Use with torch.no_grad() for memory efficiency." + ) + with torch.no_grad(): + feature_selectors = self.choice_function( + self.feature_selection_logits, dim=0 + ) + # ^--[in_features, num_trees, depth] + + feature_values = torch.einsum("bi,ind->bnd", input, feature_selectors) + # ^--[batch_size, num_trees, depth] + + # initialize thresholds: sample random percentiles of data + percentiles_q = 100 * np.random.beta( + self.threshold_init_beta, + self.threshold_init_beta, + size=[self.num_trees, self.depth], + ) + self.feature_thresholds.data[...] = torch.as_tensor( + list( + map( + np.percentile, + check_numpy(feature_values.flatten(1, 2).t()), + percentiles_q.flatten(), + ) + ), + dtype=feature_values.dtype, + device=feature_values.device, + ).view(self.num_trees, self.depth) + + # init temperatures: make sure enough data points are in the linear region of sparse-sigmoid + temperatures = np.percentile( + check_numpy(abs(feature_values - self.feature_thresholds)), + q=100 * min(1.0, self.threshold_init_cutoff), + axis=0, + ) + + # if threshold_init_cutoff > 1, scale everything down by it + temperatures /= max(1.0, self.threshold_init_cutoff) + self.log_temperatures.data[...] = torch.log( + torch.as_tensor(temperatures) + eps + ) + + def __repr__(self): + return "{}(in_features={}, num_trees={}, depth={}, tree_dim={}, flatten_output={})".format( + self.__class__.__name__, + self.feature_selection_logits.shape[0], + self.num_trees, + self.depth, + self.tree_dim, + self.flatten_output, + ) + + +class DenseBlock(nn.Sequential): + """ + DenseBlock is a multi-layer module that sequentially stacks instances of `Module`, + typically decision tree models like `ODST`. Each layer in the block produces additional + features, enabling the model to learn complex representations. + + Parameters + ---------- + input_dim : int + Dimensionality of the input features. + layer_dim : int + Dimensionality of each layer in the block. + num_layers : int + Number of layers to stack in the block. + tree_dim : int, optional + Dimensionality of the output channels from each tree. Default is 1. + max_features : int, optional + Maximum dimensionality for feature expansion. If None, feature expansion is unrestricted. + Default is None. + input_dropout : float, optional + Dropout rate applied to the input features of each layer during training. Default is 0.0. + flatten_output : bool, optional + If True, flattens the output along the tree dimension. Default is True. + Module : nn.Module, optional + Module class to use for each layer in the block, typically a decision tree model. + Default is `ODST`. + **kwargs : dict + Additional keyword arguments for the `Module` instances. + + Attributes + ---------- + num_layers : int + Number of layers in the block. + layer_dim : int + Dimensionality of each layer. + tree_dim : int + Dimensionality of each tree's output in the layer. + max_features : int or None + Maximum feature dimensionality allowed for expansion. + flatten_output : bool + Determines whether to flatten the output. + input_dropout : float + Dropout rate applied to each layer's input. + + Methods + ------- + forward(x) + Performs the forward pass through the block, producing feature-expanded outputs. + """ + + def __init__( + self, + input_dim, + layer_dim, + num_layers, + tree_dim=1, + max_features=None, + input_dropout=0.0, + flatten_output=True, + Module=ODST, + **kwargs + ): + layers = [] + for i in range(num_layers): + oddt = Module( + input_dim, layer_dim, tree_dim=tree_dim, flatten_output=True, **kwargs + ) + input_dim = min( + input_dim + layer_dim * tree_dim, max_features or float("inf") + ) + layers.append(oddt) + + super().__init__(*layers) + self.num_layers, self.layer_dim, self.tree_dim = num_layers, layer_dim, tree_dim + self.max_features, self.flatten_output = max_features, flatten_output + self.input_dropout = input_dropout + + def forward(self, x): + """ + Forward pass through the DenseBlock. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape [batch_size, input_dim] or higher dimensions. + + Returns + ------- + torch.Tensor + Output tensor with expanded features, where shape depends on `flatten_output`. + If `flatten_output` is True, returns tensor of shape + [..., num_layers * layer_dim * tree_dim]. + Otherwise, returns [..., num_layers * layer_dim, tree_dim]. + """ + initial_features = x.shape[-1] + for layer in self: + layer_inp = x + if self.max_features is not None: + tail_features = ( + min(self.max_features, layer_inp.shape[-1]) - initial_features + ) + if tail_features != 0: + layer_inp = torch.cat( + [ + layer_inp[..., :initial_features], + layer_inp[..., -tail_features:], + ], + dim=-1, + ) + if self.training and self.input_dropout: + layer_inp = F.dropout(layer_inp, self.input_dropout) + h = layer(layer_inp) + x = torch.cat([x, h], dim=-1) + + outputs = x[..., initial_features:] + if not self.flatten_output: + outputs = outputs.view( + *outputs.shape[:-1], self.num_layers * self.layer_dim, self.tree_dim + ) + return outputs From 8d362dfbb32b4df4f445dab0c478cf9cc6e81f50 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Tue, 5 Nov 2024 14:52:23 +0000 Subject: [PATCH 06/11] add node into basemodels - includes tabular MLP head --- mambular/base_models/node.py | 159 +++++++++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) create mode 100644 mambular/base_models/node.py diff --git a/mambular/base_models/node.py b/mambular/base_models/node.py new file mode 100644 index 0000000..e61a817 --- /dev/null +++ b/mambular/base_models/node.py @@ -0,0 +1,159 @@ +from .basemodel import BaseModel +from ..configs.node_config import DefaultNODEConfig +import torch +from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer +from ..arch_utils.node_utils import DenseBlock +from ..arch_utils.mlp_utils import MLP + + +class NODE(BaseModel): + """ + Neural Oblivious Decision Ensemble (NODE) Model. Slightly different with a MLP as a tabular task specific head. + + NODE is a neural decision tree model that processes both categorical and numerical features. + This class combines embedding layers, a dense decision tree block, and an MLP head for tabular + data prediction tasks. + + Parameters + ---------- + cat_feature_info : dict + Dictionary mapping categorical feature names to their input shapes. + num_feature_info : dict + Dictionary mapping numerical feature names to their input shapes. + num_classes : int, optional + Number of output classes. Default is 1. + config : DefaultNODEConfig, optional + Configuration object that holds model hyperparameters. Default is `DefaultNODEConfig`. + **kwargs : dict + Additional arguments for the base model. + + Attributes + ---------- + lr : float + Learning rate for the optimizer. + lr_patience : int + Number of epochs without improvement before reducing the learning rate. + weight_decay : float + Weight decay factor for regularization. + lr_factor : float + Factor by which to reduce the learning rate. + cat_feature_info : dict + Information about categorical features. + num_feature_info : dict + Information about numerical features. + use_embeddings : bool + Whether to use embeddings for categorical and numerical features. + embedding_layer : EmbeddingLayer, optional + Embedding layer for feature transformation. + d_out : int + Output dimensionality. + block : DenseBlock + DenseBlock layer that implements the decision tree ensemble. + tabular_head : MLP + MLP layer that serves as the output head of the model. + + Methods + ------- + forward(num_features, cat_features) + Performs the forward pass, processing numerical and categorical features to produce predictions. + """ + + def __init__( + self, + cat_feature_info, + num_feature_info, + num_classes: int = 1, + config: DefaultNODEConfig = DefaultNODEConfig(), + **kwargs, + ): + super().__init__(**kwargs) + self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) + + self.lr = self.hparams.get("lr", config.lr) + self.lr_patience = self.hparams.get("lr_patience", config.lr_patience) + self.weight_decay = self.hparams.get("weight_decay", config.weight_decay) + self.lr_factor = self.hparams.get("lr_factor", config.lr_factor) + self.cat_feature_info = cat_feature_info + self.num_feature_info = num_feature_info + self.use_embeddings = self.hparams.get("use_embeddings", config.use_embeddings) + + input_dim = 0 + for feature_name, input_shape in num_feature_info.items(): + input_dim += input_shape + for feature_name, input_shape in cat_feature_info.items(): + input_dim += 1 + + if self.use_embeddings: + input_dim = ( + len(num_feature_info) * config.d_model + + len(cat_feature_info) * config.d_model + ) + + self.embedding_layer = EmbeddingLayer( + num_feature_info=num_feature_info, + cat_feature_info=cat_feature_info, + d_model=self.hparams.get("d_model", config.d_model), + embedding_activation=self.hparams.get( + "embedding_activation", config.embedding_activation + ), + layer_norm_after_embedding=self.hparams.get( + "layer_norm_after_embedding" + ), + use_cls=False, + ) + + self.d_out = num_classes + self.block = DenseBlock( + input_dim=input_dim, + num_layers=config.num_layers, + layer_dim=config.layer_dim, + depth=config.depth, + tree_dim=config.tree_dim, + flatten_output=True, + ) + + head_activation = self.hparams.get("head_activation", config.head_activation) + + self.tabular_head = MLP( + config.num_layers * config.layer_dim, + hidden_units_list=self.hparams.get( + "head_layer_sizes", config.head_layer_sizes + ), + dropout_rate=self.hparams.get("head_dropout", config.head_dropout), + use_skip_layers=self.hparams.get( + "head_skip_layers", config.head_skip_layers + ), + activation_fn=head_activation, + use_batch_norm=self.hparams.get( + "head_use_batch_norm", config.head_use_batch_norm + ), + n_output_units=num_classes, + ) + + def forward(self, num_features, cat_features): + """ + Forward pass through the NODE model. + + Parameters + ---------- + num_features : torch.Tensor + Numerical features tensor of shape [batch_size, num_numerical_features]. + cat_features : torch.Tensor + Categorical features tensor of shape [batch_size, num_categorical_features]. + + Returns + ------- + torch.Tensor + Model output of shape [batch_size, num_classes]. + """ + if self.use_embeddings: + x = self.embedding_layer(num_features, cat_features) + B, S, D = x.shape + x = x.reshape(B, S * D) + else: + x = num_features + cat_features + x = torch.cat(x, dim=1) + + x = self.block(x).squeeze(-1) + x = self.tabular_head(x) + return x From 8b12c61342cf562ec333de0a6a6047d42a36e0a8 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Tue, 5 Nov 2024 14:52:35 +0000 Subject: [PATCH 07/11] add default config for NODE model --- mambular/configs/node_config.py | 69 +++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 mambular/configs/node_config.py diff --git a/mambular/configs/node_config.py b/mambular/configs/node_config.py new file mode 100644 index 0000000..d574f25 --- /dev/null +++ b/mambular/configs/node_config.py @@ -0,0 +1,69 @@ +from dataclasses import dataclass +import torch.nn as nn + + +@dataclass +class DefaultNODEConfig: + """ + Configuration class for the default Neural Oblivious Decision Ensemble (NODE) model. + + This class provides default hyperparameters for training and configuring a NODE model. + + Attributes + ---------- + lr : float, optional + Learning rate for the optimizer. Default is 1e-4. + lr_patience : int, optional + Number of epochs without improvement after which the learning rate will be reduced. Default is 10. + weight_decay : float, optional + Weight decay (L2 regularization penalty) applied by the optimizer. Default is 1e-6. + lr_factor : float, optional + Factor by which the learning rate is reduced when there is no improvement. Default is 0.1. + norm : str, optional + Type of normalization to use. Default is None. + use_embeddings : bool, optional + Whether to use embedding layers for categorical features. Default is False. + embedding_activation : callable, optional + Activation function to apply to embeddings. Default is `nn.Identity`. + layer_norm_after_embedding : bool, optional + Whether to apply layer normalization after embedding layers. Default is False. + d_model : int, optional + Dimensionality of the embedding space. Default is 32. + num_layers : int, optional + Number of dense layers in the model. Default is 4. + layer_dim : int, optional + Dimensionality of each dense layer. Default is 128. + tree_dim : int, optional + Dimensionality of the output from each tree leaf. Default is 1. + depth : int, optional + Depth of each decision tree in the ensemble. Default is 6. + head_layer_sizes : list, default=(128, 64, 32) + Sizes of the layers in the head of the model. + head_dropout : float, default=0.5 + Dropout rate for the head layers. + head_skip_layers : bool, default=False + Whether to skip layers in the head. + head_activation : callable, default=nn.SELU() + Activation function for the head layers. + head_use_batch_norm : bool, default=False + Whether to use batch normalization in the head layers. + """ + + lr: float = 1e-04 + lr_patience: int = 10 + weight_decay: float = 1e-06 + lr_factor: float = 0.1 + norm: str = None + use_embeddings: bool = False + embedding_activation: callable = nn.Identity() + layer_norm_after_embedding: bool = False + d_model: int = 32 + num_layers: int = 4 + layer_dim: int = 128 + tree_dim: int = 1 + depth: int = 6 + head_layer_sizes: list = () + head_dropout: float = 0.5 + head_skip_layers: bool = False + head_activation: callable = nn.SELU() + head_use_batch_norm: bool = False From 90e1476bcf75a7620f45fca6ee201d4a3c531f01 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Tue, 5 Nov 2024 14:52:50 +0000 Subject: [PATCH 08/11] add Node to models and __init__ --- mambular/models/__init__.py | 4 + mambular/models/node.py | 287 ++++++++++++++++++++++++++++++++++++ 2 files changed, 291 insertions(+) create mode 100644 mambular/models/node.py diff --git a/mambular/models/__init__.py b/mambular/models/__init__.py index fc7d27f..f9d82f6 100644 --- a/mambular/models/__init__.py +++ b/mambular/models/__init__.py @@ -24,6 +24,7 @@ ) from .ndtf import NDTFClassifier, NDTFRegressor, NDTFLSS +from .node import NODEClassifier, NODERegressor, NODELSS __all__ = [ @@ -57,4 +58,7 @@ "NDTFClassifier", "NDTFRegressor", "NDTFLSS", + "NODEClassifier", + "NODERegressor", + "NODELSS", ] diff --git a/mambular/models/node.py b/mambular/models/node.py new file mode 100644 index 0000000..cfd9d52 --- /dev/null +++ b/mambular/models/node.py @@ -0,0 +1,287 @@ +from .sklearn_base_regressor import SklearnBaseRegressor +from .sklearn_base_classifier import SklearnBaseClassifier +from .sklearn_base_lss import SklearnBaseLSS +from ..base_models.node import NODE +from ..configs.node_config import DefaultNODEConfig + + +class NODERegressor(SklearnBaseRegressor): + """ + Neural Oblivious Decision Ensemble (NODE) Regressor. Slightly different with a MLP as a tabular task specific head. This class extends the SklearnBaseRegressor class and uses the NODE model + with the default NODE configuration. + + The accepted arguments to the NODERegressor class include both the attributes in the DefaultNODEConfig dataclass + and the parameters for the Preprocessor class. + + Parameters + ---------- + lr : float, optional + Learning rate for the optimizer. Default is 1e-4. + lr_patience : int, optional + Number of epochs without improvement after which the learning rate will be reduced. Default is 10. + weight_decay : float, optional + Weight decay (L2 regularization penalty) applied by the optimizer. Default is 1e-6. + lr_factor : float, optional + Factor by which the learning rate is reduced when there is no improvement. Default is 0.1. + norm : str, optional + Type of normalization to use. Default is None. + use_embeddings : bool, optional + Whether to use embedding layers for categorical features. Default is False. + embedding_activation : callable, optional + Activation function to apply to embeddings. Default is `nn.Identity`. + layer_norm_after_embedding : bool, optional + Whether to apply layer normalization after embedding layers. Default is False. + d_model : int, optional + Dimensionality of the embedding space. Default is 32. + num_layers : int, optional + Number of dense layers in the model. Default is 4. + layer_dim : int, optional + Dimensionality of each dense layer. Default is 128. + tree_dim : int, optional + Dimensionality of the output from each tree leaf. Default is 1. + depth : int, optional + Depth of each decision tree in the ensemble. Default is 6. + head_layer_sizes : list, default=(128, 64, 32) + Sizes of the layers in the head of the model. + head_dropout : float, default=0.5 + Dropout rate for the head layers. + head_skip_layers : bool, default=False + Whether to skip layers in the head. + head_activation : callable, default=nn.SELU() + Activation function for the head layers. + head_use_batch_norm : bool, default=False + Whether to use batch normalization in the head layers. + n_bins : int, default=50 + The number of bins to use for numerical feature binning. This parameter is relevant + only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + numerical_preprocessing : str, default="ple" + The preprocessing strategy for numerical features. Valid options are + 'binning', 'one_hot', 'standardization', and 'normalization'. + use_decision_tree_bins : bool, default=False + If True, uses decision tree regression/classification to determine + optimal bin edges for numerical feature binning. This parameter is + relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + binning_strategy : str, default="uniform" + Defines the strategy for binning numerical features. Options include 'uniform', + 'quantile', or other sklearn-compatible strategies. + cat_cutoff : float or int, default=0.03 + Indicates the cutoff after which integer values are treated as categorical. + If float, it's treated as a percentage. If int, it's the maximum number of + unique values for a column to be considered categorical. + treat_all_integers_as_numerical : bool, default=False + If True, all integer columns will be treated as numerical, regardless + of their unique value count or proportion. + degree : int, default=3 + The degree of the polynomial features to be used in preprocessing. + knots : int, default=12 + The number of knots to be used in spline transformations. + + Notes + ----- + - The accepted arguments to the NODERegressor class are the same as the attributes in the DefaultNODEConfig dataclass. + - NODERegressor uses SklearnBaseRegressor as the parent class. The methods for fitting, predicting, and evaluating the model are inherited from the parent class. Please refer to the parent class for more information. + + See Also + -------- + mambular.models.SklearnBaseRegressor : The parent class for NODERegressor. + + Examples + -------- + >>> from mambular.models import NODERegressor + >>> model = NODERegressor(layer_sizes=[128, 128, 64], activation=nn.ReLU()) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """ + + def __init__(self, **kwargs): + super().__init__(model=NODE, config=DefaultNODEConfig, **kwargs) + + +class NODEClassifier(SklearnBaseClassifier): + """ + Neural Oblivious Decision Ensemble (NODE) Classifier. Slightly different with a MLP as a tabular task specific head. This class extends the SklearnBaseClassifier class and uses the NODE model + with the default NODE configuration. + + The accepted arguments to the NODEClassifier class include both the attributes in the DefaultNODEConfig dataclass + and the parameters for the Preprocessor class. + + Parameters + ---------- + lr : float, optional + Learning rate for the optimizer. Default is 1e-4. + lr_patience : int, optional + Number of epochs without improvement after which the learning rate will be reduced. Default is 10. + weight_decay : float, optional + Weight decay (L2 regularization penalty) applied by the optimizer. Default is 1e-6. + lr_factor : float, optional + Factor by which the learning rate is reduced when there is no improvement. Default is 0.1. + norm : str, optional + Type of normalization to use. Default is None. + use_embeddings : bool, optional + Whether to use embedding layers for categorical features. Default is False. + embedding_activation : callable, optional + Activation function to apply to embeddings. Default is `nn.Identity`. + layer_norm_after_embedding : bool, optional + Whether to apply layer normalization after embedding layers. Default is False. + d_model : int, optional + Dimensionality of the embedding space. Default is 32. + num_layers : int, optional + Number of dense layers in the model. Default is 4. + layer_dim : int, optional + Dimensionality of each dense layer. Default is 128. + tree_dim : int, optional + Dimensionality of the output from each tree leaf. Default is 1. + depth : int, optional + Depth of each decision tree in the ensemble. Default is 6. + head_layer_sizes : list, default=(128, 64, 32) + Sizes of the layers in the head of the model. + head_dropout : float, default=0.5 + Dropout rate for the head layers. + head_skip_layers : bool, default=False + Whether to skip layers in the head. + head_activation : callable, default=nn.SELU() + Activation function for the head layers. + head_use_batch_norm : bool, default=False + Whether to use batch normalization in the head layers. + n_bins : int, default=50 + The number of bins to use for numerical feature binning. This parameter is relevant + only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + numerical_preprocessing : str, default="ple" + The preprocessing strategy for numerical features. Valid options are + 'binning', 'one_hot', 'standardization', and 'normalization'. + use_decision_tree_bins : bool, default=False + If True, uses decision tree regression/classification to determine + optimal bin edges for numerical feature binning. This parameter is + relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + binning_strategy : str, default="uniform" + Defines the strategy for binning numerical features. Options include 'uniform', + 'quantile', or other sklearn-compatible strategies. + cat_cutoff : float or int, default=0.03 + Indicates the cutoff after which integer values are treated as categorical. + If float, it's treated as a percentage. If int, it's the maximum number of + unique values for a column to be considered categorical. + treat_all_integers_as_numerical : bool, default=False + If True, all integer columns will be treated as numerical, regardless + of their unique value count or proportion. + degree : int, default=3 + The degree of the polynomial features to be used in preprocessing. + knots : int, default=12 + The number of knots to be used in spline transformations. + + Notes + ----- + - The accepted arguments to the NODEClassifier class are the same as the attributes in the DefaultNODEConfig dataclass. + - NODEClassifier uses SklearnBaseClassifieras the parent class. The methods for fitting, predicting, and evaluating the model are inherited from the parent class. Please refer to the parent class for more information. + + See Also + -------- + mambular.models.SklearnBaseClassifier : The parent class for NODEClassifier. + + Examples + -------- + >>> from mambular.models import NODEClassifier + >>> model = NODEClassifier(layer_sizes=[128, 128, 64], activation=nn.ReLU()) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """ + + def __init__(self, **kwargs): + super().__init__(model=NODE, config=DefaultNODEConfig, **kwargs) + + +class NODELSS(SklearnBaseLSS): + """ + Neural Oblivious Decision Ensemble (NODE) for disrtibutional regression. Slightly different with a MLP as a tabular task specific head. This class extends the SklearnBaseLSS class and uses the NODE model + with the default NODE configuration. + + The accepted arguments to the NODELSS class include both the attributes in the DefaultNODEConfig dataclass + and the parameters for the Preprocessor class. + + Parameters + ---------- + lr : float, optional + Learning rate for the optimizer. Default is 1e-4. + lr_patience : int, optional + Number of epochs without improvement after which the learning rate will be reduced. Default is 10. + weight_decay : float, optional + Weight decay (L2 regularization penalty) applied by the optimizer. Default is 1e-6. + lr_factor : float, optional + Factor by which the learning rate is reduced when there is no improvement. Default is 0.1. + norm : str, optional + Type of normalization to use. Default is None. + use_embeddings : bool, optional + Whether to use embedding layers for categorical features. Default is False. + embedding_activation : callable, optional + Activation function to apply to embeddings. Default is `nn.Identity`. + layer_norm_after_embedding : bool, optional + Whether to apply layer normalization after embedding layers. Default is False. + d_model : int, optional + Dimensionality of the embedding space. Default is 32. + num_layers : int, optional + Number of dense layers in the model. Default is 4. + layer_dim : int, optional + Dimensionality of each dense layer. Default is 128. + tree_dim : int, optional + Dimensionality of the output from each tree leaf. Default is 1. + depth : int, optional + Depth of each decision tree in the ensemble. Default is 6. + head_layer_sizes : list, default=(128, 64, 32) + Sizes of the layers in the head of the model. + head_dropout : float, default=0.5 + Dropout rate for the head layers. + head_skip_layers : bool, default=False + Whether to skip layers in the head. + head_activation : callable, default=nn.SELU() + Activation function for the head layers. + head_use_batch_norm : bool, default=False + Whether to use batch normalization in the head layers. + n_bins : int, default=50 + The number of bins to use for numerical feature binning. This parameter is relevant + only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + numerical_preprocessing : str, default="ple" + The preprocessing strategy for numerical features. Valid options are + 'binning', 'one_hot', 'standardization', and 'normalization'. + use_decision_tree_bins : bool, default=False + If True, uses decision tree regression/classification to determine + optimal bin edges for numerical feature binning. This parameter is + relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'. + binning_strategy : str, default="uniform" + Defines the strategy for binning numerical features. Options include 'uniform', + 'quantile', or other sklearn-compatible strategies. + task : str, default="regression" + Indicates the type of machine learning task ('regression' or 'classification'). This can + influence certain preprocessing behaviors, especially when using decision tree-based binning as ple. + cat_cutoff : float or int, default=0.03 + Indicates the cutoff after which integer values are treated as categorical. + If float, it's treated as a percentage. If int, it's the maximum number of + unique values for a column to be considered categorical. + treat_all_integers_as_numerical : bool, default=False + If True, all integer columns will be treated as numerical, regardless + of their unique value count or proportion. + degree : int, default=3 + The degree of the polynomial features to be used in preprocessing. + knots : int, default=12 + The number of knots to be used in spline transformations. + + Notes + ----- + - The accepted arguments to the NODELSS class are the same as the attributes in the DefaultNODEConfig dataclass. + - NODELSS uses SklearnBaseLSS as the parent class. The methods for fitting, predicting, and evaluating the model are inherited from the parent class. Please refer to the parent class for more information. + + See Also + -------- + mambular.models.SklearnBaseLSS : The parent class for NODELSS. + + Examples + -------- + >>> from mambular.models import NODELSS + >>> model = NODELSS(layer_sizes=[128, 128, 64], activation=nn.ReLU()) + >>> model.fit(X_train, y_train) + >>> preds = model.predict(X_test) + >>> model.evaluate(X_test, y_test) + """ + + def __init__(self, **kwargs): + super().__init__(model=NODE, config=DefaultNODEConfig, **kwargs) From 9bdbff3fe5a83262cae0d53131da3ea3187814b7 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Tue, 5 Nov 2024 14:55:24 +0000 Subject: [PATCH 09/11] refactor normalization layer -> get_normalization_layer included in __init__ --- mambular/base_models/ft_transformer.py | 27 ++----------------- mambular/base_models/mlp.py | 25 ++--------------- mambular/base_models/tabtransformer.py | 26 +++--------------- mambular/base_models/tabularnn.py | 37 ++------------------------ 4 files changed, 9 insertions(+), 106 deletions(-) diff --git a/mambular/base_models/ft_transformer.py b/mambular/base_models/ft_transformer.py index 2af362b..c349e02 100644 --- a/mambular/base_models/ft_transformer.py +++ b/mambular/base_models/ft_transformer.py @@ -1,14 +1,7 @@ import torch import torch.nn as nn from ..arch_utils.mlp_utils import MLP -from ..arch_utils.layer_utils.normalization_layers import ( - RMSNorm, - LayerNorm, - LearnableLayerScaling, - BatchNorm, - InstanceNorm, - GroupNorm, -) +from ..arch_utils.get_norm_fn import get_normalization_layer from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer from ..arch_utils.transformer_utils import CustomTransformerEncoderLayer from ..configs.fttransformer_config import DefaultFTTransformerConfig @@ -101,23 +94,7 @@ def __init__( bias=self.hparams.get("bias", config.bias), ) - norm_layer = self.hparams.get("norm", config.norm) - if norm_layer == "RMSNorm": - self.norm_f = RMSNorm(self.hparams.get("d_model", config.d_model)) - elif norm_layer == "LayerNorm": - self.norm_f = LayerNorm(self.hparams.get("d_model", config.d_model)) - elif norm_layer == "BatchNorm": - self.norm_f = BatchNorm(self.hparams.get("d_model", config.d_model)) - elif norm_layer == "InstanceNorm": - self.norm_f = InstanceNorm(self.hparams.get("d_model", config.d_model)) - elif norm_layer == "GroupNorm": - self.norm_f = GroupNorm(1, self.hparams.get("d_model", config.d_model)) - elif norm_layer == "LearnableLayerScaling": - self.norm_f = LearnableLayerScaling( - self.hparams.get("d_model", config.d_model) - ) - else: - self.norm_f = None + self.norm_f = get_normalization_layer(config) self.encoder = nn.TransformerEncoder( encoder_layer, diff --git a/mambular/base_models/mlp.py b/mambular/base_models/mlp.py index d9e24e3..4aebee5 100644 --- a/mambular/base_models/mlp.py +++ b/mambular/base_models/mlp.py @@ -2,14 +2,7 @@ import torch.nn as nn from ..configs.mlp_config import DefaultMLPConfig from .basemodel import BaseModel -from ..arch_utils.layer_utils.normalization_layers import ( - RMSNorm, - LayerNorm, - LearnableLayerScaling, - BatchNorm, - InstanceNorm, - GroupNorm, -) +from ..arch_utils.get_norm_fn import get_normalization_layer from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer @@ -74,21 +67,7 @@ def __init__( if config.batch_norm: self.layers.append(nn.BatchNorm1d(self.layer_sizes[0])) - norm_layer = self.hparams.get("norm", config.norm) - if norm_layer == "RMSNorm": - self.norm_f = RMSNorm(self.layer_sizes[0]) - elif norm_layer == "LayerNorm": - self.norm_f = LayerNorm(self.layer_sizes[0]) - elif norm_layer == "BatchNorm": - self.norm_f = BatchNorm(self.layer_sizes[0]) - elif norm_layer == "InstanceNorm": - self.norm_f = InstanceNorm(self.layer_sizes[0]) - elif norm_layer == "GroupNorm": - self.norm_f = GroupNorm(1, self.layer_sizes[0]) - elif norm_layer == "LearnableLayerScaling": - self.norm_f = LearnableLayerScaling(self.layer_sizes[0]) - else: - self.norm_f = None + self.norm_f = get_normalization_layer(config) if self.norm_f is not None: self.layers.append(self.norm_f(self.layer_sizes[0])) diff --git a/mambular/base_models/tabtransformer.py b/mambular/base_models/tabtransformer.py index 2229faa..48fd8ed 100644 --- a/mambular/base_models/tabtransformer.py +++ b/mambular/base_models/tabtransformer.py @@ -1,18 +1,12 @@ import torch import torch.nn as nn from ..arch_utils.mlp_utils import MLP -from ..arch_utils.layer_utils.normalization_layers import ( - RMSNorm, - LayerNorm, - LearnableLayerScaling, - BatchNorm, - InstanceNorm, - GroupNorm, -) +from ..arch_utils.get_norm_fn import get_normalization_layer from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer from ..configs.tabtransformer_config import DefaultTabTransformerConfig from .basemodel import BaseModel from ..arch_utils.transformer_utils import CustomTransformerEncoderLayer +from ..arch_utils.layer_utils.normalization_layers import LayerNorm class TabTransformer(BaseModel): @@ -109,21 +103,7 @@ def __init__( bias=self.hparams.get("bias", config.bias), ) - norm_layer = self.hparams.get("norm", config.norm) - if norm_layer == "RMSNorm": - self.norm_f = RMSNorm(layer_norm_dim) - elif norm_layer == "LayerNorm": - self.norm_f = LayerNorm(layer_norm_dim) - elif norm_layer == "BatchNorm": - self.norm_f = BatchNorm(layer_norm_dim) - elif norm_layer == "InstanceNorm": - self.norm_f = InstanceNorm(layer_norm_dim) - elif norm_layer == "GroupNorm": - self.norm_f = GroupNorm(1, layer_norm_dim) - elif norm_layer == "LearnableLayerScaling": - self.norm_f = LearnableLayerScaling(layer_norm_dim) - else: - self.norm_f = None + self.norm_f = get_normalization_layer(config) self.norm_embedding = LayerNorm(self.hparams.get("d_model", config.d_model)) self.encoder = nn.TransformerEncoder( diff --git a/mambular/base_models/tabularnn.py b/mambular/base_models/tabularnn.py index 3cc5fc3..4ce5a64 100644 --- a/mambular/base_models/tabularnn.py +++ b/mambular/base_models/tabularnn.py @@ -5,14 +5,7 @@ from .basemodel import BaseModel from ..arch_utils.layer_utils.embedding_layer import EmbeddingLayer from ..arch_utils.rnn_utils import ConvRNN -from ..arch_utils.layer_utils.normalization_layers import ( - RMSNorm, - LayerNorm, - LearnableLayerScaling, - BatchNorm, - InstanceNorm, - GroupNorm, -) +from ..arch_utils.get_norm_fn import get_normalization_layer class TabulaRNN(BaseModel): @@ -35,33 +28,7 @@ def __init__( self.cat_feature_info = cat_feature_info self.num_feature_info = num_feature_info - norm_layer = self.hparams.get("norm", config.norm) - if norm_layer == "RMSNorm": - self.norm_f = RMSNorm( - self.hparams.get("dim_feedforward", config.dim_feedforward) - ) - elif norm_layer == "LayerNorm": - self.norm_f = LayerNorm( - self.hparams.get("dim_feedforward", config.dim_feedforward) - ) - elif norm_layer == "BatchNorm": - self.norm_f = BatchNorm( - self.hparams.get("dim_feedforward", config.dim_feedforward) - ) - elif norm_layer == "InstanceNorm": - self.norm_f = InstanceNorm( - self.hparams.get("dim_feedforward", config.dim_feedforward) - ) - elif norm_layer == "GroupNorm": - self.norm_f = GroupNorm( - 1, self.hparams.get("dim_feedforward", config.dim_feedforward) - ) - elif norm_layer == "LearnableLayerScaling": - self.norm_f = LearnableLayerScaling( - self.hparams.get("dim_feedforward", config.dim_feedforward) - ) - else: - self.norm_f = None + self.norm_f = get_normalization_layer(config) self.rnn = ConvRNN( model_type=self.hparams.get("model_type", config.model_type), From b0c0bf4d2860b81eb61055bb1b74c2a0669f7e1e Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Tue, 5 Nov 2024 14:55:35 +0000 Subject: [PATCH 10/11] add nodeconfig in __init__ --- mambular/configs/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mambular/configs/__init__.py b/mambular/configs/__init__.py index 31ee342..45dc765 100644 --- a/mambular/configs/__init__.py +++ b/mambular/configs/__init__.py @@ -7,6 +7,7 @@ from .tabularnn_config import DefaultTabulaRNNConfig from .mambattention_config import DefaultMambAttentionConfig from .ndtf_config import DefaultNDTFConfig +from .node_config import DefaultNODEConfig __all__ = [ @@ -19,4 +20,5 @@ "DefaultTabulaRNNConfig", "DefaultMambAttentionConfig", "DefaultNDTFConfig", + "DefaultNODEConfig", ] From bea4bc3097ce255c45a44e71ef586ffbab8e2ce1 Mon Sep 17 00:00:00 2001 From: AnFreTh Date: Tue, 5 Nov 2024 14:55:54 +0000 Subject: [PATCH 11/11] fix typo in docstrings --- mambular/models/mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mambular/models/mlp.py b/mambular/models/mlp.py index 60d77e3..b286c45 100644 --- a/mambular/models/mlp.py +++ b/mambular/models/mlp.py @@ -172,7 +172,7 @@ class MLPClassifier(SklearnBaseClassifier): See Also -------- - mambular.models.SklearnBaseRegressor : The parent class for MLPClassifier. + mambular.models.SklearnBaseClassifier : The parent class for MLPClassifier. Examples --------