-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #106 from basf/develop
Version 0.2.1
- Loading branch information
Showing
41 changed files
with
2,801 additions
and
862 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
"""Version information.""" | ||
|
||
# The following line *must* be the last in the module, exactly as formatted: | ||
__version__ = "0.1.7" | ||
__version__ = "0.2.1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import torch.nn as nn | ||
import torch | ||
|
||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class Reshape(nn.Module): | ||
def __init__(self, j, dim, method="linear"): | ||
super(Reshape, self).__init__() | ||
self.j = j | ||
self.dim = dim | ||
self.method = method | ||
|
||
if self.method == "linear": | ||
# Use nn.Linear approach | ||
self.layer = nn.Linear(dim, j * dim) | ||
elif self.method == "embedding": | ||
# Use nn.Embedding approach | ||
self.layer = nn.Embedding(dim, j * dim) | ||
elif self.method == "conv1d": | ||
# Use nn.Conv1d approach | ||
self.layer = nn.Conv1d(in_channels=dim, out_channels=j * dim, kernel_size=1) | ||
else: | ||
raise ValueError(f"Unsupported method '{method}' for reshaping.") | ||
|
||
def forward(self, x): | ||
batch_size = x.shape[0] | ||
|
||
if self.method == "linear" or self.method == "embedding": | ||
x_reshaped = self.layer(x) # shape: (batch_size, j * dim) | ||
x_reshaped = x_reshaped.view( | ||
batch_size, self.j, self.dim | ||
) # shape: (batch_size, j, dim) | ||
elif self.method == "conv1d": | ||
# For Conv1d, add dummy dimension and reshape | ||
x = x.unsqueeze(-1) # Add dummy dimension for convolution | ||
x_reshaped = self.layer(x) # shape: (batch_size, j * dim, 1) | ||
x_reshaped = x_reshaped.squeeze(-1) # Remove dummy dimension | ||
x_reshaped = x_reshaped.view( | ||
batch_size, self.j, self.dim | ||
) # shape: (batch_size, j, dim) | ||
|
||
return x_reshaped | ||
|
||
|
||
class AttentionNetBlock(nn.Module): | ||
def __init__( | ||
self, | ||
channels, | ||
in_channels, | ||
d_model, | ||
n_heads, | ||
n_layers, | ||
dim_feedforward, | ||
transformer_activation, | ||
output_dim, | ||
attn_dropout, | ||
layer_norm_eps, | ||
norm_first, | ||
bias, | ||
activation, | ||
embedding_activation, | ||
norm_f, | ||
method, | ||
): | ||
super(AttentionNetBlock, self).__init__() | ||
|
||
self.reshape = Reshape(channels, in_channels, method) | ||
|
||
encoder_layer = nn.TransformerEncoderLayer( | ||
d_model=d_model, | ||
nhead=n_heads, | ||
batch_first=True, | ||
dim_feedforward=dim_feedforward, | ||
dropout=attn_dropout, | ||
activation=transformer_activation, | ||
layer_norm_eps=layer_norm_eps, | ||
norm_first=norm_first, | ||
bias=bias, | ||
) | ||
|
||
self.encoder = nn.TransformerEncoder( | ||
encoder_layer, | ||
num_layers=n_layers, | ||
norm=norm_f, | ||
) | ||
|
||
self.linear = nn.Linear(d_model, output_dim) | ||
self.activation = activation | ||
self.embedding_activation = embedding_activation | ||
|
||
def forward(self, x): | ||
z = self.reshape(x) | ||
x = self.embedding_activation(z) | ||
x = self.encoder(x) | ||
x = z + x | ||
x = torch.sum(x, dim=1) | ||
x = self.linear(x) | ||
x = self.activation(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import torch.nn as nn | ||
import torch | ||
from rotary_embedding_torch import RotaryEmbedding | ||
from einops import rearrange | ||
import torch.nn.functional as F | ||
import numpy as np | ||
|
||
|
||
class GEGLU(nn.Module): | ||
def forward(self, x): | ||
x, gates = x.chunk(2, dim=-1) | ||
return x * F.gelu(gates) | ||
|
||
|
||
def FeedForward(dim, mult=4, dropout=0.0): | ||
return nn.Sequential( | ||
nn.LayerNorm(dim), | ||
nn.Linear(dim, dim * mult * 2), | ||
GEGLU(), | ||
nn.Dropout(dropout), | ||
nn.Linear(dim * mult, dim), | ||
) | ||
|
||
|
||
class Attention(nn.Module): | ||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary=False): | ||
super().__init__() | ||
inner_dim = dim_head * heads | ||
self.heads = heads | ||
self.scale = dim_head**-0.5 | ||
self.norm = nn.LayerNorm(dim) | ||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) | ||
self.to_out = nn.Linear(inner_dim, dim, bias=False) | ||
self.dropout = nn.Dropout(dropout) | ||
self.rotary = rotary | ||
dim = np.int64(dim / 2) | ||
self.rotary_embedding = RotaryEmbedding(dim=dim) | ||
|
||
def forward(self, x): | ||
h = self.heads | ||
x = self.norm(x) | ||
q, k, v = self.to_qkv(x).chunk(3, dim=-1) | ||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) | ||
if self.rotary: | ||
q = self.rotary_embedding.rotate_queries_or_keys(q) | ||
k = self.rotary_embedding.rotate_queries_or_keys(k) | ||
q = q * self.scale | ||
|
||
sim = torch.einsum("b h i d, b h j d -> b h i j", q, k) | ||
|
||
attn = sim.softmax(dim=-1) | ||
dropped_attn = self.dropout(attn) | ||
|
||
out = torch.einsum("b h i j, b h j d -> b h i d", dropped_attn, v) | ||
out = rearrange(out, "b h n d -> b n (h d)", h=h) | ||
out = self.to_out(out) | ||
|
||
return out, attn | ||
|
||
|
||
class Transformer(nn.Module): | ||
def __init__( | ||
self, dim, depth, heads, dim_head, attn_dropout, ff_dropout, rotary=False | ||
): | ||
super().__init__() | ||
self.layers = nn.ModuleList([]) | ||
|
||
for _ in range(depth): | ||
self.layers.append( | ||
nn.ModuleList( | ||
[ | ||
Attention( | ||
dim, | ||
heads=heads, | ||
dim_head=dim_head, | ||
dropout=attn_dropout, | ||
rotary=rotary, | ||
), | ||
FeedForward(dim, dropout=ff_dropout), | ||
] | ||
) | ||
) | ||
|
||
def forward(self, x, return_attn=False): | ||
post_softmax_attns = [] | ||
|
||
for attn, ff in self.layers: | ||
attn_out, post_softmax_attn = attn(x) | ||
post_softmax_attns.append(post_softmax_attn) | ||
|
||
x = attn_out + x | ||
x = ff(x) + x | ||
|
||
if not return_attn: | ||
return x | ||
|
||
return x, torch.stack(post_softmax_attns) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class EmbeddingLayer(nn.Module): | ||
def __init__( | ||
self, | ||
num_feature_info, | ||
cat_feature_info, | ||
d_model, | ||
embedding_activation=nn.Identity(), | ||
layer_norm_after_embedding=False, | ||
use_cls=False, | ||
cls_position=0, | ||
cat_encoding="int", | ||
): | ||
""" | ||
Embedding layer that handles numerical and categorical embeddings. | ||
Parameters | ||
---------- | ||
num_feature_info : dict | ||
Dictionary where keys are numerical feature names and values are their respective input dimensions. | ||
cat_feature_info : dict | ||
Dictionary where keys are categorical feature names and values are the number of categories for each feature. | ||
d_model : int | ||
Dimensionality of the embeddings. | ||
embedding_activation : nn.Module, optional | ||
Activation function to apply after embedding. Default is `nn.Identity()`. | ||
layer_norm_after_embedding : bool, optional | ||
If True, applies layer normalization after embeddings. Default is `False`. | ||
use_cls : bool, optional | ||
If True, includes a class token in the embeddings. Default is `False`. | ||
cls_position : int, optional | ||
Position to place the class token, either at the start (0) or end (1) of the sequence. Default is `0`. | ||
Methods | ||
------- | ||
forward(num_features=None, cat_features=None) | ||
Defines the forward pass of the model. | ||
""" | ||
super(EmbeddingLayer, self).__init__() | ||
|
||
self.d_model = d_model | ||
self.embedding_activation = embedding_activation | ||
self.layer_norm_after_embedding = layer_norm_after_embedding | ||
self.use_cls = use_cls | ||
self.cls_position = cls_position | ||
|
||
self.num_embeddings = nn.ModuleList( | ||
[ | ||
nn.Sequential( | ||
nn.Linear(input_shape, d_model, bias=False), | ||
self.embedding_activation, | ||
) | ||
for feature_name, input_shape in num_feature_info.items() | ||
] | ||
) | ||
|
||
self.cat_embeddings = nn.ModuleList() | ||
for feature_name, num_categories in cat_feature_info.items(): | ||
if cat_encoding == "int": | ||
self.cat_embeddings.append( | ||
nn.Sequential( | ||
nn.Embedding(num_categories + 1, d_model), | ||
self.embedding_activation, | ||
) | ||
) | ||
elif cat_encoding == "one-hot": | ||
self.cat_embeddings.append( | ||
nn.Sequential( | ||
OneHotEncoding(num_categories), | ||
nn.Linear(num_categories, d_model, bias=False), | ||
self.embedding_activation, | ||
) | ||
) | ||
|
||
if self.use_cls: | ||
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) | ||
if layer_norm_after_embedding: | ||
self.embedding_norm = nn.LayerNorm(d_model) | ||
|
||
self.seq_len = len(self.num_embeddings) + len(self.cat_embeddings) | ||
|
||
def forward(self, num_features=None, cat_features=None): | ||
""" | ||
Defines the forward pass of the model. | ||
Parameters | ||
---------- | ||
num_features : Tensor, optional | ||
Tensor containing the numerical features. | ||
cat_features : Tensor, optional | ||
Tensor containing the categorical features. | ||
Returns | ||
------- | ||
Tensor | ||
The output embeddings of the model. | ||
Raises | ||
------ | ||
ValueError | ||
If no features are provided to the model. | ||
""" | ||
if self.use_cls: | ||
batch_size = ( | ||
cat_features[0].size(0) | ||
if cat_features != [] | ||
else num_features[0].size(0) | ||
) | ||
cls_tokens = self.cls_token.expand(batch_size, -1, -1) | ||
|
||
if self.cat_embeddings and cat_features is not None: | ||
cat_embeddings = [ | ||
emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings) | ||
] | ||
cat_embeddings = torch.stack(cat_embeddings, dim=1) | ||
cat_embeddings = torch.squeeze(cat_embeddings, dim=2) | ||
if self.layer_norm_after_embedding: | ||
cat_embeddings = self.embedding_norm(cat_embeddings) | ||
else: | ||
cat_embeddings = None | ||
|
||
if self.num_embeddings and num_features is not None: | ||
num_embeddings = [ | ||
emb(num_features[i]) for i, emb in enumerate(self.num_embeddings) | ||
] | ||
num_embeddings = torch.stack(num_embeddings, dim=1) | ||
if self.layer_norm_after_embedding: | ||
num_embeddings = self.embedding_norm(num_embeddings) | ||
else: | ||
num_embeddings = None | ||
|
||
if cat_embeddings is not None and num_embeddings is not None: | ||
x = torch.cat([cat_embeddings, num_embeddings], dim=1) | ||
elif cat_embeddings is not None: | ||
x = cat_embeddings | ||
elif num_embeddings is not None: | ||
x = num_embeddings | ||
else: | ||
raise ValueError("No features provided to the model.") | ||
|
||
if self.use_cls: | ||
if self.cls_position == 0: | ||
x = torch.cat([cls_tokens, x], dim=1) | ||
elif self.cls_position == 1: | ||
x = torch.cat([x, cls_tokens], dim=1) | ||
else: | ||
raise ValueError( | ||
"Invalid cls_position value. It should be either 0 or 1." | ||
) | ||
|
||
return x | ||
|
||
|
||
class OneHotEncoding(nn.Module): | ||
def __init__(self, num_categories): | ||
super(OneHotEncoding, self).__init__() | ||
self.num_categories = num_categories | ||
|
||
def forward(self, x): | ||
return torch.nn.functional.one_hot(x, num_classes=self.num_categories).float() |
Oops, something went wrong.