Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change to the correct architecture of SAINT #6

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 71 additions & 59 deletions saint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,51 +24,50 @@ class Encoder_block(nn.Module):
O = SkipConct(FFN(LayerNorm(M)))
"""

def __init__(self , dim_model, heads_en, total_ex ,total_cat, seq_len):
def __init__(self , dim_model, heads_en, total_ex ,total_cat, seq_len, dropout):
super().__init__()
self.dim_model = dim_model
self.seq_len = seq_len
self.embd_ex = nn.Embedding( total_ex , embedding_dim = dim_model ) # embedings q,k,v = E = exercise ID embedding, category embedding, and positionembedding.
self.embd_cat = nn.Embedding( total_cat, embedding_dim = dim_model )
self.embd_pos = nn.Embedding( seq_len , embedding_dim = dim_model ) #positional embedding

self.multi_en = nn.MultiheadAttention( embed_dim= dim_model, num_heads= heads_en, ) # multihead attention ## todo add dropout, LayerNORM
self.ffn_en = Feed_Forward_block( dim_model ) # feedforward block ## todo dropout, LayerNorm
self.multi_en = nn.MultiheadAttention( embed_dim= dim_model, num_heads= heads_en, ) # multihead attention
self.ffn_en = Feed_Forward_block( dim_model ) # feedforward block
self.dropout = nn.Dropout(dropout)
self.layer_norm1 = nn.LayerNorm( dim_model )
self.layer_norm2 = nn.LayerNorm( dim_model )


def forward(self, in_ex, in_cat, first_block=True):

## todo create a positional encoding ( two options numeric, sine)
if first_block:
in_ex = self.embd_ex( in_ex )
in_cat = self.embd_cat( in_cat )
#in_pos = self.embd_pos( in_pos )
in_pos = position_embedding(in_ex.shape[0], self.seq_len, self.dim_model)
#combining the embedings
out = in_ex + in_cat #+ in_pos # (b,n,d)
out = in_ex + in_cat + in_pos # (b,n,d)
else:
out = in_ex

in_pos = get_pos(self.seq_len)
in_pos = self.embd_pos( in_pos )
out = out + in_pos # Applying positional embedding

out = self.dropout(out)
out = out.permute(1,0,2) # (n,b,d) # print('pre multi', out.shape )

#Multihead attention
n,_,_ = out.shape
out = self.layer_norm1( out ) # Layer norm
skip_out = out
out, attn_wt = self.multi_en( out , out , out ,
attn_mask=get_mask(seq_len=n)) # attention mask upper triangular
out = self.dropout(out)
out = out + skip_out # skip connection
out = self.layer_norm1( out ) # Layer norm

#feed forward
out = out.permute(1,0,2) # (b,n,d)
out = self.layer_norm2( out ) # Layer norm
out = out.permute(1,0,2) # (b,n,d)
skip_out = out
out = self.ffn_en( out )
out = self.dropout(out)
out = out + skip_out # skip connection
out = self.layer_norm2( out ) # Layer norm

return out

Expand All @@ -80,59 +79,58 @@ class Decoder_block(nn.Module):
L = SkipConct(FFN(LayerNorm(M2)))
"""

def __init__(self,dim_model ,total_in, heads_de,seq_len ):
def __init__(self,dim_model, total_in, heads_de, seq_len, dropout):
super().__init__()
self.dim_model = dim_model
self.seq_len = seq_len
self.embd_in = nn.Embedding( total_in , embedding_dim = dim_model ) #interaction embedding
self.embd_pos = nn.Embedding( seq_len , embedding_dim = dim_model ) #positional embedding
self.multi_de1 = nn.MultiheadAttention( embed_dim= dim_model, num_heads= heads_de ) # M1 multihead for interaction embedding as q k v
self.multi_de2 = nn.MultiheadAttention( embed_dim= dim_model, num_heads= heads_de ) # M2 multihead for M1 out, encoder out, encoder out as q k v
self.ffn_en = Feed_Forward_block( dim_model ) # feed forward layer

self.dropout = nn.Dropout(dropout)
self.layer_norm1 = nn.LayerNorm( dim_model )
self.layer_norm2 = nn.LayerNorm( dim_model )
self.layer_norm3 = nn.LayerNorm( dim_model )


def forward(self, in_in, en_out,first_block=True):

## todo create a positional encoding ( two options numeric, sine)
if first_block:
in_in = self.embd_in( in_in )

in_pos = position_embedding(in_in.shape[0], self.seq_len, self.dim_model)
#combining the embedings
out = in_in #+ in_cat #+ in_pos # (b,n,d)
out = in_in + in_pos # (b,n,d)
else:
out = in_in

in_pos = get_pos(self.seq_len)
in_pos = self.embd_pos( in_pos )
out = out + in_pos # Applying positional embedding

out = self.dropout(out)
out = out.permute(1,0,2) # (n,b,d)# print('pre multi', out.shape )
n,_,_ = out.shape

