Skip to content

Commit

Permalink
Merge pull request #10 from schauppi/mLSTM
Browse files Browse the repository at this point in the history
implement and test mLSTM
  • Loading branch information
schauppi authored May 15, 2024
2 parents c697f74 + 09315e2 commit 4c6d6d0
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 17 deletions.
70 changes: 53 additions & 17 deletions src/mLSTM.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,69 @@
import torch
import torch.nn as nn
import math

torch.manual_seed(42)


class mLSTM(nn.Module):
def __init__(self, n_embed):
def __init__(self, input_size, hidden_size, mem_dim):
super().__init__()
self.n_embeb = n_embed
self.input_size = input_size
self.hidden_size = hidden_size
self.mem_dim = mem_dim

# Query
self.W_q = nn.Linear(n_embed, n_embed)
self.Wq = nn.Parameter(torch.randn(hidden_size, input_size))
self.bq = nn.Parameter(torch.randn(hidden_size, 1))
# Key
self.W_k = nn.Linear(n_embed, n_embed)
self.Wk = nn.Parameter(torch.randn(mem_dim, input_size))
self.bk = nn.Parameter(torch.randn(mem_dim, 1))
# Value
self.W_v = nn.Linear(n_embed, n_embed)
# Forget
self.W_f = nn.Linear(n_embed, n_embed)
self.Wv = nn.Parameter(torch.randn(mem_dim, input_size))
self.bV = nn.Parameter(torch.randn(mem_dim, 1))
# Input
self.W_i = nn.Linear(n_embed, n_embed)
# Output projection
self.output_projection = nn.Linear(n_embed, n_embed)
self.Wi = nn.Parameter(torch.randn(1, input_size))
self.bi = nn.Parameter(torch.randn(1))
# Forget
self.Wf = nn.Parameter(torch.randn(1, input_size))
self.bf = nn.Parameter(torch.randn(1))
# Out
self.Wo = nn.Parameter(torch.randn(1, input_size))
self.bo = nn.Parameter(torch.randn(1))

def forward(self, x, hidden_states):
cp, np = hidden_states

qt = torch.matmul(self.Wq, x) + self.bq
kt = (1, math.sqrt(self.mem_dim)) * (torch.matmul(self.Wk, x) + self.bk)
vt = torch.matmul(self.Wv, x) + self.bv

# Input gate
it_tilde = torch.matmul(self.Wi, x) + self.bi
it = torch.exp(it_tilde)

# Forget gate
ft_tilde = torch.matmul(self.Wf, x) + self.bf
# Test torch.sigmoid as well - stated in the paper formula 26
ft = torch.exp(ft_tilde)

# Remove dimension
vt = vt.squeeze()
kt = kt.squeeze()

# cell state calculation using outer product
C = ft * cp + it * torch.ger(vt, kt)
# normalizer state
n = ft * np + it * kt.unsqueeze(1)

def forward(self, x):
batch_size, seq_len, _ = x.size()
# hidden state - stated in paper formula 21
h_tilde = torch.matmul(C, qt) / (
torch.max(torch.abs(torch.matmul(n.T, qt)), torch.tensor(1.0))
)
# output gate
ot = torch.sigmoid(torch.matmul(self.Wo, x) + self.bo)

q = self.W_q(x)
k = self.W_k(x)
v = self.W_q(x)
# Hidden state
ht = ot * h_tilde

f = torch.sigmoid(self.W_f(x))
i = torch.exp(self.W_i(x))
return ht, (C, n)
26 changes: 26 additions & 0 deletions src/utils/test_mLSTM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import numpy as np
import torch
import torch.nn as nn

from src.mLSTM import mLSTM


def generate_sine_wave(seq_len, num_sequences):
x = np.linspace(0, 2 * np.pi, seq_len)
y = np.sin(x)
return torch.tensor(y).float().view(-1, 1).repeat(1, num_sequences).unsqueeze(0)


input_size = 1
hidden_size = 5
mem_dim = 5
seq_len = 100
num_sequences = 2

data = generate_sine_wave(seq_len=seq_len, num_sequences=num_sequences)
model = mLSTM(input_size=input_size, hidden_size=hidden_size, mem_dim=mem_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterien = nn.MSELoss()

for epoch in range(200):
states = (torch.zeros(mem_dim, mem_dim), torch.zeros(mem_dim, 1))

0 comments on commit 4c6d6d0

Please sign in to comment.