import random import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils import * class DecoderSimple(nn.Module): def __init__( self, hidden_size, vocab_sizeT, vocab_sizeN, embedding_sizeT, embedding_sizeN, dropout, num_layers, device='cuda' ): super(DecoderSimple, self).__init__() self.num_layers = num_layers self.hidden_size = hidden_size self.device = device self.dropout = dropout self.embeddingN = nn.Embedding(vocab_sizeN, embedding_sizeN, vocab_sizeN - 1) self.embeddingT = nn.Embedding(vocab_sizeT + 3, embedding_sizeT, vocab_sizeT - 1) self.lstm = nn.LSTM( embedding_sizeN + embedding_sizeT, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=False ) self.w_global = nn.Linear(hidden_size * 3, vocab_sizeT + 3) # map into T def embedded_dropout(self, embed, words, scale=None): dropout = self.dropout if dropout > 0: mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout) masked_embed_weight = mask * embed.weight else: masked_embed_weight = embed.weight if scale: masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight padding_idx = embed.padding_idx if padding_idx is None: padding_idx = -1 words[words >= embed.weight.size(0)] = padding_idx X = F.embedding(words, masked_embed_weight, padding_idx, embed.max_norm, embed.norm_type, embed.scale_grad_by_freq, embed.sparse ) return X def forward( self, input, hc, enc_out, mask, h_parent ): n_input, t_input = input batch_size = n_input.size(0) # (enc_out, enc_out_W) [(batch_size, max_length, hidden_size * 2), (batch_size, max_length, hidden_size)] # mask (batch_size, max_length) # hidden_prev (batch_size, hidden_size) n_input = self.embedded_dropout(self.embeddingN, n_input) t_input = self.embedded_dropout(self.embeddingT, t_input) input = torch.cat([n_input, t_input], 1) out, (h, c) = self.lstm(input.unsqueeze(1), hc) hidden = h[-1] # use only last layer hidden in attention out = out.squeeze(1) w_t = F.log_softmax(self.w_global(torch.cat([hidden, out, h_parent], dim=1)), dim=1) return w_t, (h, c) class DecoderAttention(nn.Module): def __init__( self, hidden_size, vocab_sizeT, vocab_sizeN, embedding_sizeT, embedding_sizeN, dropout, num_layers, attn_size=50, pointer=True, device='cuda' ): super(DecoderAttention, self).__init__() self.num_layers = num_layers self.hidden_size = hidden_size self.pointer = pointer self.device = device self.dropout = dropout self.embeddingN = nn.Embedding(vocab_sizeN, embedding_sizeN, vocab_sizeN - 1) self.embeddingT = nn.Embedding(vocab_sizeT + attn_size + 2, embedding_sizeT, vocab_sizeT - 1) self.W_hidden = nn.Linear(hidden_size, hidden_size) self.W_mem2hidden = nn.Linear(hidden_size, hidden_size) self.v = nn.Linear(hidden_size, 1) self.W_context = nn.Linear( embedding_sizeN + embedding_sizeT + hidden_size, hidden_size ) self.lstm = nn.LSTM( embedding_sizeN + embedding_sizeT, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=False ) self.w_global = nn.Linear(hidden_size * 3, vocab_sizeT + 2) # map into T if self.pointer: self.w_switcher = nn.Linear(hidden_size * 2, 1) self.logsigmoid = torch.nn.LogSigmoid() def embedded_dropout(self, embed, words, scale=None): dropout = self.dropout if dropout > 0: mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout) masked_embed_weight = mask * embed.weight else: masked_embed_weight = embed.weight if scale: masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight padding_idx = embed.padding_idx if padding_idx is None: padding_idx = -1 words[words >= embed.weight.size(0)] = padding_idx X = F.embedding(words, masked_embed_weight, padding_idx, embed.max_norm, embed.norm_type, embed.scale_grad_by_freq, embed.sparse ) return X def forward( self, input, hc, enc_out, mask, h_parent ): n_input, t_input = input batch_size = n_input.size(0) # (enc_out, enc_out_W) [(batch_size, max_length, hidden_size * 2), (batch_size, max_length, hidden_size)] # mask (batch_size, max_length) # hidden_prev (batch_size, hidden_size) n_input = self.embedded_dropout(self.embeddingN, n_input) t_input = self.embedded_dropout(self.embeddingT, t_input) input = torch.cat([n_input, t_input], 1) out, (h, c) = self.lstm(input.unsqueeze(1), hc) hidden = h[-1] # use only last layer hidden in attention out = out.squeeze(1) scores = self.W_hidden(hidden).unsqueeze(1) # (batch_size, max_length, hidden_size) if enc_out.shape[1] > 0: scores_mem = self.W_mem2hidden(enc_out) scores = scores.repeat(1, scores_mem.shape[1], 1) + scores_mem scores = torch.tanh(scores) scores = self.v(scores).squeeze(2) # (batch_size, max_length) scores = scores.masked_fill(mask, -1e20) # (batch_size, max_length) attn_weights = F.softmax(scores, dim=1) # (batch_size, max_length) attn_weights = attn_weights.unsqueeze(1) # (batch_size, 1, max_length) context = torch.matmul(attn_weights, enc_out).squeeze(1) # (batch_size, hidden_size) if self.pointer: w_t = F.log_softmax(self.w_global(torch.cat([context, out, h_parent], dim=1)), dim=1) attn_weights = F.log_softmax(scores, dim=1) w_s = self.w_switcher(torch.cat([context, out], dim=1)) return torch.cat([self.logsigmoid(w_s) + w_t, self.logsigmoid(-w_s) + attn_weights], dim=1), (h, c) else: w_t = F.log_softmax(self.w_global(torch.cat([context, out, h_parent], dim=1)), dim=1) return w_t, (h, c) class MixtureAttention(nn.Module): def __init__( self, hidden_size, vocab_sizeT, vocab_sizeN, embedding_sizeT, embedding_sizeN, num_layers, dropout, device='cuda', label_smoothing = 0.1, attn=True, pointer=True, attn_size=50, SOS_token=0 ): super(MixtureAttention, self).__init__() self.device = device self.hidden_size = hidden_size self.dropout = dropout self.eof_N_id = vocab_sizeN - 1 self.eof_T_id = vocab_sizeT - 1 self.unk_id = vocab_sizeT - 2 self.SOS_token = SOS_token self.attn_size = attn_size self.vocab_sizeT = vocab_sizeT self.vocab_sizeN = vocab_sizeN self.W_out = nn.Linear(hidden_size * 2, hidden_size) if attn: self.decoder = DecoderAttention( hidden_size=hidden_size, vocab_sizeT=vocab_sizeT, vocab_sizeN=vocab_sizeN, embedding_sizeT=embedding_sizeT, embedding_sizeN=embedding_sizeN, num_layers=num_layers, attn_size=attn_size, dropout=dropout, pointer=pointer, device=device ).to(device) else: self.decoder = DecoderSimple( hidden_size=hidden_size, vocab_sizeT=vocab_sizeT, vocab_sizeN=vocab_sizeN, embedding_sizeT=embedding_sizeT, embedding_sizeN=embedding_sizeN, num_layers=num_layers, dropout=dropout, device=device ).to(device) if label_smoothing > 0: self.criterion = LabelSmoothingLoss( label_smoothing, tgt_vocab_size=vocab_sizeT + attn_size + 3, ignore_index=self.eof_T_id, device=self.device ) # ignore EOF ?! else: self.criterion = nn.NLLLoss(reduction='none', ignore_index=self.eof_T_id) # self.pointer = pointer def forward( self, n_tensor, t_tensor, p_tensor ): batch_size = n_tensor.size(0) max_length = n_tensor.size(1) full_mask = (n_tensor == self.eof_N_id) input = ( torch.ones( batch_size, dtype=torch.long, device=self.device ) * self.SOS_token, torch.ones( batch_size, dtype=torch.long, device=self.device ) * self.SOS_token ) hs = torch.zeros( batch_size, max_length, self.hidden_size, requires_grad=False ).to(self.device) hc = None parent = torch.zeros( batch_size, dtype=torch.long, device=self.device ) loss = torch.tensor(0.0, device=self.device) token_losses = torch.zeros( batch_size, max_length ).to(self.device) ans = [] for iter in range(max_length): memory = hs[:, max(iter - self.attn_size, 0) : iter] output, hc = self.decoder( input, hc, memory.clone().detach(), full_mask[:, max(iter - self.attn_size, 0) : iter], hs[torch.arange(batch_size),parent].squeeze(1).clone().detach() ) hs[:, iter] = hc[0][-1] # store last layer hidden state only topv, topi = output.topk(1) input = (n_tensor[:, iter].clone(), t_tensor[:, iter].clone()) parent = p_tensor[:, iter] # print(output.shape[1]) ans.append(topi.detach()) # cond = (t_tensor[:, iter] < self.vocab_sizeT + self.attn_size).long() # masked_target = cond * t_tensor[:, iter] + (1 - cond) * self.eof_T_id target = t_tensor[:, iter] target[target >= output.shape[1]] = self.unk_id token_losses[:, iter] = self.criterion(output, t_tensor[:, iter].clone().detach()) loss = token_losses.sum() #/ batch_size return loss, torch.cat(ans, dim=1)