Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Attention in seq2seq_translation_tutorial AttnDecoderRNN #2452

Merged
merged 7 commits into from
Jun 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified _static/img/seq-seq-images/attention-decoder-network.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
38 changes: 20 additions & 18 deletions intermediate_source/seq2seq_translation_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,25 +440,27 @@ def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGT
self.max_length = max_length

self.embedding = nn.Embedding(self.output_size, self.hidden_size)
self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
self.fc_hidden = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.fc_encoder = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.alignment_vector = nn.Parameter(torch.Tensor(1, hidden_size))
torch.nn.init.xavier_uniform_(self.alignment_vector)
self.dropout = nn.Dropout(self.dropout_p)
self.gru = nn.GRU(self.hidden_size, self.hidden_size)
self.gru = nn.GRU(self.hidden_size * 2, self.hidden_size)
self.out = nn.Linear(self.hidden_size, self.output_size)

def forward(self, input, hidden, encoder_outputs):
embedded = self.embedding(input).view(1, 1, -1)
embedded = self.embedding(input).view(1, -1)
embedded = self.dropout(embedded)

attn_weights = F.softmax(
self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
attn_applied = torch.bmm(attn_weights.unsqueeze(0),
encoder_outputs.unsqueeze(0))

output = torch.cat((embedded[0], attn_applied[0]), 1)
output = self.attn_combine(output).unsqueeze(0)
transformed_hidden = self.fc_hidden(hidden[0])
expanded_hidden_state = transformed_hidden.expand(self.max_length, -1)
alignment_scores = torch.tanh(expanded_hidden_state +
self.fc_encoder(encoder_outputs))
alignment_scores = self.alignment_vector.mm(alignment_scores.T)
attn_weights = F.softmax(alignment_scores, dim=1)
context_vector = attn_weights.mm(encoder_outputs)

output = F.relu(output)
output = torch.cat((embedded, context_vector), 1).unsqueeze(0)
output, hidden = self.gru(output, hidden)

output = F.log_softmax(self.out(output[0]), dim=1)
Expand Down Expand Up @@ -761,15 +763,15 @@ def evaluateRandomly(encoder, decoder, n=10):
#

hidden_size = 256
encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)
attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)
encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
attn_decoder = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)

trainIters(encoder1, attn_decoder1, 75000, print_every=5000)
trainIters(encoder, attn_decoder, 75000, print_every=5000)

######################################################################
#

evaluateRandomly(encoder1, attn_decoder1)
evaluateRandomly(encoder, attn_decoder)


######################################################################
Expand All @@ -787,7 +789,7 @@ def evaluateRandomly(encoder, decoder, n=10):
#

output_words, attentions = evaluate(
encoder1, attn_decoder1, "je suis trop froid .")
encoder, attn_decoder, "je suis trop froid .")
plt.matshow(attentions.numpy())


Expand Down Expand Up @@ -817,7 +819,7 @@ def showAttention(input_sentence, output_words, attentions):

def evaluateAndShowAttention(input_sentence):
output_words, attentions = evaluate(
encoder1, attn_decoder1, input_sentence)
encoder, attn_decoder, input_sentence)
print('input =', input_sentence)
print('output =', ' '.join(output_words))
showAttention(input_sentence, output_words, attentions)
Expand Down