Skip to content

Commit

Permalink
start implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
schauppi committed May 14, 2024
1 parent fdb8895 commit 460bf45
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions src/mLSTM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
import torch.nn as nn

torch.manual_seed(42)


class mLSTM(nn.Module):
def __init__(self, n_embed):
super().__init__()
self.n_embeb = n_embed

# Query
self.W_q = nn.Linear(n_embed, n_embed)
# Key
self.W_k = nn.Linear(n_embed, n_embed)
# Value
self.W_v = nn.Linear(n_embed, n_embed)
# Forget
self.W_f = nn.Linear(n_embed, n_embed)
# Input
self.W_i = nn.Linear(n_embed, n_embed)
# Output projection
self.output_projection = nn.Linear(n_embed, n_embed)

def forward(self, x):
batch_size, seq_len, _ = x.size()

q = self.W_q(x)
k = self.W_k(x)
v = self.W_q(x)

f = torch.sigmoid(self.W_f(x))
i = torch.exp(self.W_i(x))

0 comments on commit 460bf45

Please sign in to comment.