Skip to content

Commit

Permalink
implementations and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
schauppi committed May 22, 2024
1 parent ededf0a commit 588bbfb
Show file tree
Hide file tree
Showing 12 changed files with 87 additions and 300 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ Results on an sine wave with 5 hidden units and 500 epochs.

## sLSTM

![Figure 2](images/sLSTM_5.png)
![Figure 2](images/sLSTMCell_10.png)

## mLSTM

![Figure 3](images/mLSTM_5.png)
![Figure 3](images/mLSTMCell_10.png)

# To Do

- [ ] Check implementation of mLSTM - seems somewhat off
- [X] Check implementation of mLSTM - seems somewhat off
- [ ] Implement xLSTM - stack Cells together
Binary file added images/mLSTMCell_10.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/sLSTMCell_10.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
90 changes: 0 additions & 90 deletions src/mLSTM.py

This file was deleted.

7 changes: 6 additions & 1 deletion src/mLSTMCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def __init__(self, input_size, hidden_size, bias=True):
self.bk = nn.Parameter((torch.zeros(hidden_size)), requires_grad=True)
self.bv = nn.Parameter((torch.zeros(hidden_size)), requires_grad=True)

self.linear = nn.Linear(hidden_size, 1)

def init_hidden(self, batch_size):
return (
torch.zeros(batch_size, self.hidden_size, self.hidden_size),
Expand Down Expand Up @@ -137,4 +139,7 @@ def forward(self, x, hidden):
# Hidden state part 2 - equation (21) -> (batch_size, hidden_size)
ht = ot * h_tilde

return ht, (ct, nt, mt)
# Map the hidden state to the output -> (batch_size, 1)
output = self.linear(ht)

return output, (ct, nt, mt)
79 changes: 0 additions & 79 deletions src/sLSTM.py

This file was deleted.

9 changes: 7 additions & 2 deletions src/sLSTMCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def __init__(self, input_size, hidden_size, bias=True):
if self.bias:
self.B = nn.Parameter((torch.zeros(4 * hidden_size)), requires_grad=True)

self.linear = nn.Linear(hidden_size, 1)

def init_hidden(self, batch_size):
return (
torch.zeros(batch_size, self.hidden_size),
Expand Down Expand Up @@ -67,5 +69,8 @@ def forward(self, x, hidden):
# Hidden state - equation (10) -> (batch_size, hidden_size)
ht = ot * (c / nt)

# -> (batch_size, hidden_size), (batch_size, hidden_size, batch_size, hidden_size, batch_size, hidden_size, batch_size, hidden_size)
return ht, (ht, ct, nt, mt)
# Map the hidden state to the output -> (batch_size, 1)
output = self.linear(ht)

# -> (batch_size, 1), (batch_size, hidden_size, batch_size, hidden_size, batch_size, hidden_size, batch_size, hidden_size)
return output, (ht, ct, nt, mt)
10 changes: 0 additions & 10 deletions src/utils/sine_wave.py

This file was deleted.

57 changes: 0 additions & 57 deletions src/utils/test_mLSTM.py

This file was deleted.

63 changes: 63 additions & 0 deletions src/utils/test_mLSTMCell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from src.mLSTMCell import mLSTMCell

# Generate a sine wave
num_points = 100
time = np.linspace(0, 4 * np.pi, num_points)
data = np.sin(time)
data = torch.tensor(data, dtype=torch.float32)

# Hyperparameters
input_size = 1
hidden_size = 10
output_size = 1
num_layers = 1
learning_rate = 0.01
num_epochs = 500

# Model, loss function, and optimizer
model = mLSTMCell(input_size, hidden_size)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):

hidden = model.init_hidden(1)

optimizer.zero_grad()
loss = 0

for i in range(num_points - 1):
inputs = data[i].view(1, 1)
targets = data[i + 1].view(1, 1)
outputs, hidden = model(inputs, hidden)
loss += criterion(outputs, targets)

loss.backward()
optimizer.step()

if epoch % 10 == 0:
print(f"Epoch {epoch}, Loss: {loss.item()}")

# Test the model
model.eval()
with torch.no_grad():
predictions = []
hidden = model.init_hidden(1)
for i in range(num_points - 1):
inputs = data[i].view(1, 1)
outputs, hidden = model(inputs, hidden)
predictions.append(outputs.item())

# Plot the results
plt.figure(figsize=(12, 6))
plt.title(f"mLSTM - Original vs Predicted Sine Wave, hidden_size={hidden_size}")
plt.plot(time[1:], data[1:], label="Original")
plt.plot(time[1:], predictions, label="Predicted")
plt.legend()
plt.savefig(f"images/mLSTMCell_{hidden_size}.png")
plt.show()
Loading

0 comments on commit 588bbfb

Please sign in to comment.