#Multihead attention M1 ## todo verify if E to passed as q,k,v
out = self.layer_norm1( out )
#Multihead attention M1
n,_,_ = out.shape
skip_out = out
out, attn_wt = self.multi_de1( out , out , out,
attn_mask=get_mask(seq_len=n)) # attention mask upper triangular
out = self.dropout(out)
out = skip_out + out # skip connection
out = self.layer_norm1( out )

#Multihead attention M2 ## todo verify if E to passed as q,k,v
#Multihead attention M2
en_out = en_out.permute(1,0,2) # (b,n,d)-->(n,b,d)
en_out = self.layer_norm2( en_out )
skip_out = out
out, attn_wt = self.multi_de2( out , en_out , en_out,
attn_mask=get_mask(seq_len=n)) # attention mask upper triangular
out = out + skip_out
en_out = self.layer_norm2( en_out )

#feed forward
out = out.permute(1,0,2) # (b,n,d)
out = self.layer_norm3( out ) # Layer norm
skip_out = out
out = self.ffn_en( out )
out = self.dropout(out)
out = out + skip_out # skip connection
out = self.layer_norm3( out ) # Layer norm

return out

Expand All @@ -141,22 +139,32 @@ def get_clones(module, N):


def get_mask(seq_len):
##todo add this to device
return torch.from_numpy( np.triu(np.ones((seq_len ,seq_len)), k=1).astype('bool'))

def get_pos(seq_len):
# use sine positional embeddinds
return torch.arange( seq_len ).unsqueeze(0)
def position_encoding(pos, dim_model):
# Encode one position with sin and cos
# Attention Is All You Need uses positinal sines, SAINT paper does not specify
pos_enc = np.zeros(dim_model)
for i in range(0, dim_model, 2):
pos_enc[i] = np.sin(pos / (10000 ** (2 * i / dim_model)))
pos_enc[i + 1] = np.cos(pos / (10000 ** (2 * i / dim_model)))
return pos_enc


def position_embedding(bs, seq_len, dim_model):
# Return the position embedding for the whole sequence
pe_array = np.array([[position_encoding(pos, dim_model) for pos in range(seq_len)]] * bs)
return torch.from_numpy(pe_array).float()

class saint(nn.Module):
def __init__(self,dim_model,num_en, num_de ,heads_en, total_ex ,total_cat,total_in,heads_de,seq_len ):
def __init__(self,dim_model,num_en, num_de ,heads_en, total_ex ,total_cat,total_in,heads_de,seq_len, dropout):
super().__init__( )

self.num_en = num_en
self.num_de = num_de

self.encoder = get_clones( Encoder_block(dim_model, heads_en , total_ex ,total_cat,seq_len) , num_en)
self.decoder = get_clones( Decoder_block(dim_model ,total_in, heads_de,seq_len) , num_de)
self.encoder = get_clones( Encoder_block(dim_model, heads_en , total_ex ,total_cat,seq_len, dropout) , num_en)
self.decoder = get_clones( Decoder_block(dim_model ,total_in, heads_de,seq_len, dropout) , num_de)

self.out = nn.Linear(in_features= dim_model , out_features=1)

Expand All @@ -183,35 +191,39 @@ def forward(self,in_ex, in_cat, in_in ):
return in_in


## forward prop on dummy data
if __name__ == "__main__":
## forward prop on dummy data

seq_len = 100
total_ex = 1200
total_cat = 234
total_in = 2

seq_len = 100
total_ex = 1200
total_cat = 234
total_in = 2
dropout_rate = 0.1


def random_data(bs, seq_len , total_ex, total_cat, total_in = 2):
ex = torch.randint( 0 , total_ex ,(bs , seq_len) )
cat = torch.randint( 0 , total_cat ,(bs , seq_len) )
de = torch.randint( 0 , total_in ,(bs , seq_len) )
return ex,cat, de
def random_data(bs, seq_len , total_ex, total_cat, total_in = 2):
ex = torch.randint( 0 , total_ex ,(bs , seq_len) )
cat = torch.randint( 0 , total_cat ,(bs , seq_len) )
de = torch.randint( 0 , total_in ,(bs , seq_len) )
return ex,cat, de


in_ex, in_cat, in_de = random_data(64, seq_len , total_ex, total_cat, total_in)
in_ex, in_cat, in_de = random_data(64, seq_len , total_ex, total_cat, total_in)


model = saint(dim_model=128,
num_en=6,
num_de=6,
heads_en=8,
heads_de=8,
total_ex=total_ex,
total_cat=total_cat,
total_in=total_in,
seq_len=seq_len
)
model = saint(dim_model=128,
num_en=6,
num_de=6,
heads_en=8,
heads_de=8,
total_ex=total_ex,
total_cat=total_cat,
total_in=total_in,
seq_len=seq_len,
dropout=dropout_rate
)

outs = model(in_ex, in_cat, in_de)
outs = model(in_ex, in_cat, in_de)

print(outs.shape)
print(outs.shape)