From 970901679459da9ba9f8bc183fc8bb1c090ea71f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Luis=20Castro=20Garc=C3=ADa?= Date: Tue, 13 Jun 2023 22:49:18 -0600 Subject: [PATCH 1/3] Bahdanau attention changed --- .../seq2seq_translation_tutorial.py | 369 +++++++++--------- 1 file changed, 188 insertions(+), 181 deletions(-) mode change 100644 => 100755 intermediate_source/seq2seq_translation_tutorial.py diff --git a/intermediate_source/seq2seq_translation_tutorial.py b/intermediate_source/seq2seq_translation_tutorial.py old mode 100644 new mode 100755 index c2b0b722e5..81b3803b7a --- a/intermediate_source/seq2seq_translation_tutorial.py +++ b/intermediate_source/seq2seq_translation_tutorial.py @@ -45,7 +45,7 @@ :alt: To improve upon this model we'll use an `attention -mechanism `__, which lets the decoder +mechanism `__, which lets the decoder learn to focus over a specific range of the input sequence. **Recommended Reading:** @@ -66,8 +66,8 @@ Statistical Machine Translation `__ - `Sequence to Sequence Learning with Neural Networks `__ -- `Effective Approaches to Attention-based Neural Machine - Translation `__ +- `Neural Machine Translation by Jointly Learning to Align and + Translate `__ - `A Neural Conversational Model `__ You will also find the previous tutorials on @@ -78,9 +78,9 @@ **Requirements** """ +from __future__ import unicode_literals, print_function, division from io import open import unicodedata -import string import re import random @@ -89,6 +89,9 @@ from torch import optim import torch.nn.functional as F +import numpy as np +from torch.utils.data import TensorDataset, DataLoader, RandomSampler + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ###################################################################### @@ -144,7 +147,6 @@ SOS_token = 0 EOS_token = 1 - class Lang: def __init__(self, name): self.name = name @@ -182,13 +184,11 @@ def unicodeToAscii(s): ) # Lowercase, trim, and remove non-letter characters - - def normalizeString(s): s = unicodeToAscii(s.lower().strip()) s = re.sub(r"([.!?])", r" \1", s) - s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) - return s + s = re.sub(r"[^a-zA-Z!?]+", r" ", s) + return s.strip() ###################################################################### @@ -240,7 +240,6 @@ def readLangs(lang1, lang2, reverse=False): "they are", "they re " ) - def filterPair(p): return len(p[0].split(' ')) < MAX_LENGTH and \ len(p[1].split(' ')) < MAX_LENGTH and \ @@ -273,7 +272,6 @@ def prepareData(lang1, lang2, reverse=False): print(output_lang.name, output_lang.n_words) return input_lang, output_lang, pairs - input_lang, output_lang, pairs = prepareData('eng', 'fra', True) print(random.choice(pairs)) @@ -329,22 +327,19 @@ def prepareData(lang1, lang2, reverse=False): # class EncoderRNN(nn.Module): - def __init__(self, input_size, hidden_size): + def __init__(self, input_size, hidden_size, dropout_p=0.1): super(EncoderRNN, self).__init__() self.hidden_size = hidden_size self.embedding = nn.Embedding(input_size, hidden_size) - self.gru = nn.GRU(hidden_size, hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True) + self.dropout = nn.Dropout(dropout_p) - def forward(self, input, hidden): - embedded = self.embedding(input).view(1, 1, -1) - output = embedded - output, hidden = self.gru(output, hidden) + def forward(self, input): + embedded = self.dropout(self.embedding(input)) + output, hidden = self.gru(embedded) return output, hidden - def initHidden(self): - return torch.zeros(1, 1, self.hidden_size, device=device) - ###################################################################### # The Decoder # ----------- @@ -374,25 +369,42 @@ def initHidden(self): # class DecoderRNN(nn.Module): + # Standard non-attentional decoder def __init__(self, hidden_size, output_size): super(DecoderRNN, self).__init__() - self.hidden_size = hidden_size - self.embedding = nn.Embedding(output_size, hidden_size) - self.gru = nn.GRU(hidden_size, hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True) self.out = nn.Linear(hidden_size, output_size) - self.softmax = nn.LogSoftmax(dim=1) - def forward(self, input, hidden): - output = self.embedding(input).view(1, 1, -1) + def forward(self, encoder_outputs, encoder_hidden, target_tensor=None): + batch_size = encoder_outputs.size(0) + decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token) + decoder_hidden = encoder_hidden + decoder_outputs = [] + + for i in range(MAX_LENGTH): + decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden) + decoder_outputs.append(decoder_output) + + if target_tensor is not None: + # Teacher forcing: Feed the target as the next input + decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing + else: + # Without teacher forcing: use its own predictions as the next input + _, topi = decoder_output.topk(1) + decoder_input = topi.squeeze(-1).detach() # detach from history as input + + decoder_outputs = torch.cat(decoder_outputs, dim=1) + decoder_outputs = F.log_softmax(decoder_outputs, dim=-1) + return decoder_outputs, decoder_hidden, None # We return `None` for consistency in the training loop + + def forward_step(self, input, hidden): + output = self.embedding(input) output = F.relu(output) output, hidden = self.gru(output, hidden) - output = self.softmax(self.out(output[0])) + output = self.out(output) return output, hidden - def initHidden(self): - return torch.zeros(1, 1, self.hidden_size, device=device) - ###################################################################### # I encourage you to train and observe the results of this model, but to # save space we'll be going straight for the gold and introducing the @@ -431,43 +443,71 @@ def initHidden(self): # # +class BahdanauAttention(nn.Module): + def __init__(self, hidden_size): + super(BahdanauAttention, self).__init__() + self.Wa = nn.Linear(hidden_size, hidden_size) + self.Ua = nn.Linear(hidden_size, hidden_size) + self.Va = nn.Linear(hidden_size, 1) + + def forward(self, query, keys): + scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys))) + scores = scores.squeeze(2).unsqueeze(1) + + weights = F.softmax(scores, dim=-1) + context = torch.bmm(weights, keys) + + return context, weights + class AttnDecoderRNN(nn.Module): - def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH): + def __init__(self, hidden_size, output_size, dropout_p=0.1): super(AttnDecoderRNN, self).__init__() - self.hidden_size = hidden_size - self.output_size = output_size - self.dropout_p = dropout_p - self.max_length = max_length - - self.embedding = nn.Embedding(self.output_size, self.hidden_size) - self.fc_hidden = nn.Linear(self.hidden_size, self.hidden_size, bias=False) - self.fc_encoder = nn.Linear(self.hidden_size, self.hidden_size, bias=False) - self.alignment_vector = nn.Parameter(torch.Tensor(1, hidden_size)) - torch.nn.init.xavier_uniform_(self.alignment_vector) - self.dropout = nn.Dropout(self.dropout_p) - self.gru = nn.GRU(self.hidden_size * 2, self.hidden_size) - self.out = nn.Linear(self.hidden_size, self.output_size) - - def forward(self, input, hidden, encoder_outputs): - embedded = self.embedding(input).view(1, -1) - embedded = self.dropout(embedded) - - transformed_hidden = self.fc_hidden(hidden[0]) - expanded_hidden_state = transformed_hidden.expand(self.max_length, -1) - alignment_scores = torch.tanh(expanded_hidden_state + - self.fc_encoder(encoder_outputs)) - alignment_scores = self.alignment_vector.mm(alignment_scores.T) - attn_weights = F.softmax(alignment_scores, dim=1) - context_vector = attn_weights.mm(encoder_outputs) - - output = torch.cat((embedded, context_vector), 1).unsqueeze(0) - output, hidden = self.gru(output, hidden) + self.embedding = nn.Embedding(output_size, hidden_size) + self.attention = BahdanauAttention(hidden_size) + self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True) + self.out = nn.Linear(hidden_size, output_size) + self.dropout = nn.Dropout(dropout_p) - output = F.log_softmax(self.out(output[0]), dim=1) - return output, hidden, attn_weights + def forward(self, encoder_outputs, encoder_hidden, target_tensor=None): + batch_size = encoder_outputs.size(0) + decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token) + decoder_hidden = encoder_hidden + decoder_outputs = [] + attentions = [] + + for i in range(MAX_LENGTH): + decoder_output, decoder_hidden, attn_weights = self.forward_step( + decoder_input, decoder_hidden, encoder_outputs + ) + decoder_outputs.append(decoder_output) + attentions.append(attn_weights) + + if target_tensor is not None: + # Teacher forcing: Feed the target as the next input + decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing + else: + # Without teacher forcing: use its own predictions as the next input + _, topi = decoder_output.topk(1) + decoder_input = topi.squeeze(-1).detach() # detach from history as input + + decoder_outputs = torch.cat(decoder_outputs, dim=1) + decoder_outputs = F.log_softmax(decoder_outputs, dim=-1) + attentions = torch.cat(attentions, dim=1) + + return decoder_outputs, decoder_hidden, attentions + + + def forward_step(self, input, hidden, encoder_outputs): + embedded = self.dropout(self.embedding(input)) - def initHidden(self): - return torch.zeros(1, 1, self.hidden_size, device=device) + query = hidden.permute(1, 0, 2) + context, attn_weights = self.attention(query, encoder_outputs) + input_gru = torch.cat((embedded, context), dim=2) + + output, hidden = self.gru(input_gru, hidden) + output = self.out(output) + + return output, hidden, attn_weights ###################################################################### @@ -491,18 +531,38 @@ def initHidden(self): def indexesFromSentence(lang, sentence): return [lang.word2index[word] for word in sentence.split(' ')] - def tensorFromSentence(lang, sentence): indexes = indexesFromSentence(lang, sentence) indexes.append(EOS_token) - return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1) - + return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1) def tensorsFromPair(pair): input_tensor = tensorFromSentence(input_lang, pair[0]) target_tensor = tensorFromSentence(output_lang, pair[1]) return (input_tensor, target_tensor) +def get_dataloader(batch_size): + input_lang, output_lang, pairs = prepareData('eng', 'fra', True) + + n = len(pairs) + input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32) + target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32) + + for idx, (inp, tgt) in enumerate(pairs): + inp_ids = indexesFromSentence(input_lang, inp) + tgt_ids = indexesFromSentence(output_lang, tgt) + inp_ids.append(EOS_token) + tgt_ids.append(EOS_token) + input_ids[idx, :len(inp_ids)] = inp_ids + target_ids[idx, :len(tgt_ids)] = tgt_ids + + train_data = TensorDataset(torch.LongTensor(input_ids).to(device), + torch.LongTensor(target_ids).to(device)) + + train_sampler = RandomSampler(train_data) + train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size) + return input_lang, output_lang, train_dataloader + ###################################################################### # Training the Model @@ -531,59 +591,31 @@ def tensorsFromPair(pair): # ``teacher_forcing_ratio`` up to use more of it. # -teacher_forcing_ratio = 0.5 - - -def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH): - encoder_hidden = encoder.initHidden() - - encoder_optimizer.zero_grad() - decoder_optimizer.zero_grad() +def train_epoch(dataloader, encoder, decoder, encoder_optimizer, + decoder_optimizer, criterion): + + total_loss = 0 + for data in dataloader: + input_tensor, target_tensor = data - input_length = input_tensor.size(0) - target_length = target_tensor.size(0) - - encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device) - - loss = 0 - - for ei in range(input_length): - encoder_output, encoder_hidden = encoder( - input_tensor[ei], encoder_hidden) - encoder_outputs[ei] = encoder_output[0, 0] - - decoder_input = torch.tensor([[SOS_token]], device=device) - - decoder_hidden = encoder_hidden - - use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False - - if use_teacher_forcing: - # Teacher forcing: Feed the target as the next input - for di in range(target_length): - decoder_output, decoder_hidden, decoder_attention = decoder( - decoder_input, decoder_hidden, encoder_outputs) - loss += criterion(decoder_output, target_tensor[di]) - decoder_input = target_tensor[di] # Teacher forcing - - else: - # Without teacher forcing: use its own predictions as the next input - for di in range(target_length): - decoder_output, decoder_hidden, decoder_attention = decoder( - decoder_input, decoder_hidden, encoder_outputs) - topv, topi = decoder_output.topk(1) - decoder_input = topi.squeeze().detach() # detach from history as input - - loss += criterion(decoder_output, target_tensor[di]) - if decoder_input.item() == EOS_token: - break + encoder_optimizer.zero_grad() + decoder_optimizer.zero_grad() - loss.backward() + encoder_outputs, encoder_hidden = encoder(input_tensor) + decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor) + + loss = criterion( + decoder_outputs.view(-1, decoder_outputs.size(-1)), + target_tensor.view(-1) + ) + loss.backward() - encoder_optimizer.step() - decoder_optimizer.step() + encoder_optimizer.step() + decoder_optimizer.step() - return loss.item() / target_length + total_loss += loss.item() + + return total_loss / len(dataloader) ###################################################################### @@ -594,13 +626,11 @@ def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, deco import time import math - def asMinutes(s): m = math.floor(s / 60) s -= m * 60 return '%dm %ds' % (m, s) - def timeSince(since, percent): now = time.time() s = now - since @@ -621,42 +651,35 @@ def timeSince(since, percent): # of examples, time so far, estimated time) and average loss. # -def trainIters(encoder, decoder, n_iters, print_every=1000, plot_every=100, learning_rate=0.01): +def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001, + print_every=100, plot_every=100): start = time.time() plot_losses = [] print_loss_total = 0 # Reset every print_every plot_loss_total = 0 # Reset every plot_every - - encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate) - decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate) - training_pairs = [tensorsFromPair(random.choice(pairs)) - for i in range(n_iters)] + + encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate) + decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate) criterion = nn.NLLLoss() - for iter in range(1, n_iters + 1): - training_pair = training_pairs[iter - 1] - input_tensor = training_pair[0] - target_tensor = training_pair[1] - - loss = train(input_tensor, target_tensor, encoder, - decoder, encoder_optimizer, decoder_optimizer, criterion) + for epoch in range(1, n_epochs + 1): + loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion) print_loss_total += loss plot_loss_total += loss - - if iter % print_every == 0: + + if epoch % print_every == 0: print_loss_avg = print_loss_total / print_every print_loss_total = 0 - print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters), - iter, iter / n_iters * 100, print_loss_avg)) + print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs), + epoch, epoch / n_epochs * 100, print_loss_avg)) - if iter % plot_every == 0: + if epoch % plot_every == 0: plot_loss_avg = plot_loss_total / plot_every plot_losses.append(plot_loss_avg) plot_loss_total = 0 - + showPlot(plot_losses) - ###################################################################### # Plotting results # ---------------- @@ -670,7 +693,6 @@ def trainIters(encoder, decoder, n_iters, print_every=1000, plot_every=100, lear import matplotlib.ticker as ticker import numpy as np - def showPlot(points): plt.figure() fig, ax = plt.subplots() @@ -691,40 +713,23 @@ def showPlot(points): # attention outputs for display later. # -def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH): +def evaluate(encoder, decoder, sentence, input_lang, output_lang): with torch.no_grad(): input_tensor = tensorFromSentence(input_lang, sentence) - input_length = input_tensor.size()[0] - encoder_hidden = encoder.initHidden() - encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device) - - for ei in range(input_length): - encoder_output, encoder_hidden = encoder(input_tensor[ei], - encoder_hidden) - encoder_outputs[ei] += encoder_output[0, 0] - - decoder_input = torch.tensor([[SOS_token]], device=device) # SOS - - decoder_hidden = encoder_hidden + encoder_outputs, encoder_hidden = encoder(input_tensor) + decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden) + _, topi = decoder_outputs.topk(1) + decoded_ids = topi.squeeze() + decoded_words = [] - decoder_attentions = torch.zeros(max_length, max_length) - - for di in range(max_length): - decoder_output, decoder_hidden, decoder_attention = decoder( - decoder_input, decoder_hidden, encoder_outputs) - decoder_attentions[di] = decoder_attention.data - topv, topi = decoder_output.data.topk(1) - if topi.item() == EOS_token: + for idx in decoded_ids: + if idx.item() == EOS_token: decoded_words.append('') break - else: - decoded_words.append(output_lang.index2word[topi.item()]) - - decoder_input = topi.squeeze().detach() - - return decoded_words, decoder_attentions[:di + 1] + decoded_words.append(output_lang.index2word[idx.item()]) + return decoded_words, decoder_attn ###################################################################### @@ -737,7 +742,7 @@ def evaluateRandomly(encoder, decoder, n=10): pair = random.choice(pairs) print('>', pair[0]) print('=', pair[1]) - output_words, attentions = evaluate(encoder, decoder, pair[0]) + output_words, _ = evaluate(encoder, decoder, pair[0], input_lang, output_lang) output_sentence = ' '.join(output_words) print('<', output_sentence) print('') @@ -762,16 +767,20 @@ def evaluateRandomly(encoder, decoder, n=10): # encoder and decoder are initialized and run ``trainIters`` again. # -hidden_size = 256 +hidden_size = 128 +batch_size = 32 + +input_lang, output_lang, train_dataloader = get_dataloader(batch_size) + encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device) -attn_decoder = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device) +decoder = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device) -trainIters(encoder, attn_decoder, 75000, print_every=5000) +train(train_dataloader, encoder, decoder, 100, print_every=5, plot_every=5) ###################################################################### # -evaluateRandomly(encoder, attn_decoder) +evaluateRandomly(encoder, decoder) ###################################################################### @@ -789,8 +798,8 @@ def evaluateRandomly(encoder, decoder, n=10): # output_words, attentions = evaluate( - encoder, attn_decoder, "je suis trop froid .") -plt.matshow(attentions.numpy()) + encoder, decoder, 'je suis trop froid', input_lang, output_lang) +plt.matshow(attentions.cpu().numpy()[0]) ###################################################################### @@ -799,10 +808,9 @@ def evaluateRandomly(encoder, decoder, n=10): # def showAttention(input_sentence, output_words, attentions): - # Set up figure with colorbar fig = plt.figure() ax = fig.add_subplot(111) - cax = ax.matshow(attentions.numpy(), cmap='bone') + cax = ax.matshow(attentions.cpu().numpy(), cmap='bone') fig.colorbar(cax) # Set up axes @@ -818,20 +826,19 @@ def showAttention(input_sentence, output_words, attentions): def evaluateAndShowAttention(input_sentence): - output_words, attentions = evaluate( - encoder, attn_decoder, input_sentence) + output_words, attentions = evaluate(encoder, decoder, input_sentence, input_lang, output_lang) print('input =', input_sentence) print('output =', ' '.join(output_words)) - showAttention(input_sentence, output_words, attentions) + showAttention(input_sentence, output_words, attentions[0, :len(output_words), :]) -evaluateAndShowAttention("elle a cinq ans de moins que moi .") +evaluateAndShowAttention('il n est pas aussi grand que son pere') -evaluateAndShowAttention("elle est trop petit .") +evaluateAndShowAttention('je suis trop fatigue pour conduire') -evaluateAndShowAttention("je ne crains pas de mourir .") +evaluateAndShowAttention('je suis desole si c est une question idiote') -evaluateAndShowAttention("c est un jeune directeur plein de talent .") +evaluateAndShowAttention('je suis reellement fiere de vous') ###################################################################### From 48d2ae74a449a012f69352d159b881ee5e7931fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Luis=20Castro=20Garc=C3=ADa?= Date: Wed, 14 Jun 2023 00:18:37 -0600 Subject: [PATCH 2/3] comments added --- .../seq2seq_translation_tutorial.py | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/intermediate_source/seq2seq_translation_tutorial.py b/intermediate_source/seq2seq_translation_tutorial.py index 81b3803b7a..3d8d9e4472 100755 --- a/intermediate_source/seq2seq_translation_tutorial.py +++ b/intermediate_source/seq2seq_translation_tutorial.py @@ -442,6 +442,21 @@ def forward_step(self, input, hidden): # :alt: # # +# Bahdanau attention, also known as additive attention, is a commonly used +# attention mechanism in sequence-to-sequence models, particularly in neural +# machine translation tasks. It was introduced by Dzmitry Bahdanau et al. in their +# paper titled `Neural Machine Translation by Jointly Learning to Align and Translate `__. +# This attention mechanism employs a learned alignment model to compute attention +# scores between the encoder and decoder hidden states. It utilizes a feed-forward +# neural network to calculate alignment scores. +# +# However, there are alternative attention mechanisms available, such as Luong attention, +# which computes attention scores by taking the dot product between the decoder hidden +# state and the encoder hidden states. It does not involve the non-linear transformation +# used in Bahdanau attention. +# +# In this tutorial, we will be using Bahdanau attention. However, it would be a valuable +# exercise to explore modifying the attention mechanism to use Luong attention. class BahdanauAttention(nn.Module): def __init__(self, hidden_size): @@ -775,7 +790,7 @@ def evaluateRandomly(encoder, decoder, n=10): encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device) decoder = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device) -train(train_dataloader, encoder, decoder, 100, print_every=5, plot_every=5) +train(train_dataloader, encoder, decoder, 80, print_every=5, plot_every=5) ###################################################################### # @@ -793,18 +808,8 @@ def evaluateRandomly(encoder, decoder, n=10): # at each time step. # # You could simply run ``plt.matshow(attentions)`` to see attention output -# displayed as a matrix, with the columns being input steps and rows being -# output steps: -# - -output_words, attentions = evaluate( - encoder, decoder, 'je suis trop froid', input_lang, output_lang) -plt.matshow(attentions.cpu().numpy()[0]) - - -###################################################################### -# For a better viewing experience we will do the extra work of adding axes -# and labels: +# displayed as a matrix. For a better viewing experience we will do the +# extra work of adding axes and labels: # def showAttention(input_sentence, output_words, attentions): From 36b57639e04fd4579ada01988bd5398add21a267 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Luis=20Castro=20Garc=C3=ADa?= <81191337+JoseLuisC99@users.noreply.github.com> Date: Wed, 14 Jun 2023 06:09:45 -0600 Subject: [PATCH 3/3] Update seq2seq_translation_tutorial.py --- intermediate_source/seq2seq_translation_tutorial.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/intermediate_source/seq2seq_translation_tutorial.py b/intermediate_source/seq2seq_translation_tutorial.py index 3d8d9e4472..e3a4be064c 100755 --- a/intermediate_source/seq2seq_translation_tutorial.py +++ b/intermediate_source/seq2seq_translation_tutorial.py @@ -369,7 +369,6 @@ def forward(self, input): # class DecoderRNN(nn.Module): - # Standard non-attentional decoder def __init__(self, hidden_size, output_size): super(DecoderRNN, self).__init__() self.embedding = nn.Embedding(output_size, hidden_size) @@ -444,7 +443,7 @@ def forward_step(self, input, hidden): # # Bahdanau attention, also known as additive attention, is a commonly used # attention mechanism in sequence-to-sequence models, particularly in neural -# machine translation tasks. It was introduced by Dzmitry Bahdanau et al. in their +# machine translation tasks. It was introduced by Bahdanau et al. in their # paper titled `Neural Machine Translation by Jointly Learning to Align and Translate `__. # This attention mechanism employs a learned alignment model to compute attention # scores between the encoder and decoder hidden states. It utilizes a feed-forward