diff --git a/esim/layers.py b/esim/layers.py index 0c9e0f1..7543417 100644 --- a/esim/layers.py +++ b/esim/layers.py @@ -140,7 +140,8 @@ def forward(self, premise_batch, premise_mask, hypothesis_batch, - hypothesis_mask): + hypothesis_mask, + output_attentions = False): """ Args: premise_batch: A batch of sequences of vectors representing the @@ -155,12 +156,23 @@ def forward(self, hypothesis_mask: A mask for the sequences in the hypotheses batch, to ignore padding data in the sequences during the computation of the attention. + output_attentions: returns the softmaxed attention value matrix for the + premise and hypothesis after cross attention. Default is 'False'. Returns: attended_premises: The sequences of attention vectors for the premises in the input batch. attended_hypotheses: The sequences of attention vectors for the hypotheses in the input batch. + + if output_attentions is True: + + hyp_prem_attn: attention values for each hypothesis token softmaxed across all premise tokens,ie + masked softmax using the premise mask + + prem_hyp_mask: attention values for each premise token softmaxed across all hypothesis tokens,ie + masked softmax using the hypothesis mask + """ # Dot product between premises and hypotheses in each sequence of # the batch. @@ -182,4 +194,5 @@ def forward(self, hyp_prem_attn, hypothesis_mask) + if output_attentions: return attended_premises, attended_hypotheses, hyp_prem_attn, prem_hyp_attn return attended_premises, attended_hypotheses diff --git a/esim/model.py b/esim/model.py index 4facde2..53affb6 100644 --- a/esim/model.py +++ b/esim/model.py @@ -24,6 +24,7 @@ def __init__(self, padding_idx=0, dropout=0.5, num_classes=3, + output_attentions = False, device="cpu"): """ Args: @@ -40,6 +41,8 @@ def __init__(self, Defaults to 0.5. num_classes: The number of classes in the output of the network. Defaults to 3. + output_attentions: returns the attentions for premise and hypothesis + after cross attention. Defaults to 'False' device: The name of the device on which the model is being executed. Defaults to 'cpu'. """ @@ -50,6 +53,7 @@ def __init__(self, self.hidden_size = hidden_size self.num_classes = num_classes self.dropout = dropout + self.output_attn = output_attentions self.device = device self._word_embedding = nn.Embedding(self.vocab_size, @@ -128,9 +132,17 @@ def forward(self, encoded_hypotheses = self._encoding(embedded_hypotheses, hypotheses_lengths) - attended_premises, attended_hypotheses =\ - self._attention(encoded_premises, premises_mask, - encoded_hypotheses, hypotheses_mask) + if self.output_attn: + attended_premises, attended_hypotheses, hyp_attn, prem_attn =self._attention(encoded_premises, + premises_mask, + encoded_hypotheses, + hypotheses_mask, + output_attentions = self.output_attn) + else: + attended_premises, attended_hypotheses =self._attention(encoded_premises, + premises_mask, + encoded_hypotheses, + hypotheses_mask) enhanced_premises = torch.cat([encoded_premises, attended_premises, @@ -170,9 +182,12 @@ def forward(self, logits = self._classification(v) probabilities = nn.functional.softmax(logits, dim=-1) + if self.output_attn: + return logits, probabilities, attn_vec return logits, probabilities + def _init_esim_weights(module): """ Initialise the weights of the ESIM model.