Skip to content

Commit

Permalink
Update transformer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
KOSASIH authored Jul 28, 2024
1 parent 080abae commit b6b68cc
Showing 1 changed file with 101 additions and 39 deletions.
140 changes: 101 additions & 39 deletions ai/models/transformer.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,107 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

class Transformer(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(Transformer, self).__init__()
self.encoder = nn.TransformerEncoderLayer(d_model=input_dim, nhead=8, dim_feedforward=hidden_dim)
self.decoder = nn.TransformerDecoderLayer(d_model=input_dim, nhead=8, dim_feedforward=hidden_dim)
self.fc = nn.Linear(hidden_dim, output_dim)
class TransformerModel(nn.Module):
def __init__(self, input_dim, output_dim, num_heads, dropout):
super(TransformerModel, self).__init__()
self.encoder = TransformerEncoder(input_dim, num_heads, dropout)
self.decoder = TransformerDecoder(output_dim, num_heads, dropout)

def forward(self, x):
encoder_output = self.encoder(x)
decoder_output = self.decoder(encoder_output)
output = self.fc(decoder_output)
return output

class TransformerModel:
def __init__(self, input_dim, hidden_dim, output_dim):
self.model = Transformer(input_dim, hidden_dim, output_dim)
self.criterion = nn.CrossEntropyLoss()
self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)

def train(self, X_train, y_train):
self.model.train()
for epoch in range(10):
for x, y in zip(X_train, y_train):
x = torch.tensor(x, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)
self.optimizer.zero_grad()
outputs = self.model(x)
loss = self.criterion(outputs, y)
loss.backward()
self.optimizer.step()

def predict(self, X_test):
self.model.eval()
predictions = []
with torch.no_grad():
for x in X_test:
x = torch.tensor(x, dtype=torch.float32)
outputs = self.model(x)
_, predicted = torch.max(outputs, 1)
predictions.append(predicted.item())
return predictions
x = self.encoder(x)
x = self.decoder(x)
return x

class TransformerEncoder(nn.Module):
def __init__(self, input_dim, num_heads, dropout):
super(TransformerEncoder, self).__init__()
self.self_attn = MultiHeadAttention(input_dim, num_heads, dropout)
self.feed_forward = FeedForward(input_dim, dropout)

def forward(self, x):
x = self.self_attn(x)
x = self.feed_forward(x)
return x

class TransformerDecoder(nn.Module):
def __init__(self, output_dim, num_heads, dropout):
super(TransformerDecoder, self).__init__()
self.self_attn = MultiHeadAttention(output_dim, num_heads, dropout)
self.feed_forward = FeedForward(output_dim, dropout)

def forward(self, x):
x = self.self_attn(x)
x = self.feed_forward(x)
return x

class MultiHeadAttention(nn.Module):
def __init__(self, input_dim, num_heads, dropout):
super(MultiHeadAttention, self).__init__()
self.query_linear = nn.Linear(input_dim, input_dim)
self.key_linear = nn.Linear(input_dim, input_dim)
self.value_linear = nn.Linear(input_dim, input_dim)
self.dropout = nn.Dropout(dropout)

def forward(self, x):
queries = self.query_linear(x)
keys = self.key_linear(x)
values = self.value_linear(x)
attention_weights = torch.matmul(queries, keys.T) / math.sqrt(input_dim)
attention_weights = self.dropout(attention_weights)
context = torch.matmul(attention_weights, values)
return context

class FeedForward(nn.Module):
def __init__(self, input_dim, dropout):
super(FeedForward, self).__init__()
self.linear1 = nn.Linear(input_dim, input_dim)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(input_dim, input_dim)

def forward(self, x):
x = self.linear1(x)
x = self.dropout(x)
x = self.linear2(x)
return x

class TransformerDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
data = self.data[idx]
label = self.labels[idx]
return data, label

def train(model, device, loader, optimizer, epoch):
model.train()
for batch_idx, (data, labels) in enumerate(loader):
data, labels = data.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(data)
loss = nn.CrossEntropyLoss()(outputs, labels)
loss.backward()
optimizer.step()
print(f'Epoch {epoch+1}, Batch {batch_idx+1}, Loss: {loss.item()}')

def test(model, device, loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, labels in loader:
data, labels = data.to(device), labels.to(device)
outputs = model(data)
loss = nn.CrossEntropyLoss()(outputs, labels)
test_loss += loss.item()
_, predicted = torch.max(outputs, 1)
correct += (predicted == labels).sum().item()
accuracy = correct / len(loader.dataset)
print(f'Test Loss: {test_loss / len(loader)}')
print(f'Test Accuracy: {accuracy:.2f}%')

0 comments on commit b6b68cc

Please sign in to comment.