Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version 0.2.1 #106

Merged
merged 69 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
823763b
adapt configs to better default models
AnFreTh Jul 15, 2024
3004241
include ReGLU activation
AnFreTh Jul 15, 2024
a9f5e4c
include ReGLU in Ft-Transformer
AnFreTh Jul 15, 2024
c3d4c01
include cls token at end of sequence
AnFreTh Jul 15, 2024
fcb17c1
include ReGLU in TabTransformer
AnFreTh Jul 15, 2024
5f0608c
remove unecessary code
AnFreTh Jul 15, 2024
e16ba03
Merge pull request #75 from basf/models
AnFreTh Jul 15, 2024
e7ca3a2
add build_model and get_num_params methods
AnFreTh Jul 15, 2024
3f482fa
Merge pull request #78 from basf/models
AnFreTh Jul 15, 2024
dc4c313
add learnable ple encodings
AnFreTh Jul 15, 2024
f6aa9ac
add rotary embeddings and attention_net utils
AnFreTh Jul 15, 2024
04101f9
add mambatab, rotaryft and rnn configs
AnFreTh Jul 15, 2024
fe40c47
adapt mambatab config to paper hparams
AnFreTh Jul 15, 2024
30d35e4
include RNN and MambaTab
AnFreTh Jul 15, 2024
21b0afb
add mambatab, rnn, rotery and basisexpandFT to init
AnFreTh Jul 15, 2024
8b19f92
include basemodel basisexpansion
AnFreTh Jul 15, 2024
bad3db6
adding util embedding alyer for decluttering
AnFreTh Jul 15, 2024
a2c3e71
adapted embedding layer for when no cat features
AnFreTh Jul 15, 2024
5f15d5f
adapted mambular, FT and TabTransformer for new embedding layer class
AnFreTh Jul 15, 2024
fa8cc00
included possible embeddings into MLP and ResNet
AnFreTh Jul 15, 2024
6efc326
adapted configs of models
AnFreTh Jul 15, 2024
ae0bf97
adapt documentation to new configs
AnFreTh Jul 15, 2024
6be79d1
add missing line in MLP
AnFreTh Jul 15, 2024
3d840b9
added seq len to emebedding layer
AnFreTh Jul 16, 2024
ae84257
added shuffling embeddings before being passed to mamba blocks
AnFreTh Jul 16, 2024
e1f5c40
shuffle embeddings option in config
AnFreTh Jul 16, 2024
8af86a2
shuffling along second axis
AnFreTh Jul 16, 2024
2029700
add .built attr to model classed
AnFreTh Jul 16, 2024
c7459c4
delete expansion
AnFreTh Jul 17, 2024
96e9449
delete basis expansion
AnFreTh Jul 17, 2024
864d4f0
delete self.paper attribute from mambatab
AnFreTh Jul 17, 2024
d636b28
adapt build model func
AnFreTh Jul 17, 2024
5a5a7cd
adapt build_method
AnFreTh Jul 17, 2024
b246d5c
fix label shape for lss regression
AnFreTh Jul 25, 2024
6da2cf8
Merge pull request #83 from basf/lss_fix
AnFreTh Jul 25, 2024
257acfb
self.model.parameters instead of self.parameters in optimizer
AnFreTh Jul 25, 2024
be1879d
depth=2 for summary and self.trainer attribute
AnFreTh Jul 25, 2024
b16af74
Merge pull request #84 from basf/trainer_fix
AnFreTh Jul 25, 2024
bd941d2
fixed set and get_params functinoality
AnFreTh Jul 25, 2024
cc92798
Merge pull request #85 from basf/param_fix
AnFreTh Jul 25, 2024
283a10b
add embedding layer and delete unused models
AnFreTh Jul 26, 2024
19b760c
Merge branch 'develop' into layer_improvement
AnFreTh Jul 26, 2024
56801dd
Merge pull request #90 from basf/layer_improvement
AnFreTh Jul 26, 2024
22dad68
adding AB layernorm and weight decay to Mamba
AnFreTh Aug 2, 2024
0d4442a
Merge pull request #97 from basf/AB_layer
AnFreTh Aug 2, 2024
7c2d343
adjust names of matrices
AnFreTh Aug 2, 2024
8933cf7
Merge pull request #98 from basf/AB_layer
AnFreTh Aug 2, 2024
029c6d8
adding one-hot encoding to embedding_layer
AnFreTh Aug 5, 2024
d413fd8
adding option to one-hot encode cat features in embedding layer
AnFreTh Aug 5, 2024
07164f5
adjusting configs
AnFreTh Aug 5, 2024
71cc68e
renaming sklearn class attributes
AnFreTh Aug 5, 2024
53b77c5
adjusting class attribute in lightning wrapper
AnFreTh Aug 5, 2024
4f6b088
adjusting config docstrings
AnFreTh Aug 5, 2024
71f35e6
adjusting docstrings for documentation
AnFreTh Aug 5, 2024
e6b90dc
Merge pull request #99 from basf/minor_fix
AnFreTh Aug 5, 2024
5703d3b
adjusting docstrings for true defaults
AnFreTh Aug 5, 2024
a996e6e
Merge pull request #100 from basf/fastfix
AnFreTh Aug 5, 2024
16e18f4
adding score function to lss models
AnFreTh Aug 5, 2024
91ab62c
Merge pull request #101 from basf/lss_crossval
AnFreTh Aug 5, 2024
43d2758
include tabularRNN
AnFreTh Aug 12, 2024
330c1a0
Merge pull request #102 from basf/rnn_branch
AnFreTh Aug 12, 2024
08a8797
adapt readme and include citation
AnFreTh Aug 13, 2024
545af38
adjust readme with results
AnFreTh Aug 13, 2024
bd690bc
remove colored text in readme
AnFreTh Aug 13, 2024
2c653be
Merge pull request #103 from basf/doc-adaption
AnFreTh Aug 13, 2024
a0775af
new version
AnFreTh Aug 13, 2024
6949f26
Merge pull request #104 from basf/version_adaption
AnFreTh Aug 13, 2024
f466b80
deleting old paper.pdf
AnFreTh Aug 13, 2024
a383411
Merge pull request #105 from basf/version_adaption
AnFreTh Aug 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions .github/workflows/draft-pdf.yml

