Skip to content

Commit

Permalink
Fix Attention in seq2seq_translation_tutorial AttnDecoderRNN (#2452)
Browse files Browse the repository at this point in the history
* replace old decoder diagram with new one
* remove 1 from encoder1 and decoder1
* fix attention in AttnDecoderRNN
* Fix formatting going over max character count

---------

Co-authored-by: Svetlana Karslioglu <svekars@fb.com>
  • Loading branch information
QasimKhan5x and Svetlana Karslioglu authored Jun 9, 2023
1 parent 203f567 commit a5376f7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 18 deletions.
Binary file modified _static/img/seq-seq-images/attention-decoder-network.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
38 changes: 20 additions & 18 deletions intermediate_source/seq2seq_translation_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,25 +440,27 @@ def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGT
self.max_length = max_length

self.embedding = nn.Embedding(self.output_size, self.hidden_size)
self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
self.fc_hidden = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.fc_encoder = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.alignment_vector = nn.Parameter(torch.Tensor(1, hidden_size))
torch.nn.init.xavier_uniform_(self.alignment_vector)
self.dropout = nn.Dropout(self.dropout_p)
self.gru = nn.GRU(self.hidden_size, self.hidden_size)
self.gru = nn.GRU(self.hidden_size * 2, self.hidden_size)
self.out = nn.Linear(self.hidden_size, self.output_size)

def forward(self, input, hidden, encoder_outputs):
embedded = self.embedding(input).view(1, 1, -1)
embedded = self.embedding(input).view(1, -1)
embedded = self.dropout(embedded)

attn_weights = F.softmax(
self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
attn_applied = torch.bmm(attn_weights.unsqueeze(0),
encoder_outputs.unsqueeze(0))

output = torch.cat((embedded[0], attn_applied[0]), 1)
output = self.attn_combine(output).unsqueeze(0)
transformed_hidden = self.fc_hidden(hidden[0])
expanded_hidden_state = transformed_hidden.expand(self.max_length, -1)
alignment_scores = torch.tanh(expanded_hidden_state +
self.fc_encoder(encoder_outputs))
alignment_scores = self.alignment_vector.mm(alignment_scores.T)
attn_weights = F.softmax(alignment_scores, dim=1)
context_vector = attn_weights.mm(encoder_outputs)

output = F.relu(output)
output = torch.cat((embedded, context_vector), 1).unsqueeze(0)
output, hidden = self.gru(output, hidden)

output = F.log_softmax(self.out(output[0]), dim=1)
Expand Down Expand Up @@ -761,15 +763,15 @@ def evaluateRandomly(encoder, decoder, n=10):
#

hidden_size = 256
encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)
attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)
encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
attn_decoder = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)

trainIters(encoder1, attn_decoder1, 75000, print_every=5000)
trainIters(encoder, attn_decoder, 75000, print_every=5000)

######################################################################
#

evaluateRandomly(encoder1, attn_decoder1)
evaluateRandomly(encoder, attn_decoder)


######################################################################
Expand All @@ -787,7 +789,7 @@ def evaluateRandomly(encoder, decoder, n=10):
#

output_words, attentions = evaluate(
encoder1, attn_decoder1, "je suis trop froid .")
encoder, attn_decoder, "je suis trop froid .")
plt.matshow(attentions.numpy())


Expand Down Expand Up @@ -817,7 +819,7 @@ def showAttention(input_sentence, output_words, attentions):

def evaluateAndShowAttention(input_sentence):
output_words, attentions = evaluate(
encoder1, attn_decoder1, input_sentence)
encoder, attn_decoder, input_sentence)
print('input =', input_sentence)
print('output =', ' '.join(output_words))
showAttention(input_sentence, output_words, attentions)
Expand Down

0 comments on commit a5376f7

Please sign in to comment.