-
Notifications
You must be signed in to change notification settings - Fork 0
/
seq2seq.py
35 lines (25 loc) · 1.28 KB
/
seq2seq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, src, trg):
#src = [batch size, src len]
#trg = [batch size, trg len - 1] (<eos> token sliced off the end)
#calculate z^u (encoder_conved) and (z^u + e) (encoder_combined)
#encoder_conved is output from final encoder conv. block
#encoder_combined is encoder_conved plus (elementwise) src embedding plus
# positional embeddings
encoder_conved, encoder_combined = self.encoder(src)
#encoder_conved = [batch size, src len, emb dim]
#encoder_combined = [batch size, src len, emb dim]
#calculate predictions of next words
#output is a batch of predictions for each word in the trg sentence
#attention a batch of attention scores across the src sentence for
# each word in the trg sentence
output, attention = self.decoder(trg, encoder_conved, encoder_combined)
#output = [batch size, trg len - 1, output dim]
#attention = [batch size, trg len - 1, src len]
return output, attention