This file was deleted.

615 changes: 352 additions & 263 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mambular/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Version information."""

# The following line *must* be the last in the module, exactly as formatted:
__version__ = "0.1.7"
__version__ = "0.2.1"
102 changes: 102 additions & 0 deletions mambular/arch_utils/attention_net_arch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import torch.nn as nn
import torch


import torch
import torch.nn as nn


class Reshape(nn.Module):
def __init__(self, j, dim, method="linear"):
super(Reshape, self).__init__()
self.j = j
self.dim = dim
self.method = method

if self.method == "linear":
# Use nn.Linear approach
self.layer = nn.Linear(dim, j * dim)
elif self.method == "embedding":
# Use nn.Embedding approach
self.layer = nn.Embedding(dim, j * dim)
elif self.method == "conv1d":
# Use nn.Conv1d approach
self.layer = nn.Conv1d(in_channels=dim, out_channels=j * dim, kernel_size=1)
else:
raise ValueError(f"Unsupported method '{method}' for reshaping.")

def forward(self, x):
batch_size = x.shape[0]

if self.method == "linear" or self.method == "embedding":
x_reshaped = self.layer(x) # shape: (batch_size, j * dim)
x_reshaped = x_reshaped.view(
batch_size, self.j, self.dim
) # shape: (batch_size, j, dim)
elif self.method == "conv1d":
# For Conv1d, add dummy dimension and reshape
x = x.unsqueeze(-1) # Add dummy dimension for convolution
x_reshaped = self.layer(x) # shape: (batch_size, j * dim, 1)
x_reshaped = x_reshaped.squeeze(-1) # Remove dummy dimension
x_reshaped = x_reshaped.view(
batch_size, self.j, self.dim
) # shape: (batch_size, j, dim)

return x_reshaped


class AttentionNetBlock(nn.Module):
def __init__(
self,
channels,
in_channels,
d_model,
n_heads,
n_layers,
dim_feedforward,
transformer_activation,
output_dim,
attn_dropout,
layer_norm_eps,
norm_first,
bias,
activation,
embedding_activation,
norm_f,
method,
):
super(AttentionNetBlock, self).__init__()

self.reshape = Reshape(channels, in_channels, method)

encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=n_heads,
batch_first=True,
dim_feedforward=dim_feedforward,
dropout=attn_dropout,
activation=transformer_activation,
layer_norm_eps=layer_norm_eps,
norm_first=norm_first,
bias=bias,
)

self.encoder = nn.TransformerEncoder(
encoder_layer,
num_layers=n_layers,
norm=norm_f,
)

self.linear = nn.Linear(d_model, output_dim)
self.activation = activation
self.embedding_activation = embedding_activation

def forward(self, x):
z = self.reshape(x)
x = self.embedding_activation(z)
x = self.encoder(x)
x = z + x
x = torch.sum(x, dim=1)
x = self.linear(x)
x = self.activation(x)
return x
97 changes: 97 additions & 0 deletions mambular/arch_utils/attention_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch.nn as nn
import torch
from rotary_embedding_torch import RotaryEmbedding
from einops import rearrange
import torch.nn.functional as F
import numpy as np


class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim=-1)
return x * F.gelu(gates)


def FeedForward(dim, mult=4, dropout=0.0):
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim),
)


class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary=False):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head**-0.5
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
self.dropout = nn.Dropout(dropout)
self.rotary = rotary
dim = np.int64(dim / 2)
self.rotary_embedding = RotaryEmbedding(dim=dim)

def forward(self, x):
h = self.heads
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
if self.rotary:
q = self.rotary_embedding.rotate_queries_or_keys(q)
k = self.rotary_embedding.rotate_queries_or_keys(k)
q = q * self.scale

sim = torch.einsum("b h i d, b h j d -> b h i j", q, k)

attn = sim.softmax(dim=-1)
dropped_attn = self.dropout(attn)

out = torch.einsum("b h i j, b h j d -> b h i d", dropped_attn, v)
out = rearrange(out, "b h n d -> b n (h d)", h=h)
out = self.to_out(out)

return out, attn


class Transformer(nn.Module):
def __init__(
self, dim, depth, heads, dim_head, attn_dropout, ff_dropout, rotary=False
):
super().__init__()
self.layers = nn.ModuleList([])

for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
Attention(
dim,
heads=heads,
dim_head=dim_head,
dropout=attn_dropout,
rotary=rotary,
),
FeedForward(dim, dropout=ff_dropout),
]
)
)

def forward(self, x, return_attn=False):
post_softmax_attns = []

for attn, ff in self.layers:
attn_out, post_softmax_attn = attn(x)
post_softmax_attns.append(post_softmax_attn)

x = attn_out + x
x = ff(x) + x

if not return_attn:
return x

return x, torch.stack(post_softmax_attns)
163 changes: 163 additions & 0 deletions mambular/arch_utils/embedding_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import torch
import torch.nn as nn


class EmbeddingLayer(nn.Module):
def __init__(
self,
num_feature_info,
cat_feature_info,
d_model,
embedding_activation=nn.Identity(),
layer_norm_after_embedding=False,
use_cls=False,
cls_position=0,
cat_encoding="int",
):
"""
Embedding layer that handles numerical and categorical embeddings.

Parameters
----------
num_feature_info : dict
Dictionary where keys are numerical feature names and values are their respective input dimensions.
cat_feature_info : dict
Dictionary where keys are categorical feature names and values are the number of categories for each feature.
d_model : int
Dimensionality of the embeddings.
embedding_activation : nn.Module, optional
Activation function to apply after embedding. Default is `nn.Identity()`.
layer_norm_after_embedding : bool, optional
If True, applies layer normalization after embeddings. Default is `False`.
use_cls : bool, optional
If True, includes a class token in the embeddings. Default is `False`.
cls_position : int, optional
Position to place the class token, either at the start (0) or end (1) of the sequence. Default is `0`.

Methods
-------
forward(num_features=None, cat_features=None)
Defines the forward pass of the model.
"""
super(EmbeddingLayer, self).__init__()

self.d_model = d_model
self.embedding_activation = embedding_activation
self.layer_norm_after_embedding = layer_norm_after_embedding
self.use_cls = use_cls
self.cls_position = cls_position

self.num_embeddings = nn.ModuleList(
[
nn.Sequential(
nn.Linear(input_shape, d_model, bias=False),
self.embedding_activation,
)
for feature_name, input_shape in num_feature_info.items()
]
)

self.cat_embeddings = nn.ModuleList()
for feature_name, num_categories in cat_feature_info.items():
if cat_encoding == "int":
self.cat_embeddings.append(
nn.Sequential(
nn.Embedding(num_categories + 1, d_model),
self.embedding_activation,
)
)
elif cat_encoding == "one-hot":
self.cat_embeddings.append(
nn.Sequential(
OneHotEncoding(num_categories),
nn.Linear(num_categories, d_model, bias=False),
self.embedding_activation,
)
)

if self.use_cls:
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
if layer_norm_after_embedding:
self.embedding_norm = nn.LayerNorm(d_model)

self.seq_len = len(self.num_embeddings) + len(self.cat_embeddings)

def forward(self, num_features=None, cat_features=None):
"""
Defines the forward pass of the model.

Parameters
----------
num_features : Tensor, optional
Tensor containing the numerical features.
cat_features : Tensor, optional
Tensor containing the categorical features.

Returns
-------
Tensor
The output embeddings of the model.

Raises
------
ValueError
If no features are provided to the model.
"""
if self.use_cls:
batch_size = (
cat_features[0].size(0)
if cat_features != []
else num_features[0].size(0)
)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)

if self.cat_embeddings and cat_features is not None:
cat_embeddings = [
emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings)
]
cat_embeddings = torch.stack(cat_embeddings, dim=1)
cat_embeddings = torch.squeeze(cat_embeddings, dim=2)
if self.layer_norm_after_embedding:
cat_embeddings = self.embedding_norm(cat_embeddings)
else:
cat_embeddings = None

if self.num_embeddings and num_features is not None:
num_embeddings = [
emb(num_features[i]) for i, emb in enumerate(self.num_embeddings)
]
num_embeddings = torch.stack(num_embeddings, dim=1)
if self.layer_norm_after_embedding:
num_embeddings = self.embedding_norm(num_embeddings)
else:
num_embeddings = None

if cat_embeddings is not None and num_embeddings is not None:
x = torch.cat([cat_embeddings, num_embeddings], dim=1)
elif cat_embeddings is not None:
x = cat_embeddings
elif num_embeddings is not None:
x = num_embeddings
else:
raise ValueError("No features provided to the model.")

if self.use_cls:
if self.cls_position == 0:
x = torch.cat([cls_tokens, x], dim=1)
elif self.cls_position == 1:
x = torch.cat([x, cls_tokens], dim=1)
else:
raise ValueError(
"Invalid cls_position value. It should be either 0 or 1."
)

return x


class OneHotEncoding(nn.Module):
def __init__(self, num_categories):
super(OneHotEncoding, self).__init__()
self.num_categories = num_categories

def forward(self, x):
return torch.nn.functional.one_hot(x, num_classes=self.num_categories).float()
Loading