From 6612457434be4870dc8b952fdab0f1de427362dc Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Tue, 17 May 2022 20:43:21 +0200 Subject: [PATCH 01/10] Weighted CRF: scaled emission scores --- allennlp/modules/conditional_random_field.py | 22 +++++ .../modules/conditional_random_field_test.py | 81 +++++++++++++++++++ 2 files changed, 103 insertions(+) diff --git a/allennlp/modules/conditional_random_field.py b/allennlp/modules/conditional_random_field.py index 78cde09f36a..1c8ef940590 100644 --- a/allennlp/modules/conditional_random_field.py +++ b/allennlp/modules/conditional_random_field.py @@ -174,6 +174,14 @@ class ConditionalRandomField(torch.nn.Module): start and end transitions are handled correctly for your tag type. include_start_end_transitions : `bool`, optional (default = `True`) Whether to include the start and end transition parameters. + label_weights : `List[float]`, optional (default=`None`) + An optional list of weights to be used in the loss function in order to + give different weights for each token depending on its label. + `len(label_weights)` must be equal to `num_tags`. This is useful to + deal with highly unbalanced datasets. The method implemented here was based on + the paper *Weighted conditional random fields for supervised interpatient heartbeat + classification* proposed by De Lannoy et. al (2019). + See https://perso.uclouvain.be/michel.verleysen/papers/ieeetbe12gdl.pdf """ def __init__( @@ -181,6 +189,7 @@ def __init__( num_tags: int, constraints: List[Tuple[int, int]] = None, include_start_end_transitions: bool = True, + label_weights: List[float] = None, ) -> None: super().__init__() self.num_tags = num_tags @@ -206,6 +215,11 @@ def __init__( self.start_transitions = torch.nn.Parameter(torch.Tensor(num_tags)) self.end_transitions = torch.nn.Parameter(torch.Tensor(num_tags)) + # If label_weights is not given, use 1.0 for all weights. + if label_weights is None: + label_weights = [1.0] * num_tags + self.label_weights = torch.Tensor(label_weights) + self.reset_parameters() def reset_parameters(self): @@ -280,6 +294,8 @@ def _joint_likelihood( else: score = 0.0 + label_weights = self.label_weights + # Add up the scores for the observed transitions and all the inputs but the last for i in range(sequence_length - 1): # Each is shape (batch_size,) @@ -291,6 +307,9 @@ def _joint_likelihood( # The score for using current_tag emit_score = logits[i].gather(1, current_tag.view(batch_size, 1)).squeeze(1) + # Weight emit scores by label. + emit_score *= label_weights[current_tag.view(-1)] + # Include transition score if next element is unmasked, # input_score if this element is unmasked. score = score + transition_score * mask[i + 1] + emit_score * mask[i] @@ -311,6 +330,9 @@ def _joint_likelihood( last_input_score = last_inputs.gather(1, last_tags.view(-1, 1)) # (batch_size, 1) last_input_score = last_input_score.squeeze() # (batch_size,) + # Weight last emit scores by label weights. + last_input_score = last_input_score * label_weights[last_tags.view(-1)] + score = score + last_transition_score + last_input_score * mask[-1] return score diff --git a/tests/modules/conditional_random_field_test.py b/tests/modules/conditional_random_field_test.py index 3b4c2655b5c..5a71cd12c45 100644 --- a/tests/modules/conditional_random_field_test.py +++ b/tests/modules/conditional_random_field_test.py @@ -382,3 +382,84 @@ def test_allowed_transitions(self): (8, 4), (8, 7), # Extra row for start tag } + + +class TestWeightedConditionalRandomField(TestConditionalRandomField): + def setup_method(self): + super().setup_method() + + self.label_weights = torch.FloatTensor([1.0, 1.0, 0.5, 0.5, 0.5]) + + # Use the CRF Module with labels weights. + self.crf.label_weights = self.label_weights + + def score_with_weights(self, logits, tags): + """ + Computes the likelihood score for the given sequence of tags, + given the provided logits, the transition weights in the CRF model + and the label weights. + """ + # Start with transitions from START and to END + total = self.transitions_from_start[tags[0]] + self.transitions_to_end[tags[-1]] + # Add in all the intermediate transitions + for tag, next_tag in zip(tags, tags[1:]): + total += self.transitions[tag, next_tag] + # Add in the logits for the observed tags + for logit, tag in zip(logits, tags): + total += logit[tag] * self.label_weights[tag] + return total + + + def test_forward_works_without_mask(self): + log_likelihood = self.crf(self.logits, self.tags).item() + + # Now compute the log-likelihood manually + manual_log_likelihood = 0.0 + + # For each instance, manually compute the numerator + # (which is just the score for the logits and actual tags) + # and the denominator + # (which is the log-sum-exp of the scores for the logits across all possible tags) + for logits_i, tags_i in zip(self.logits, self.tags): + numerator = self.score_with_weights(logits_i.detach(), tags_i.detach()) + all_scores = [ + self.score(logits_i.detach(), tags_j) + for tags_j in itertools.product(range(5), repeat=3) + ] + denominator = math.log(sum(math.exp(score) for score in all_scores)) + # And include them in the manual calculation. + manual_log_likelihood += numerator - denominator + + # The manually computed log likelihood should equal the result of crf.forward. + assert manual_log_likelihood.item() == approx(log_likelihood) + + def test_forward_works_with_mask(self): + # Use a non-trivial mask + mask = torch.tensor([[True, True, True], [True, True, False]]) + + log_likelihood = self.crf(self.logits, self.tags, mask).item() + + # Now compute the log-likelihood manually + manual_log_likelihood = 0.0 + + # For each instance, manually compute the numerator + # (which is just the score for the logits and actual tags) + # and the denominator + # (which is the log-sum-exp of the scores for the logits across all possible tags) + for logits_i, tags_i, mask_i in zip(self.logits, self.tags, mask): + # Find the sequence length for this input and only look at that much of each sequence. + sequence_length = torch.sum(mask_i.detach()) + logits_i = logits_i.data[:sequence_length] + tags_i = tags_i.data[:sequence_length] + + numerator = self.score_with_weights(logits_i, tags_i) + all_scores = [ + self.score(logits_i, tags_j) + for tags_j in itertools.product(range(5), repeat=sequence_length) + ] + denominator = math.log(sum(math.exp(score) for score in all_scores)) + # And include them in the manual calculation. + manual_log_likelihood += numerator - denominator + + # The manually computed log likelihood should equal the result of crf.forward. + assert manual_log_likelihood.item() == approx(log_likelihood) From 6ce3b1578933d0cc1d36408c62bdfe425d11ab4f Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Tue, 17 May 2022 20:43:21 +0200 Subject: [PATCH 02/10] Fixed bug in ConditionalRandomField self.label_weights is now created as a parameter so that it will be moved to GPU whenvever the model moves. --- allennlp/modules/conditional_random_field.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/allennlp/modules/conditional_random_field.py b/allennlp/modules/conditional_random_field.py index 1c8ef940590..beb72e26997 100644 --- a/allennlp/modules/conditional_random_field.py +++ b/allennlp/modules/conditional_random_field.py @@ -218,7 +218,7 @@ def __init__( # If label_weights is not given, use 1.0 for all weights. if label_weights is None: label_weights = [1.0] * num_tags - self.label_weights = torch.Tensor(label_weights) + self.label_weights = torch.nn.Parameter(torch.Tensor(label_weights), requires_grad=False) self.reset_parameters() From eed2eef2be00017c7804f86a411041b11a792902 Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Tue, 17 May 2022 20:43:21 +0200 Subject: [PATCH 03/10] CRF weighting strategies --- allennlp/modules/__init__.py | 2 + allennlp/modules/conditional_random_field.py | 33 +- .../conditional_random_field_lannoy.py | 489 ++++++++++++++++++ .../conditional_random_field_wtrans.py | 466 +++++++++++++++++ .../modules/conditional_random_field_test.py | 6 +- 5 files changed, 980 insertions(+), 16 deletions(-) create mode 100644 allennlp/modules/conditional_random_field_lannoy.py create mode 100644 allennlp/modules/conditional_random_field_wtrans.py diff --git a/allennlp/modules/__init__.py b/allennlp/modules/__init__.py index 0e47f36d0f6..741f840cc54 100644 --- a/allennlp/modules/__init__.py +++ b/allennlp/modules/__init__.py @@ -8,6 +8,8 @@ from allennlp.modules.backbones import Backbone from allennlp.modules.bimpm_matching import BiMpmMatching from allennlp.modules.conditional_random_field import ConditionalRandomField +from allennlp.modules.conditional_random_field_wtrans import ConditionalRandomFieldWeightTrans +from allennlp.modules.conditional_random_field_lannoy import ConditionalRandomFieldLannoy from allennlp.modules.elmo import Elmo from allennlp.modules.feedforward import FeedForward from allennlp.modules.gated_sum import GatedSum diff --git a/allennlp/modules/conditional_random_field.py b/allennlp/modules/conditional_random_field.py index beb72e26997..277eb2162a9 100644 --- a/allennlp/modules/conditional_random_field.py +++ b/allennlp/modules/conditional_random_field.py @@ -178,10 +178,9 @@ class ConditionalRandomField(torch.nn.Module): An optional list of weights to be used in the loss function in order to give different weights for each token depending on its label. `len(label_weights)` must be equal to `num_tags`. This is useful to - deal with highly unbalanced datasets. The method implemented here was based on - the paper *Weighted conditional random fields for supervised interpatient heartbeat - classification* proposed by De Lannoy et. al (2019). - See https://perso.uclouvain.be/michel.verleysen/papers/ieeetbe12gdl.pdf + deal with highly unbalanced datasets. The method implemented here is + based on the simple idea of weighting emission and transition scores + using the weight given for the corresponding tag. """ def __init__( @@ -239,41 +238,49 @@ def _input_likelihood(self, logits: torch.Tensor, mask: torch.BoolTensor) -> tor mask = mask.transpose(0, 1).contiguous() logits = logits.transpose(0, 1).contiguous() + # insert the batch dimesion to be broadcasted + label_weights = self.label_weights.view(1, num_tags) + + # emit_scores.shape = (batch_size, num_tags) + emit_scores = logits[0] * label_weights + # Initial alpha is the (batch_size, num_tags) tensor of likelihoods combining the # transitions to the initial states and the logits for the first timestep. if self.include_start_end_transitions: - alpha = self.start_transitions.view(1, num_tags) + logits[0] + log_alpha = self.start_transitions.view(1, num_tags) + emit_scores else: - alpha = logits[0] + log_alpha = emit_scores # For each i we compute logits for the transitions from timestep i-1 to timestep i. # We do so in a (batch_size, num_tags, num_tags) tensor where the axes are # (instance, current_tag, next_tag) for i in range(1, sequence_length): + # multiply the logits by the label weights + # logits[i].shape: (batch_size, num_tags) + emit_scores = logits[i] * label_weights + # The emit scores are for time i ("next_tag") so we broadcast along the current_tag axis. - emit_scores = logits[i].view(batch_size, 1, num_tags) + emit_scores = emit_scores.view(batch_size, 1, num_tags) # Transition scores are (current_tag, next_tag) so we broadcast along the instance axis. transition_scores = self.transitions.view(1, num_tags, num_tags) # Alpha is for the current_tag, so we broadcast along the next_tag axis. - broadcast_alpha = alpha.view(batch_size, num_tags, 1) + broadcast_alpha = log_alpha.view(batch_size, num_tags, 1) # Add all the scores together and logexp over the current_tag axis. inner = broadcast_alpha + emit_scores + transition_scores # In valid positions (mask == True) we want to take the logsumexp over the current_tag dimension # of `inner`. Otherwise (mask == False) we want to retain the previous alpha. - alpha = util.logsumexp(inner, 1) * mask[i].view(batch_size, 1) + alpha * ( + log_alpha = util.logsumexp(inner, 1) * mask[i].view(batch_size, 1) + log_alpha * ( ~mask[i] ).view(batch_size, 1) # Every sequence needs to end with a transition to the stop_tag. if self.include_start_end_transitions: - stops = alpha + self.end_transitions.view(1, num_tags) - else: - stops = alpha + log_alpha = log_alpha + self.end_transitions.view(1, num_tags) # Finally we log_sum_exp along the num_tags dim, result is (batch_size,) - return util.logsumexp(stops) + return util.logsumexp(log_alpha, 1) def _joint_likelihood( self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor diff --git a/allennlp/modules/conditional_random_field_lannoy.py b/allennlp/modules/conditional_random_field_lannoy.py new file mode 100644 index 00000000000..5c1c4942cdc --- /dev/null +++ b/allennlp/modules/conditional_random_field_lannoy.py @@ -0,0 +1,489 @@ +""" +Conditional random field +""" +from typing import List, Tuple, Dict, Union + +import torch + +from allennlp.common.checks import ConfigurationError +import allennlp.nn.util as util + +VITERBI_DECODING = Tuple[List[int], float] # a list of tags, and a viterbi score + + +def allowed_transitions(constraint_type: str, labels: Dict[int, str]) -> List[Tuple[int, int]]: + """ + Given labels and a constraint type, returns the allowed transitions. It will + additionally include transitions for the start and end states, which are used + by the conditional random field. + + # Parameters + + constraint_type : `str`, required + Indicates which constraint to apply. Current choices are + "BIO", "IOB1", "BIOUL", and "BMES". + labels : `Dict[int, str]`, required + A mapping {label_id -> label}. Most commonly this would be the value from + Vocabulary.get_index_to_token_vocabulary() + + # Returns + + `List[Tuple[int, int]]` + The allowed transitions (from_label_id, to_label_id). + """ + num_labels = len(labels) + start_tag = num_labels + end_tag = num_labels + 1 + labels_with_boundaries = list(labels.items()) + [(start_tag, "START"), (end_tag, "END")] + + allowed = [] + for from_label_index, from_label in labels_with_boundaries: + if from_label in ("START", "END"): + from_tag = from_label + from_entity = "" + else: + from_tag = from_label[0] + from_entity = from_label[1:] + for to_label_index, to_label in labels_with_boundaries: + if to_label in ("START", "END"): + to_tag = to_label + to_entity = "" + else: + to_tag = to_label[0] + to_entity = to_label[1:] + if is_transition_allowed(constraint_type, from_tag, from_entity, to_tag, to_entity): + allowed.append((from_label_index, to_label_index)) + return allowed + + +def is_transition_allowed( + constraint_type: str, from_tag: str, from_entity: str, to_tag: str, to_entity: str +): + """ + Given a constraint type and strings `from_tag` and `to_tag` that + represent the origin and destination of the transition, return whether + the transition is allowed under the given constraint type. + + # Parameters + + constraint_type : `str`, required + Indicates which constraint to apply. Current choices are + "BIO", "IOB1", "BIOUL", and "BMES". + from_tag : `str`, required + The tag that the transition originates from. For example, if the + label is `I-PER`, the `from_tag` is `I`. + from_entity : `str`, required + The entity corresponding to the `from_tag`. For example, if the + label is `I-PER`, the `from_entity` is `PER`. + to_tag : `str`, required + The tag that the transition leads to. For example, if the + label is `I-PER`, the `to_tag` is `I`. + to_entity : `str`, required + The entity corresponding to the `to_tag`. For example, if the + label is `I-PER`, the `to_entity` is `PER`. + + # Returns + + `bool` + Whether the transition is allowed under the given `constraint_type`. + """ + + if to_tag == "START" or from_tag == "END": + # Cannot transition into START or from END + return False + + if constraint_type == "BIOUL": + if from_tag == "START": + return to_tag in ("O", "B", "U") + if to_tag == "END": + return from_tag in ("O", "L", "U") + return any( + [ + # O can transition to O, B-* or U-* + # L-x can transition to O, B-*, or U-* + # U-x can transition to O, B-*, or U-* + from_tag in ("O", "L", "U") and to_tag in ("O", "B", "U"), + # B-x can only transition to I-x or L-x + # I-x can only transition to I-x or L-x + from_tag in ("B", "I") and to_tag in ("I", "L") and from_entity == to_entity, + ] + ) + elif constraint_type == "BIO": + if from_tag == "START": + return to_tag in ("O", "B") + if to_tag == "END": + return from_tag in ("O", "B", "I") + return any( + [ + # Can always transition to O or B-x + to_tag in ("O", "B"), + # Can only transition to I-x from B-x or I-x + to_tag == "I" and from_tag in ("B", "I") and from_entity == to_entity, + ] + ) + elif constraint_type == "IOB1": + if from_tag == "START": + return to_tag in ("O", "I") + if to_tag == "END": + return from_tag in ("O", "B", "I") + return any( + [ + # Can always transition to O or I-x + to_tag in ("O", "I"), + # Can only transition to B-x from B-x or I-x, where + # x is the same tag. + to_tag == "B" and from_tag in ("B", "I") and from_entity == to_entity, + ] + ) + elif constraint_type == "BMES": + if from_tag == "START": + return to_tag in ("B", "S") + if to_tag == "END": + return from_tag in ("E", "S") + return any( + [ + # Can only transition to B or S from E or S. + to_tag in ("B", "S") and from_tag in ("E", "S"), + # Can only transition to M-x from B-x, where + # x is the same tag. + to_tag == "M" and from_tag in ("B", "M") and from_entity == to_entity, + # Can only transition to E-x from B-x or M-x, where + # x is the same tag. + to_tag == "E" and from_tag in ("B", "M") and from_entity == to_entity, + ] + ) + else: + raise ConfigurationError(f"Unknown constraint type: {constraint_type}") + + +class ConditionalRandomFieldLannoy(torch.nn.Module): + """ + This module uses the "forward-backward" algorithm to compute + the log-likelihood of its inputs assuming a conditional random field model. + + See, e.g. http://www.cs.columbia.edu/~mcollins/fb.pdf + + # Parameters + + num_tags : `int`, required + The number of tags. + constraints : `List[Tuple[int, int]]`, optional (default = `None`) + An optional list of allowed transitions (from_tag_id, to_tag_id). + These are applied to `viterbi_tags()` but do not affect `forward()`. + These should be derived from `allowed_transitions` so that the + start and end transitions are handled correctly for your tag type. + include_start_end_transitions : `bool`, optional (default = `True`) + Whether to include the start and end transition parameters. + label_weights : `List[float]`, optional (default=`None`) + An optional list of weights to be used in the loss function in order to + give different weights for each token depending on its label. + `len(label_weights)` must be equal to `num_tags`. This is useful to + deal with highly unbalanced datasets. The method implemented here was based on + the paper *Weighted conditional random fields for supervised interpatient heartbeat + classification* proposed by De Lannoy et. al (2019). + See https://perso.uclouvain.be/michel.verleysen/papers/ieeetbe12gdl.pdf + """ + + def __init__( + self, + num_tags: int, + constraints: List[Tuple[int, int]] = None, + include_start_end_transitions: bool = True, + label_weights: List[float] = None, + ) -> None: + super().__init__() + self.num_tags = num_tags + + # transitions[i, j] is the logit for transitioning from state i to state j. + self.transitions = torch.nn.Parameter(torch.empty(num_tags, num_tags)) + + # _constraint_mask indicates valid transitions (based on supplied constraints). + # Include special start of sequence (num_tags + 1) and end of sequence tags (num_tags + 2) + if constraints is None: + # All transitions are valid. + constraint_mask = torch.full((num_tags + 2, num_tags + 2), 1.0) + else: + constraint_mask = torch.full((num_tags + 2, num_tags + 2), 0.0) + for i, j in constraints: + constraint_mask[i, j] = 1.0 + + self._constraint_mask = torch.nn.Parameter(constraint_mask, requires_grad=False) + + # Also need logits for transitioning from "start" state and to "end" state. + self.include_start_end_transitions = include_start_end_transitions + if include_start_end_transitions: + self.start_transitions = torch.nn.Parameter(torch.Tensor(num_tags)) + self.end_transitions = torch.nn.Parameter(torch.Tensor(num_tags)) + + # If label_weights is not given, use 1.0 for all weights. + if label_weights is None: + label_weights = [1.0] * num_tags + self.label_weights = torch.nn.Parameter(torch.Tensor(label_weights), requires_grad=False) + + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.xavier_normal_(self.transitions) + if self.include_start_end_transitions: + torch.nn.init.normal_(self.start_transitions) + torch.nn.init.normal_(self.end_transitions) + + def _input_likelihood( + self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor + ) -> torch.Tensor: + """ + Computes the (batch_size,) denominator term for the log-likelihood, which is the + sum of the likelihoods across all possible state sequences. + + Compute this value using the scaling trick instead of the log domain trick, since + this is necessary to implement the label-weighting method by Lannoy et al. (2012). + """ + batch_size, sequence_length, num_tags = logits.size() + + # Transpose batch size and sequence dimensions + mask = mask.transpose(0, 1).contiguous() + logits = logits.transpose(0, 1).contiguous() + tags = tags.transpose(0, 1).contiguous() + + # insert an 1-sized second dimension to match z.shape + label_weights = self.label_weights.view(num_tags, 1) + + # emit_scores.shape = (batch_size, num_tags) + emit_scores = logits[0] + + # Initial alpha is the (batch_size, num_tags) tensor of likelihoods combining the + # transitions to the initial states and the logits for the first timestep. + # alpha.shape = (batch_size, num_tags) + if self.include_start_end_transitions: + alpha = torch.exp(self.start_transitions.view(1, num_tags) + emit_scores) + else: + alpha = torch.exp(emit_scores) + + # z.shape = (batch_size, 1) + z = alpha.sum(dim=1, keepdim=True) + alpha = alpha / z + sum_log_z = torch.log(z) * label_weights[tags[0]] + + # For each i we compute logits for the transitions from timestep i-1 to timestep i. + # We do so in a (batch_size, num_tags, num_tags) tensor where the axes are + # (instance, current_tag, next_tag) + for i in range(1, sequence_length): + # multiply the logits by the label weights + # logits[i].shape: (batch_size, num_tags) + # emit_scores = torch.mul(logits[i], label_weights) + emit_scores = logits[i] + + # The emit scores are for time i ("next_tag") so we broadcast along the current_tag axis. + emit_scores = emit_scores.view(batch_size, 1, num_tags) + # Transition scores are (current_tag, next_tag) so we broadcast along the instance axis. + transition_scores = self.transitions.view(1, num_tags, num_tags) + # Alpha is for the current_tag (i-1), so we broadcast along the next_tag axis. + broadcast_alpha = alpha.view(batch_size, num_tags, 1) + + # Add all the scores together and logexp over the current_tag axis. + inner = broadcast_alpha * torch.exp(emit_scores + transition_scores) + + # In valid positions (mask == True) we want to take the logsumexp over the current_tag dimension + # of `inner`. Otherwise (mask == False) we want to retain the previous alpha. + alpha = inner.sum(dim=1) * mask[i].view(batch_size, 1) + alpha * (~mask[i]).view( + batch_size, 1 + ) + + # scale alphas to avoid underflow (sum of alphas equal to 1) + z = alpha.sum(dim=1, keepdim=True) + alpha = alpha / z + # weight z (normalization factor) according to the current tag + sum_log_z += torch.log(z) * label_weights[tags[i]] + + # Every sequence needs to end with a transition to the stop_tag. + if self.include_start_end_transitions: + alpha = alpha * torch.exp(self.end_transitions.view(1, num_tags)) + z = alpha.sum(dim=1, keepdim=True) + # alpha = alpha / z # this step is unnecessary since alpha is not used anymore + sum_log_z += torch.log(z) + + return sum_log_z.squeeze(1) + + + def _joint_likelihood( + self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor + ) -> torch.Tensor: + """ + Computes the numerator term for the log-likelihood, which is just score(inputs, tags) + """ + batch_size, sequence_length, _ = logits.data.shape + + # Transpose batch size and sequence dimensions: + logits = logits.transpose(0, 1).contiguous() + mask = mask.transpose(0, 1).contiguous() + tags = tags.transpose(0, 1).contiguous() + + # Start with the transition scores from start_tag to the first tag in each input + if self.include_start_end_transitions: + score = self.start_transitions.index_select(0, tags[0]) + else: + score = 0.0 + + label_weights = self.label_weights + + # weight transition score using current_tag, i.e., t(i,j) will be t(i,j)*w(i), + # where t(i,j) is the score to transition from i to j and w(i) is the weight + # for tag i. + transitions = self.transitions * label_weights.view(-1, 1) + + # Add up the scores for the observed transitions and all the inputs but the last + for i in range(sequence_length - 1): + # Each is shape (batch_size,) + current_tag, next_tag = tags[i], tags[i + 1] + + # The scores for transitioning from current_tag to next_tag + transition_score = transitions[current_tag.view(-1), next_tag.view(-1)] + + # The score for using current_tag + emit_score = logits[i].gather(1, current_tag.view(batch_size, 1)).squeeze(1) + + # Weight emit scores by label. + emit_score *= label_weights[current_tag.view(-1)] + + # Include transition score if next element is unmasked, + # input_score if this element is unmasked. + score = score + transition_score * mask[i + 1] + emit_score * mask[i] + + # Transition from last state to "stop" state. To start with, we need to find the last tag + # for each instance. + last_tag_index = mask.sum(0).long() - 1 + last_tags = tags.gather(0, last_tag_index.view(1, batch_size)).squeeze(0) + + # Compute score of transitioning to `stop_tag` from each "last tag". + if self.include_start_end_transitions: + last_transition_score = self.end_transitions.index_select(0, last_tags) + else: + last_transition_score = 0.0 + + # Add the last input if it's not masked. + last_inputs = logits[-1] # (batch_size, num_tags) + last_input_score = last_inputs.gather(1, last_tags.view(-1, 1)) # (batch_size, 1) + last_input_score = last_input_score.squeeze() # (batch_size,) + + # Weight last emit scores by label weights. + last_input_score = last_input_score * label_weights[last_tags.view(-1)] + + score = score + last_transition_score + last_input_score * mask[-1] + + return score + + + def forward( + self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor = None + ) -> torch.Tensor: + """ + Computes the log likelihood. + """ + + if mask is None: + mask = torch.ones(*tags.size(), dtype=torch.bool, device=inputs.device) + else: + # The code below fails in weird ways if this isn't a bool tensor, so we make sure. + mask = mask.to(torch.bool) + + # TODO check if weights are being used during test/validation + + log_denominator = self._input_likelihood(inputs, tags, mask) + log_numerator = self._joint_likelihood(inputs, tags, mask) + + return torch.sum(log_numerator - log_denominator) + + + def viterbi_tags( + self, logits: torch.Tensor, mask: torch.BoolTensor = None, top_k: int = None + ) -> Union[List[VITERBI_DECODING], List[List[VITERBI_DECODING]]]: + """ + Uses viterbi algorithm to find most likely tags for the given inputs. + If constraints are applied, disallows all other transitions. + + Returns a list of results, of the same size as the batch (one result per batch member) + Each result is a List of length top_k, containing the top K viterbi decodings + Each decoding is a tuple (tag_sequence, viterbi_score) + + For backwards compatibility, if top_k is None, then instead returns a flat list of + tag sequences (the top tag sequence for each batch item). + """ + if mask is None: + mask = torch.ones(*logits.shape[:2], dtype=torch.bool, device=logits.device) + + if top_k is None: + top_k = 1 + flatten_output = True + else: + flatten_output = False + + _, max_seq_length, num_tags = logits.size() + + # Get the tensors out of the variables + logits, mask = logits.data, mask.data + + # Augment transitions matrix with start and end transitions + start_tag = num_tags + end_tag = num_tags + 1 + transitions = torch.full((num_tags + 2, num_tags + 2), -10000.0, device=logits.device) + + # Apply transition constraints + constrained_transitions = self.transitions * self._constraint_mask[ + :num_tags, :num_tags + ] + -10000.0 * (1 - self._constraint_mask[:num_tags, :num_tags]) + transitions[:num_tags, :num_tags] = constrained_transitions.data + + if self.include_start_end_transitions: + transitions[ + start_tag, :num_tags + ] = self.start_transitions.detach() * self._constraint_mask[ + start_tag, :num_tags + ].data + -10000.0 * ( + 1 - self._constraint_mask[start_tag, :num_tags].detach() + ) + transitions[:num_tags, end_tag] = self.end_transitions.detach() * self._constraint_mask[ + :num_tags, end_tag + ].data + -10000.0 * (1 - self._constraint_mask[:num_tags, end_tag].detach()) + else: + transitions[start_tag, :num_tags] = -10000.0 * ( + 1 - self._constraint_mask[start_tag, :num_tags].detach() + ) + transitions[:num_tags, end_tag] = -10000.0 * ( + 1 - self._constraint_mask[:num_tags, end_tag].detach() + ) + + best_paths = [] + # Pad the max sequence length by 2 to account for start_tag + end_tag. + tag_sequence = torch.empty(max_seq_length + 2, num_tags + 2, device=logits.device) + + for prediction, prediction_mask in zip(logits, mask): + mask_indices = prediction_mask.nonzero(as_tuple=False).squeeze() + masked_prediction = torch.index_select(prediction, 0, mask_indices) + sequence_length = masked_prediction.shape[0] + + # Start with everything totally unlikely + tag_sequence.fill_(-10000.0) + # At timestep 0 we must have the START_TAG + tag_sequence[0, start_tag] = 0.0 + # At steps 1, ..., sequence_length we just use the incoming prediction + tag_sequence[1 : (sequence_length + 1), :num_tags] = masked_prediction + # And at the last timestep we must have the END_TAG + tag_sequence[sequence_length + 1, end_tag] = 0.0 + + # We pass the tags and the transitions to `viterbi_decode`. + viterbi_paths, viterbi_scores = util.viterbi_decode( + tag_sequence=tag_sequence[: (sequence_length + 2)], + transition_matrix=transitions, + top_k=top_k, + ) + top_k_paths = [] + for viterbi_path, viterbi_score in zip(viterbi_paths, viterbi_scores): + # Get rid of START and END sentinels and append. + viterbi_path = viterbi_path[1:-1] + top_k_paths.append((viterbi_path, viterbi_score.item())) + best_paths.append(top_k_paths) + + if flatten_output: + return [top_k_paths[0] for top_k_paths in best_paths] + + return best_paths diff --git a/allennlp/modules/conditional_random_field_wtrans.py b/allennlp/modules/conditional_random_field_wtrans.py new file mode 100644 index 00000000000..a03ed2d6cc8 --- /dev/null +++ b/allennlp/modules/conditional_random_field_wtrans.py @@ -0,0 +1,466 @@ +""" +Conditional random field +""" +from typing import List, Tuple, Dict, Union + +import torch + +from allennlp.common.checks import ConfigurationError +import allennlp.nn.util as util + +VITERBI_DECODING = Tuple[List[int], float] # a list of tags, and a viterbi score + + +def allowed_transitions(constraint_type: str, labels: Dict[int, str]) -> List[Tuple[int, int]]: + """ + Given labels and a constraint type, returns the allowed transitions. It will + additionally include transitions for the start and end states, which are used + by the conditional random field. + + # Parameters + + constraint_type : `str`, required + Indicates which constraint to apply. Current choices are + "BIO", "IOB1", "BIOUL", and "BMES". + labels : `Dict[int, str]`, required + A mapping {label_id -> label}. Most commonly this would be the value from + Vocabulary.get_index_to_token_vocabulary() + + # Returns + + `List[Tuple[int, int]]` + The allowed transitions (from_label_id, to_label_id). + """ + num_labels = len(labels) + start_tag = num_labels + end_tag = num_labels + 1 + labels_with_boundaries = list(labels.items()) + [(start_tag, "START"), (end_tag, "END")] + + allowed = [] + for from_label_index, from_label in labels_with_boundaries: + if from_label in ("START", "END"): + from_tag = from_label + from_entity = "" + else: + from_tag = from_label[0] + from_entity = from_label[1:] + for to_label_index, to_label in labels_with_boundaries: + if to_label in ("START", "END"): + to_tag = to_label + to_entity = "" + else: + to_tag = to_label[0] + to_entity = to_label[1:] + if is_transition_allowed(constraint_type, from_tag, from_entity, to_tag, to_entity): + allowed.append((from_label_index, to_label_index)) + return allowed + + +def is_transition_allowed( + constraint_type: str, from_tag: str, from_entity: str, to_tag: str, to_entity: str +): + """ + Given a constraint type and strings `from_tag` and `to_tag` that + represent the origin and destination of the transition, return whether + the transition is allowed under the given constraint type. + + # Parameters + + constraint_type : `str`, required + Indicates which constraint to apply. Current choices are + "BIO", "IOB1", "BIOUL", and "BMES". + from_tag : `str`, required + The tag that the transition originates from. For example, if the + label is `I-PER`, the `from_tag` is `I`. + from_entity : `str`, required + The entity corresponding to the `from_tag`. For example, if the + label is `I-PER`, the `from_entity` is `PER`. + to_tag : `str`, required + The tag that the transition leads to. For example, if the + label is `I-PER`, the `to_tag` is `I`. + to_entity : `str`, required + The entity corresponding to the `to_tag`. For example, if the + label is `I-PER`, the `to_entity` is `PER`. + + # Returns + + `bool` + Whether the transition is allowed under the given `constraint_type`. + """ + + if to_tag == "START" or from_tag == "END": + # Cannot transition into START or from END + return False + + if constraint_type == "BIOUL": + if from_tag == "START": + return to_tag in ("O", "B", "U") + if to_tag == "END": + return from_tag in ("O", "L", "U") + return any( + [ + # O can transition to O, B-* or U-* + # L-x can transition to O, B-*, or U-* + # U-x can transition to O, B-*, or U-* + from_tag in ("O", "L", "U") and to_tag in ("O", "B", "U"), + # B-x can only transition to I-x or L-x + # I-x can only transition to I-x or L-x + from_tag in ("B", "I") and to_tag in ("I", "L") and from_entity == to_entity, + ] + ) + elif constraint_type == "BIO": + if from_tag == "START": + return to_tag in ("O", "B") + if to_tag == "END": + return from_tag in ("O", "B", "I") + return any( + [ + # Can always transition to O or B-x + to_tag in ("O", "B"), + # Can only transition to I-x from B-x or I-x + to_tag == "I" and from_tag in ("B", "I") and from_entity == to_entity, + ] + ) + elif constraint_type == "IOB1": + if from_tag == "START": + return to_tag in ("O", "I") + if to_tag == "END": + return from_tag in ("O", "B", "I") + return any( + [ + # Can always transition to O or I-x + to_tag in ("O", "I"), + # Can only transition to B-x from B-x or I-x, where + # x is the same tag. + to_tag == "B" and from_tag in ("B", "I") and from_entity == to_entity, + ] + ) + elif constraint_type == "BMES": + if from_tag == "START": + return to_tag in ("B", "S") + if to_tag == "END": + return from_tag in ("E", "S") + return any( + [ + # Can only transition to B or S from E or S. + to_tag in ("B", "S") and from_tag in ("E", "S"), + # Can only transition to M-x from B-x, where + # x is the same tag. + to_tag == "M" and from_tag in ("B", "M") and from_entity == to_entity, + # Can only transition to E-x from B-x or M-x, where + # x is the same tag. + to_tag == "E" and from_tag in ("B", "M") and from_entity == to_entity, + ] + ) + else: + raise ConfigurationError(f"Unknown constraint type: {constraint_type}") + + +class ConditionalRandomFieldWeightTrans(torch.nn.Module): + """ + This module uses the "forward-backward" algorithm to compute + the log-likelihood of its inputs assuming a conditional random field model. + + See, e.g. http://www.cs.columbia.edu/~mcollins/fb.pdf + + # Parameters + + num_tags : `int`, required + The number of tags. + constraints : `List[Tuple[int, int]]`, optional (default = `None`) + An optional list of allowed transitions (from_tag_id, to_tag_id). + These are applied to `viterbi_tags()` but do not affect `forward()`. + These should be derived from `allowed_transitions` so that the + start and end transitions are handled correctly for your tag type. + include_start_end_transitions : `bool`, optional (default = `True`) + Whether to include the start and end transition parameters. + label_weights : `List[float]`, optional (default=`None`) + An optional list of weights to be used in the loss function in order to + give different weights for each token depending on its label. + `len(label_weights)` must be equal to `num_tags`. This is useful to + deal with highly unbalanced datasets. The method implemented here is + based on the simple idea of weighting emission and transition scores + using the weight given for the corresponding tag. + """ + + def __init__( + self, + num_tags: int, + constraints: List[Tuple[int, int]] = None, + include_start_end_transitions: bool = True, + label_weights: List[float] = None, + ) -> None: + super().__init__() + self.num_tags = num_tags + + # transitions[i, j] is the logit for transitioning from state i to state j. + self.transitions = torch.nn.Parameter(torch.empty(num_tags, num_tags)) + + # _constraint_mask indicates valid transitions (based on supplied constraints). + # Include special start of sequence (num_tags + 1) and end of sequence tags (num_tags + 2) + if constraints is None: + # All transitions are valid. + constraint_mask = torch.full((num_tags + 2, num_tags + 2), 1.0) + else: + constraint_mask = torch.full((num_tags + 2, num_tags + 2), 0.0) + for i, j in constraints: + constraint_mask[i, j] = 1.0 + + self._constraint_mask = torch.nn.Parameter(constraint_mask, requires_grad=False) + + # Also need logits for transitioning from "start" state and to "end" state. + self.include_start_end_transitions = include_start_end_transitions + if include_start_end_transitions: + self.start_transitions = torch.nn.Parameter(torch.Tensor(num_tags)) + self.end_transitions = torch.nn.Parameter(torch.Tensor(num_tags)) + + # If label_weights is not given, use 1.0 for all weights. + if label_weights is None: + label_weights = [1.0] * num_tags + self.label_weights = torch.nn.Parameter(torch.Tensor(label_weights), requires_grad=False) + + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.xavier_normal_(self.transitions) + if self.include_start_end_transitions: + torch.nn.init.normal_(self.start_transitions) + torch.nn.init.normal_(self.end_transitions) + + def _input_likelihood(self, logits: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: + """ + Computes the (batch_size,) denominator term for the log-likelihood, which is the + sum of the likelihoods across all possible state sequences. + """ + batch_size, sequence_length, num_tags = logits.size() + + # Transpose batch size and sequence dimensions + mask = mask.transpose(0, 1).contiguous() + logits = logits.transpose(0, 1).contiguous() + + # insert the batch dimesion to be broadcasted + label_weights = self.label_weights.view(1, num_tags) + + # weight transition score using current_tag, i.e., t(i,j) will be t(i,j)*w(i), + # where t(i,j) is the score to transition from i to j and w(i) is the weight + # for tag i. + transitions = self.transitions * label_weights.view(-1, 1) + + emit_scores = logits[0] * label_weights + + # Initial alpha is the (batch_size, num_tags) tensor of likelihoods combining the + # transitions to the initial states and the logits for the first timestep. + if self.include_start_end_transitions: + log_alpha = self.start_transitions.view(1, num_tags) + emit_scores + else: + log_alpha = emit_scores + + # For each i we compute logits for the transitions from timestep i-1 to timestep i. + # We do so in a (batch_size, num_tags, num_tags) tensor where the axes are + # (instance, current_tag, next_tag) + for i in range(1, sequence_length): + # multiply the logits by the label weights + # logits[i].shape: (batch_size, num_tags) + emit_scores = logits[i] * label_weights + + # The emit scores are for time i ("next_tag") so we broadcast along the current_tag axis. + emit_scores = emit_scores.view(batch_size, 1, num_tags) + # Transition scores are (current_tag, next_tag) so we broadcast along the instance axis. + transition_scores = transitions.view(1, num_tags, num_tags) + # Alpha is for the current_tag, so we broadcast along the next_tag axis. + broadcast_alpha = log_alpha.view(batch_size, num_tags, 1) + + # Add all the scores together and logexp over the current_tag axis. + inner = broadcast_alpha + emit_scores + transition_scores + + # In valid positions (mask == True) we want to take the logsumexp over the current_tag dimension + # of `inner`. Otherwise (mask == False) we want to retain the previous alpha. + log_alpha = util.logsumexp(inner, 1) * mask[i].view(batch_size, 1) + log_alpha * ( + ~mask[i] + ).view(batch_size, 1) + + # Every sequence needs to end with a transition to the stop_tag. + if self.include_start_end_transitions: + log_alpha = log_alpha + self.end_transitions.view(1, num_tags) + + # Finally we log_sum_exp along the num_tags dim, result is (batch_size,) + return util.logsumexp(log_alpha, 1) + + def _joint_likelihood( + self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor + ) -> torch.Tensor: + """ + Computes the numerator term for the log-likelihood, which is just score(inputs, tags) + """ + batch_size, sequence_length, _ = logits.data.shape + + # Transpose batch size and sequence dimensions: + logits = logits.transpose(0, 1).contiguous() + mask = mask.transpose(0, 1).contiguous() + tags = tags.transpose(0, 1).contiguous() + + # Start with the transition scores from start_tag to the first tag in each input + if self.include_start_end_transitions: + score = self.start_transitions.index_select(0, tags[0]) + else: + score = 0.0 + + label_weights = self.label_weights + + # weight transition score using current_tag, i.e., t(i,j) will be t(i,j)*w[i], + # where t(i,j) is the score to transition from i to j and w[i] is the weight + # for tag i. + transitions = self.transitions * label_weights.view(-1, 1) + + # Add up the scores for the observed transitions and all the inputs but the last + for i in range(sequence_length - 1): + # Each is shape (batch_size,) + current_tag, next_tag = tags[i], tags[i + 1] + + # The scores for transitioning from current_tag to next_tag + transition_score = transitions[current_tag.view(-1), next_tag.view(-1)] + + # The score for using current_tag + emit_score = logits[i].gather(1, current_tag.view(batch_size, 1)).squeeze(1) + + # Weight emit scores by label. + emit_score *= label_weights[current_tag.view(-1)] + + # Include transition score if next element is unmasked, + # input_score if this element is unmasked. + score = score + transition_score * mask[i + 1] + emit_score * mask[i] + + # Transition from last state to "stop" state. To start with, we need to find the last tag + # for each instance. + last_tag_index = mask.sum(0).long() - 1 + last_tags = tags.gather(0, last_tag_index.view(1, batch_size)).squeeze(0) + + # Compute score of transitioning to `stop_tag` from each "last tag". + if self.include_start_end_transitions: + last_transition_score = self.end_transitions.index_select(0, last_tags) + else: + last_transition_score = 0.0 + + # Add the last input if it's not masked. + last_inputs = logits[-1] # (batch_size, num_tags) + last_input_score = last_inputs.gather(1, last_tags.view(-1, 1)) # (batch_size, 1) + last_input_score = last_input_score.squeeze() # (batch_size,) + + # Weight last emit scores by label weights. + last_input_score = last_input_score * label_weights[last_tags.view(-1)] + + score = score + last_transition_score + last_input_score * mask[-1] + + return score + + def forward( + self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor = None + ) -> torch.Tensor: + """ + Computes the log likelihood. + """ + + if mask is None: + mask = torch.ones(*tags.size(), dtype=torch.bool, device=inputs.device) + else: + # The code below fails in weird ways if this isn't a bool tensor, so we make sure. + mask = mask.to(torch.bool) + + log_denominator = self._input_likelihood(inputs, mask) + log_numerator = self._joint_likelihood(inputs, tags, mask) + + return torch.sum(log_numerator - log_denominator) + + def viterbi_tags( + self, logits: torch.Tensor, mask: torch.BoolTensor = None, top_k: int = None + ) -> Union[List[VITERBI_DECODING], List[List[VITERBI_DECODING]]]: + """ + Uses viterbi algorithm to find most likely tags for the given inputs. + If constraints are applied, disallows all other transitions. + + Returns a list of results, of the same size as the batch (one result per batch member) + Each result is a List of length top_k, containing the top K viterbi decodings + Each decoding is a tuple (tag_sequence, viterbi_score) + + For backwards compatibility, if top_k is None, then instead returns a flat list of + tag sequences (the top tag sequence for each batch item). + """ + if mask is None: + mask = torch.ones(*logits.shape[:2], dtype=torch.bool, device=logits.device) + + if top_k is None: + top_k = 1 + flatten_output = True + else: + flatten_output = False + + _, max_seq_length, num_tags = logits.size() + + # Get the tensors out of the variables + logits, mask = logits.data, mask.data + + # Augment transitions matrix with start and end transitions + start_tag = num_tags + end_tag = num_tags + 1 + transitions = torch.full((num_tags + 2, num_tags + 2), -10000.0, device=logits.device) + + # Apply transition constraints + constrained_transitions = self.transitions * self._constraint_mask[ + :num_tags, :num_tags + ] + -10000.0 * (1 - self._constraint_mask[:num_tags, :num_tags]) + transitions[:num_tags, :num_tags] = constrained_transitions.data + + if self.include_start_end_transitions: + transitions[ + start_tag, :num_tags + ] = self.start_transitions.detach() * self._constraint_mask[ + start_tag, :num_tags + ].data + -10000.0 * ( + 1 - self._constraint_mask[start_tag, :num_tags].detach() + ) + transitions[:num_tags, end_tag] = self.end_transitions.detach() * self._constraint_mask[ + :num_tags, end_tag + ].data + -10000.0 * (1 - self._constraint_mask[:num_tags, end_tag].detach()) + else: + transitions[start_tag, :num_tags] = -10000.0 * ( + 1 - self._constraint_mask[start_tag, :num_tags].detach() + ) + transitions[:num_tags, end_tag] = -10000.0 * ( + 1 - self._constraint_mask[:num_tags, end_tag].detach() + ) + + best_paths = [] + # Pad the max sequence length by 2 to account for start_tag + end_tag. + tag_sequence = torch.empty(max_seq_length + 2, num_tags + 2, device=logits.device) + + for prediction, prediction_mask in zip(logits, mask): + mask_indices = prediction_mask.nonzero(as_tuple=False).squeeze() + masked_prediction = torch.index_select(prediction, 0, mask_indices) + sequence_length = masked_prediction.shape[0] + + # Start with everything totally unlikely + tag_sequence.fill_(-10000.0) + # At timestep 0 we must have the START_TAG + tag_sequence[0, start_tag] = 0.0 + # At steps 1, ..., sequence_length we just use the incoming prediction + tag_sequence[1 : (sequence_length + 1), :num_tags] = masked_prediction + # And at the last timestep we must have the END_TAG + tag_sequence[sequence_length + 1, end_tag] = 0.0 + + # We pass the tags and the transitions to `viterbi_decode`. + viterbi_paths, viterbi_scores = util.viterbi_decode( + tag_sequence=tag_sequence[: (sequence_length + 2)], + transition_matrix=transitions, + top_k=top_k, + ) + top_k_paths = [] + for viterbi_path, viterbi_score in zip(viterbi_paths, viterbi_scores): + # Get rid of START and END sentinels and append. + viterbi_path = viterbi_path[1:-1] + top_k_paths.append((viterbi_path, viterbi_score.item())) + best_paths.append(top_k_paths) + + if flatten_output: + return [top_k_paths[0] for top_k_paths in best_paths] + + return best_paths diff --git a/tests/modules/conditional_random_field_test.py b/tests/modules/conditional_random_field_test.py index 5a71cd12c45..d41c982b271 100644 --- a/tests/modules/conditional_random_field_test.py +++ b/tests/modules/conditional_random_field_test.py @@ -391,7 +391,7 @@ def setup_method(self): self.label_weights = torch.FloatTensor([1.0, 1.0, 0.5, 0.5, 0.5]) # Use the CRF Module with labels weights. - self.crf.label_weights = self.label_weights + self.crf.label_weights = torch.nn.Parameter(self.label_weights, requires_grad=False) def score_with_weights(self, logits, tags): """ @@ -423,7 +423,7 @@ def test_forward_works_without_mask(self): for logits_i, tags_i in zip(self.logits, self.tags): numerator = self.score_with_weights(logits_i.detach(), tags_i.detach()) all_scores = [ - self.score(logits_i.detach(), tags_j) + self.score_with_weights(logits_i.detach(), tags_j) for tags_j in itertools.product(range(5), repeat=3) ] denominator = math.log(sum(math.exp(score) for score in all_scores)) @@ -454,7 +454,7 @@ def test_forward_works_with_mask(self): numerator = self.score_with_weights(logits_i, tags_i) all_scores = [ - self.score(logits_i, tags_j) + self.score_with_weights(logits_i, tags_j) for tags_j in itertools.product(range(5), repeat=sequence_length) ] denominator = math.log(sum(math.exp(score) for score in all_scores)) From 595a4889d84f3f0ac0125ae5d538e762a783437d Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Sat, 18 Jun 2022 00:04:27 +0200 Subject: [PATCH 04/10] Weighted CRF: refactoring of three methods --- allennlp/modules/__init__.py | 1 + allennlp/modules/conditional_random_field.py | 106 ++--- .../conditional_random_field_lannoy.py | 5 +- .../conditional_random_field_wemission.py | 82 ++++ .../conditional_random_field_wtrans.py | 431 ++---------------- .../modules/conditional_random_field_test.py | 43 +- 6 files changed, 204 insertions(+), 464 deletions(-) create mode 100644 allennlp/modules/conditional_random_field_wemission.py diff --git a/allennlp/modules/__init__.py b/allennlp/modules/__init__.py index 741f840cc54..f2256b472bf 100644 --- a/allennlp/modules/__init__.py +++ b/allennlp/modules/__init__.py @@ -8,6 +8,7 @@ from allennlp.modules.backbones import Backbone from allennlp.modules.bimpm_matching import BiMpmMatching from allennlp.modules.conditional_random_field import ConditionalRandomField +from allennlp.modules.conditional_random_field_wemission import ConditionalRandomFieldWeightEmission from allennlp.modules.conditional_random_field_wtrans import ConditionalRandomFieldWeightTrans from allennlp.modules.conditional_random_field_lannoy import ConditionalRandomFieldLannoy from allennlp.modules.elmo import Elmo diff --git a/allennlp/modules/conditional_random_field.py b/allennlp/modules/conditional_random_field.py index 277eb2162a9..c8da0456b84 100644 --- a/allennlp/modules/conditional_random_field.py +++ b/allennlp/modules/conditional_random_field.py @@ -174,13 +174,6 @@ class ConditionalRandomField(torch.nn.Module): start and end transitions are handled correctly for your tag type. include_start_end_transitions : `bool`, optional (default = `True`) Whether to include the start and end transition parameters. - label_weights : `List[float]`, optional (default=`None`) - An optional list of weights to be used in the loss function in order to - give different weights for each token depending on its label. - `len(label_weights)` must be equal to `num_tags`. This is useful to - deal with highly unbalanced datasets. The method implemented here is - based on the simple idea of weighting emission and transition scores - using the weight given for the corresponding tag. """ def __init__( @@ -188,7 +181,6 @@ def __init__( num_tags: int, constraints: List[Tuple[int, int]] = None, include_start_end_transitions: bool = True, - label_weights: List[float] = None, ) -> None: super().__init__() self.num_tags = num_tags @@ -214,11 +206,6 @@ def __init__( self.start_transitions = torch.nn.Parameter(torch.Tensor(num_tags)) self.end_transitions = torch.nn.Parameter(torch.Tensor(num_tags)) - # If label_weights is not given, use 1.0 for all weights. - if label_weights is None: - label_weights = [1.0] * num_tags - self.label_weights = torch.nn.Parameter(torch.Tensor(label_weights), requires_grad=False) - self.reset_parameters() def reset_parameters(self): @@ -227,10 +214,21 @@ def reset_parameters(self): torch.nn.init.normal_(self.start_transitions) torch.nn.init.normal_(self.end_transitions) - def _input_likelihood(self, logits: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: - """ - Computes the (batch_size,) denominator term for the log-likelihood, which is the - sum of the likelihoods across all possible state sequences. + def _input_likelihood( + self, logits: torch.Tensor, transitions: torch.Tensor, mask: torch.BoolTensor + ) -> torch.Tensor: + """Computes the (batch_size,) denominator term $Z(x)$, per example, for the log-likelihood + + This is the sum of the likelihoods across all possible state sequences. + + Args: + logits (torch.Tensor): a (batch_size, sequence_length num_tags) tensor of + unnormalized log-probabilities + transitions (torch.Tensor): a (batch_size, num_tags, num_tags) tensor of transition scores + mask (torch.BoolTensor): a (batch_size, sequence_length) tensor of masking flags + + Returns: + torch.Tensor: (batch_size,) denominator term $Z(x)$, per example, for the log-likelihood """ batch_size, sequence_length, num_tags = logits.size() @@ -238,55 +236,60 @@ def _input_likelihood(self, logits: torch.Tensor, mask: torch.BoolTensor) -> tor mask = mask.transpose(0, 1).contiguous() logits = logits.transpose(0, 1).contiguous() - # insert the batch dimesion to be broadcasted - label_weights = self.label_weights.view(1, num_tags) - - # emit_scores.shape = (batch_size, num_tags) - emit_scores = logits[0] * label_weights - # Initial alpha is the (batch_size, num_tags) tensor of likelihoods combining the # transitions to the initial states and the logits for the first timestep. if self.include_start_end_transitions: - log_alpha = self.start_transitions.view(1, num_tags) + emit_scores + alpha = self.start_transitions.view(1, num_tags) + logits[0] else: - log_alpha = emit_scores + alpha = logits[0] # For each i we compute logits for the transitions from timestep i-1 to timestep i. # We do so in a (batch_size, num_tags, num_tags) tensor where the axes are # (instance, current_tag, next_tag) for i in range(1, sequence_length): - # multiply the logits by the label weights - # logits[i].shape: (batch_size, num_tags) - emit_scores = logits[i] * label_weights - # The emit scores are for time i ("next_tag") so we broadcast along the current_tag axis. - emit_scores = emit_scores.view(batch_size, 1, num_tags) + emit_scores = logits[i].view(batch_size, 1, num_tags) # Transition scores are (current_tag, next_tag) so we broadcast along the instance axis. - transition_scores = self.transitions.view(1, num_tags, num_tags) + transition_scores = transitions.view(1, num_tags, num_tags) # Alpha is for the current_tag, so we broadcast along the next_tag axis. - broadcast_alpha = log_alpha.view(batch_size, num_tags, 1) + broadcast_alpha = alpha.view(batch_size, num_tags, 1) # Add all the scores together and logexp over the current_tag axis. inner = broadcast_alpha + emit_scores + transition_scores # In valid positions (mask == True) we want to take the logsumexp over the current_tag dimension # of `inner`. Otherwise (mask == False) we want to retain the previous alpha. - log_alpha = util.logsumexp(inner, 1) * mask[i].view(batch_size, 1) + log_alpha * ( + alpha = util.logsumexp(inner, 1) * mask[i].view(batch_size, 1) + alpha * ( ~mask[i] ).view(batch_size, 1) # Every sequence needs to end with a transition to the stop_tag. if self.include_start_end_transitions: - log_alpha = log_alpha + self.end_transitions.view(1, num_tags) + stops = alpha + self.end_transitions.view(1, num_tags) + else: + stops = alpha # Finally we log_sum_exp along the num_tags dim, result is (batch_size,) - return util.logsumexp(log_alpha, 1) + return util.logsumexp(stops) def _joint_likelihood( - self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor + self, + logits: torch.Tensor, + transitions: torch.Tensor, + tags: torch.Tensor, + mask: torch.BoolTensor, ) -> torch.Tensor: - """ - Computes the numerator term for the log-likelihood, which is just score(inputs, tags) + """Computes the numerator term for the log-likelihood, which is just score(inputs, tags) + + Args: + logits (torch.Tensor): a (batch_size, sequence_length num_tags) tensor of unnormalized + log-probabilities + transitions (torch.Tensor): a (batch_size, num_tags, num_tags) tensor of transition scores + tags (torch.Tensor): output tag sequences (batch_size, sequence_length) $y$ for each input sequence + mask (torch.BoolTensor): a (batch_size, sequence_length) tensor of masking flags + + Returns: + torch.Tensor: numerator term for the log-likelihood, which is just score(inputs, tags) """ batch_size, sequence_length, _ = logits.data.shape @@ -301,22 +304,17 @@ def _joint_likelihood( else: score = 0.0 - label_weights = self.label_weights - # Add up the scores for the observed transitions and all the inputs but the last for i in range(sequence_length - 1): # Each is shape (batch_size,) current_tag, next_tag = tags[i], tags[i + 1] # The scores for transitioning from current_tag to next_tag - transition_score = self.transitions[current_tag.view(-1), next_tag.view(-1)] + transition_score = transitions[current_tag.view(-1), next_tag.view(-1)] # The score for using current_tag emit_score = logits[i].gather(1, current_tag.view(batch_size, 1)).squeeze(1) - # Weight emit scores by label. - emit_score *= label_weights[current_tag.view(-1)] - # Include transition score if next element is unmasked, # input_score if this element is unmasked. score = score + transition_score * mask[i + 1] + emit_score * mask[i] @@ -337,9 +335,6 @@ def _joint_likelihood( last_input_score = last_inputs.gather(1, last_tags.view(-1, 1)) # (batch_size, 1) last_input_score = last_input_score.squeeze() # (batch_size,) - # Weight last emit scores by label weights. - last_input_score = last_input_score * label_weights[last_tags.view(-1)] - score = score + last_transition_score + last_input_score * mask[-1] return score @@ -347,18 +342,25 @@ def _joint_likelihood( def forward( self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor = None ) -> torch.Tensor: - """ - Computes the log likelihood. - """ + """Computes the log likelihood for the given batch of input sequences $(x,y)$ + Args: + inputs (torch.Tensor): (batch_size, sequence_length, num_tags) tensor of logits for the inputs $x$ + tags (torch.Tensor): (batch_size, sequence_length) tensor of tags $y$ + mask (torch.BoolTensor, optional): (batch_size, sequence_length) tensor of masking flags. + Defaults to None. + + Returns: + torch.Tensor: (batch_size,) log likelihoods $log P(y|x)$ for each input + """ if mask is None: mask = torch.ones(*tags.size(), dtype=torch.bool, device=inputs.device) else: # The code below fails in weird ways if this isn't a bool tensor, so we make sure. mask = mask.to(torch.bool) - log_denominator = self._input_likelihood(inputs, mask) - log_numerator = self._joint_likelihood(inputs, tags, mask) + log_denominator = self._input_likelihood(inputs, self.transitions, mask) + log_numerator = self._joint_likelihood(inputs, self.transitions, tags, mask) return torch.sum(log_numerator - log_denominator) diff --git a/allennlp/modules/conditional_random_field_lannoy.py b/allennlp/modules/conditional_random_field_lannoy.py index 5c1c4942cdc..56dbf2e1956 100644 --- a/allennlp/modules/conditional_random_field_lannoy.py +++ b/allennlp/modules/conditional_random_field_lannoy.py @@ -304,7 +304,6 @@ def _input_likelihood( return sum_log_z.squeeze(1) - def _joint_likelihood( self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor ) -> torch.Tensor: @@ -326,7 +325,7 @@ def _joint_likelihood( label_weights = self.label_weights - # weight transition score using current_tag, i.e., t(i,j) will be t(i,j)*w(i), + # weight transition score using current_tag, i.e., t(i,j) will be t(i,j)*w(i), # where t(i,j) is the score to transition from i to j and w(i) is the weight # for tag i. transitions = self.transitions * label_weights.view(-1, 1) @@ -372,7 +371,6 @@ def _joint_likelihood( return score - def forward( self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor = None ) -> torch.Tensor: @@ -393,7 +391,6 @@ def forward( return torch.sum(log_numerator - log_denominator) - def viterbi_tags( self, logits: torch.Tensor, mask: torch.BoolTensor = None, top_k: int = None ) -> Union[List[VITERBI_DECODING], List[List[VITERBI_DECODING]]]: diff --git a/allennlp/modules/conditional_random_field_wemission.py b/allennlp/modules/conditional_random_field_wemission.py new file mode 100644 index 00000000000..babee7f3ebc --- /dev/null +++ b/allennlp/modules/conditional_random_field_wemission.py @@ -0,0 +1,82 @@ +""" +Conditional random field +""" +from typing import List, Tuple + +import torch + +from allennlp.common.checks import ConfigurationError + +from .conditional_random_field import ConditionalRandomField + + +class ConditionalRandomFieldWeightEmission(ConditionalRandomField): + """ + This module uses the "forward-backward" algorithm to compute + the log-likelihood of its inputs assuming a conditional random field model. + + See, e.g. http://www.cs.columbia.edu/~mcollins/fb.pdf + + # Parameters + + num_tags : `int`, required + The number of tags. + constraints : `List[Tuple[int, int]]`, optional (default = `None`) + An optional list of allowed transitions (from_tag_id, to_tag_id). + These are applied to `viterbi_tags()` but do not affect `forward()`. + These should be derived from `allowed_transitions` so that the + start and end transitions are handled correctly for your tag type. + include_start_end_transitions : `bool`, optional (default = `True`) + Whether to include the start and end transition parameters. + label_weights : `List[float]`, optional (default=`None`) + An optional list of weights to be used in the loss function in order to + give different weights for each token depending on its label. + `len(label_weights)` must be equal to `num_tags`. This is useful to + deal with highly unbalanced datasets. The method implemented here is + based on the simple idea of weighting emission and transition scores + using the weight given for the corresponding tag. + """ + + def __init__( + self, + num_tags: int, + label_weights: List[float], + constraints: List[Tuple[int, int]] = None, + include_start_end_transitions: bool = True, + ) -> None: + super().__init__(num_tags, constraints, include_start_end_transitions) + + if label_weights is None: + raise ConfigurationError("label_weights must be given") + + self.label_weights = torch.nn.Parameter(torch.Tensor(label_weights), requires_grad=False) + + def forward( + self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor = None + ) -> torch.Tensor: + """Computes the log likelihood for the given batch of input sequences $(x,y)$ + + Args: + inputs (torch.Tensor): (batch_size, sequence_length, num_tags) tensor of logits for the inputs $x$ + tags (torch.Tensor): (batch_size, sequence_length) tensor of tags $y$ + mask (torch.BoolTensor, optional): (batch_size, sequence_length) tensor of masking flags. + Defaults to None. + + Returns: + torch.Tensor: (batch_size,) log likelihoods $log P(y|x)$ for each input + """ + if mask is None: + mask = torch.ones(*tags.size(), dtype=torch.bool, device=inputs.device) + else: + # The code below fails in weird ways if this isn't a bool tensor, so we make sure. + mask = mask.to(torch.bool) + + label_weights = self.label_weights + + # scale the logits for all examples and all time steps + inputs = inputs * label_weights.view(1, 1, -1) + + log_denominator = self._input_likelihood(inputs, self.transitions, mask) + log_numerator = self._joint_likelihood(inputs, self.transitions, tags, mask) + + return torch.sum(log_numerator - log_denominator) diff --git a/allennlp/modules/conditional_random_field_wtrans.py b/allennlp/modules/conditional_random_field_wtrans.py index a03ed2d6cc8..d6b514f8a97 100644 --- a/allennlp/modules/conditional_random_field_wtrans.py +++ b/allennlp/modules/conditional_random_field_wtrans.py @@ -1,162 +1,16 @@ """ Conditional random field """ -from typing import List, Tuple, Dict, Union +from typing import List, Tuple import torch from allennlp.common.checks import ConfigurationError -import allennlp.nn.util as util -VITERBI_DECODING = Tuple[List[int], float] # a list of tags, and a viterbi score +from .conditional_random_field import ConditionalRandomField -def allowed_transitions(constraint_type: str, labels: Dict[int, str]) -> List[Tuple[int, int]]: - """ - Given labels and a constraint type, returns the allowed transitions. It will - additionally include transitions for the start and end states, which are used - by the conditional random field. - - # Parameters - - constraint_type : `str`, required - Indicates which constraint to apply. Current choices are - "BIO", "IOB1", "BIOUL", and "BMES". - labels : `Dict[int, str]`, required - A mapping {label_id -> label}. Most commonly this would be the value from - Vocabulary.get_index_to_token_vocabulary() - - # Returns - - `List[Tuple[int, int]]` - The allowed transitions (from_label_id, to_label_id). - """ - num_labels = len(labels) - start_tag = num_labels - end_tag = num_labels + 1 - labels_with_boundaries = list(labels.items()) + [(start_tag, "START"), (end_tag, "END")] - - allowed = [] - for from_label_index, from_label in labels_with_boundaries: - if from_label in ("START", "END"): - from_tag = from_label - from_entity = "" - else: - from_tag = from_label[0] - from_entity = from_label[1:] - for to_label_index, to_label in labels_with_boundaries: - if to_label in ("START", "END"): - to_tag = to_label - to_entity = "" - else: - to_tag = to_label[0] - to_entity = to_label[1:] - if is_transition_allowed(constraint_type, from_tag, from_entity, to_tag, to_entity): - allowed.append((from_label_index, to_label_index)) - return allowed - - -def is_transition_allowed( - constraint_type: str, from_tag: str, from_entity: str, to_tag: str, to_entity: str -): - """ - Given a constraint type and strings `from_tag` and `to_tag` that - represent the origin and destination of the transition, return whether - the transition is allowed under the given constraint type. - - # Parameters - - constraint_type : `str`, required - Indicates which constraint to apply. Current choices are - "BIO", "IOB1", "BIOUL", and "BMES". - from_tag : `str`, required - The tag that the transition originates from. For example, if the - label is `I-PER`, the `from_tag` is `I`. - from_entity : `str`, required - The entity corresponding to the `from_tag`. For example, if the - label is `I-PER`, the `from_entity` is `PER`. - to_tag : `str`, required - The tag that the transition leads to. For example, if the - label is `I-PER`, the `to_tag` is `I`. - to_entity : `str`, required - The entity corresponding to the `to_tag`. For example, if the - label is `I-PER`, the `to_entity` is `PER`. - - # Returns - - `bool` - Whether the transition is allowed under the given `constraint_type`. - """ - - if to_tag == "START" or from_tag == "END": - # Cannot transition into START or from END - return False - - if constraint_type == "BIOUL": - if from_tag == "START": - return to_tag in ("O", "B", "U") - if to_tag == "END": - return from_tag in ("O", "L", "U") - return any( - [ - # O can transition to O, B-* or U-* - # L-x can transition to O, B-*, or U-* - # U-x can transition to O, B-*, or U-* - from_tag in ("O", "L", "U") and to_tag in ("O", "B", "U"), - # B-x can only transition to I-x or L-x - # I-x can only transition to I-x or L-x - from_tag in ("B", "I") and to_tag in ("I", "L") and from_entity == to_entity, - ] - ) - elif constraint_type == "BIO": - if from_tag == "START": - return to_tag in ("O", "B") - if to_tag == "END": - return from_tag in ("O", "B", "I") - return any( - [ - # Can always transition to O or B-x - to_tag in ("O", "B"), - # Can only transition to I-x from B-x or I-x - to_tag == "I" and from_tag in ("B", "I") and from_entity == to_entity, - ] - ) - elif constraint_type == "IOB1": - if from_tag == "START": - return to_tag in ("O", "I") - if to_tag == "END": - return from_tag in ("O", "B", "I") - return any( - [ - # Can always transition to O or I-x - to_tag in ("O", "I"), - # Can only transition to B-x from B-x or I-x, where - # x is the same tag. - to_tag == "B" and from_tag in ("B", "I") and from_entity == to_entity, - ] - ) - elif constraint_type == "BMES": - if from_tag == "START": - return to_tag in ("B", "S") - if to_tag == "END": - return from_tag in ("E", "S") - return any( - [ - # Can only transition to B or S from E or S. - to_tag in ("B", "S") and from_tag in ("E", "S"), - # Can only transition to M-x from B-x, where - # x is the same tag. - to_tag == "M" and from_tag in ("B", "M") and from_entity == to_entity, - # Can only transition to E-x from B-x or M-x, where - # x is the same tag. - to_tag == "E" and from_tag in ("B", "M") and from_entity == to_entity, - ] - ) - else: - raise ConfigurationError(f"Unknown constraint type: {constraint_type}") - - -class ConditionalRandomFieldWeightTrans(torch.nn.Module): +class ConditionalRandomFieldWeightTrans(ConditionalRandomField): """ This module uses the "forward-backward" algorithm to compute the log-likelihood of its inputs assuming a conditional random field model. @@ -190,277 +44,44 @@ def __init__( include_start_end_transitions: bool = True, label_weights: List[float] = None, ) -> None: - super().__init__() - self.num_tags = num_tags - - # transitions[i, j] is the logit for transitioning from state i to state j. - self.transitions = torch.nn.Parameter(torch.empty(num_tags, num_tags)) - - # _constraint_mask indicates valid transitions (based on supplied constraints). - # Include special start of sequence (num_tags + 1) and end of sequence tags (num_tags + 2) - if constraints is None: - # All transitions are valid. - constraint_mask = torch.full((num_tags + 2, num_tags + 2), 1.0) - else: - constraint_mask = torch.full((num_tags + 2, num_tags + 2), 0.0) - for i, j in constraints: - constraint_mask[i, j] = 1.0 + super().__init__(num_tags, constraints, include_start_end_transitions) - self._constraint_mask = torch.nn.Parameter(constraint_mask, requires_grad=False) - - # Also need logits for transitioning from "start" state and to "end" state. - self.include_start_end_transitions = include_start_end_transitions - if include_start_end_transitions: - self.start_transitions = torch.nn.Parameter(torch.Tensor(num_tags)) - self.end_transitions = torch.nn.Parameter(torch.Tensor(num_tags)) - - # If label_weights is not given, use 1.0 for all weights. if label_weights is None: - label_weights = [1.0] * num_tags - self.label_weights = torch.nn.Parameter(torch.Tensor(label_weights), requires_grad=False) - - self.reset_parameters() - - def reset_parameters(self): - torch.nn.init.xavier_normal_(self.transitions) - if self.include_start_end_transitions: - torch.nn.init.normal_(self.start_transitions) - torch.nn.init.normal_(self.end_transitions) - - def _input_likelihood(self, logits: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: - """ - Computes the (batch_size,) denominator term for the log-likelihood, which is the - sum of the likelihoods across all possible state sequences. - """ - batch_size, sequence_length, num_tags = logits.size() - - # Transpose batch size and sequence dimensions - mask = mask.transpose(0, 1).contiguous() - logits = logits.transpose(0, 1).contiguous() - - # insert the batch dimesion to be broadcasted - label_weights = self.label_weights.view(1, num_tags) - - # weight transition score using current_tag, i.e., t(i,j) will be t(i,j)*w(i), - # where t(i,j) is the score to transition from i to j and w(i) is the weight - # for tag i. - transitions = self.transitions * label_weights.view(-1, 1) - - emit_scores = logits[0] * label_weights - - # Initial alpha is the (batch_size, num_tags) tensor of likelihoods combining the - # transitions to the initial states and the logits for the first timestep. - if self.include_start_end_transitions: - log_alpha = self.start_transitions.view(1, num_tags) + emit_scores - else: - log_alpha = emit_scores - - # For each i we compute logits for the transitions from timestep i-1 to timestep i. - # We do so in a (batch_size, num_tags, num_tags) tensor where the axes are - # (instance, current_tag, next_tag) - for i in range(1, sequence_length): - # multiply the logits by the label weights - # logits[i].shape: (batch_size, num_tags) - emit_scores = logits[i] * label_weights - - # The emit scores are for time i ("next_tag") so we broadcast along the current_tag axis. - emit_scores = emit_scores.view(batch_size, 1, num_tags) - # Transition scores are (current_tag, next_tag) so we broadcast along the instance axis. - transition_scores = transitions.view(1, num_tags, num_tags) - # Alpha is for the current_tag, so we broadcast along the next_tag axis. - broadcast_alpha = log_alpha.view(batch_size, num_tags, 1) + raise ConfigurationError("label_weights must be given") - # Add all the scores together and logexp over the current_tag axis. - inner = broadcast_alpha + emit_scores + transition_scores - - # In valid positions (mask == True) we want to take the logsumexp over the current_tag dimension - # of `inner`. Otherwise (mask == False) we want to retain the previous alpha. - log_alpha = util.logsumexp(inner, 1) * mask[i].view(batch_size, 1) + log_alpha * ( - ~mask[i] - ).view(batch_size, 1) - - # Every sequence needs to end with a transition to the stop_tag. - if self.include_start_end_transitions: - log_alpha = log_alpha + self.end_transitions.view(1, num_tags) - - # Finally we log_sum_exp along the num_tags dim, result is (batch_size,) - return util.logsumexp(log_alpha, 1) - - def _joint_likelihood( - self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor - ) -> torch.Tensor: - """ - Computes the numerator term for the log-likelihood, which is just score(inputs, tags) - """ - batch_size, sequence_length, _ = logits.data.shape - - # Transpose batch size and sequence dimensions: - logits = logits.transpose(0, 1).contiguous() - mask = mask.transpose(0, 1).contiguous() - tags = tags.transpose(0, 1).contiguous() - - # Start with the transition scores from start_tag to the first tag in each input - if self.include_start_end_transitions: - score = self.start_transitions.index_select(0, tags[0]) - else: - score = 0.0 - - label_weights = self.label_weights - - # weight transition score using current_tag, i.e., t(i,j) will be t(i,j)*w[i], - # where t(i,j) is the score to transition from i to j and w[i] is the weight - # for tag i. - transitions = self.transitions * label_weights.view(-1, 1) - - # Add up the scores for the observed transitions and all the inputs but the last - for i in range(sequence_length - 1): - # Each is shape (batch_size,) - current_tag, next_tag = tags[i], tags[i + 1] - - # The scores for transitioning from current_tag to next_tag - transition_score = transitions[current_tag.view(-1), next_tag.view(-1)] - - # The score for using current_tag - emit_score = logits[i].gather(1, current_tag.view(batch_size, 1)).squeeze(1) - - # Weight emit scores by label. - emit_score *= label_weights[current_tag.view(-1)] - - # Include transition score if next element is unmasked, - # input_score if this element is unmasked. - score = score + transition_score * mask[i + 1] + emit_score * mask[i] - - # Transition from last state to "stop" state. To start with, we need to find the last tag - # for each instance. - last_tag_index = mask.sum(0).long() - 1 - last_tags = tags.gather(0, last_tag_index.view(1, batch_size)).squeeze(0) - - # Compute score of transitioning to `stop_tag` from each "last tag". - if self.include_start_end_transitions: - last_transition_score = self.end_transitions.index_select(0, last_tags) - else: - last_transition_score = 0.0 - - # Add the last input if it's not masked. - last_inputs = logits[-1] # (batch_size, num_tags) - last_input_score = last_inputs.gather(1, last_tags.view(-1, 1)) # (batch_size, 1) - last_input_score = last_input_score.squeeze() # (batch_size,) - - # Weight last emit scores by label weights. - last_input_score = last_input_score * label_weights[last_tags.view(-1)] - - score = score + last_transition_score + last_input_score * mask[-1] - - return score + self.label_weights = torch.nn.Parameter(torch.Tensor(label_weights), requires_grad=False) def forward( self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor = None ) -> torch.Tensor: - """ - Computes the log likelihood. - """ + """Computes the log likelihood for the given batch of input sequences $(x,y)$ + Args: + inputs (torch.Tensor): (batch_size, sequence_length, num_tags) tensor of logits for the inputs $x$ + tags (torch.Tensor): (batch_size, sequence_length) tensor of tags $y$ + mask (torch.BoolTensor, optional): (batch_size, sequence_length) tensor of masking flags. + Defaults to None. + + Returns: + torch.Tensor: (batch_size,) log likelihoods $log P(y|x)$ for each input + """ if mask is None: mask = torch.ones(*tags.size(), dtype=torch.bool, device=inputs.device) else: # The code below fails in weird ways if this isn't a bool tensor, so we make sure. mask = mask.to(torch.bool) - log_denominator = self._input_likelihood(inputs, mask) - log_numerator = self._joint_likelihood(inputs, tags, mask) - - return torch.sum(log_numerator - log_denominator) - - def viterbi_tags( - self, logits: torch.Tensor, mask: torch.BoolTensor = None, top_k: int = None - ) -> Union[List[VITERBI_DECODING], List[List[VITERBI_DECODING]]]: - """ - Uses viterbi algorithm to find most likely tags for the given inputs. - If constraints are applied, disallows all other transitions. - - Returns a list of results, of the same size as the batch (one result per batch member) - Each result is a List of length top_k, containing the top K viterbi decodings - Each decoding is a tuple (tag_sequence, viterbi_score) - - For backwards compatibility, if top_k is None, then instead returns a flat list of - tag sequences (the top tag sequence for each batch item). - """ - if mask is None: - mask = torch.ones(*logits.shape[:2], dtype=torch.bool, device=logits.device) - - if top_k is None: - top_k = 1 - flatten_output = True - else: - flatten_output = False - - _, max_seq_length, num_tags = logits.size() - - # Get the tensors out of the variables - logits, mask = logits.data, mask.data - - # Augment transitions matrix with start and end transitions - start_tag = num_tags - end_tag = num_tags + 1 - transitions = torch.full((num_tags + 2, num_tags + 2), -10000.0, device=logits.device) - - # Apply transition constraints - constrained_transitions = self.transitions * self._constraint_mask[ - :num_tags, :num_tags - ] + -10000.0 * (1 - self._constraint_mask[:num_tags, :num_tags]) - transitions[:num_tags, :num_tags] = constrained_transitions.data - - if self.include_start_end_transitions: - transitions[ - start_tag, :num_tags - ] = self.start_transitions.detach() * self._constraint_mask[ - start_tag, :num_tags - ].data + -10000.0 * ( - 1 - self._constraint_mask[start_tag, :num_tags].detach() - ) - transitions[:num_tags, end_tag] = self.end_transitions.detach() * self._constraint_mask[ - :num_tags, end_tag - ].data + -10000.0 * (1 - self._constraint_mask[:num_tags, end_tag].detach()) - else: - transitions[start_tag, :num_tags] = -10000.0 * ( - 1 - self._constraint_mask[start_tag, :num_tags].detach() - ) - transitions[:num_tags, end_tag] = -10000.0 * ( - 1 - self._constraint_mask[:num_tags, end_tag].detach() - ) - - best_paths = [] - # Pad the max sequence length by 2 to account for start_tag + end_tag. - tag_sequence = torch.empty(max_seq_length + 2, num_tags + 2, device=logits.device) - - for prediction, prediction_mask in zip(logits, mask): - mask_indices = prediction_mask.nonzero(as_tuple=False).squeeze() - masked_prediction = torch.index_select(prediction, 0, mask_indices) - sequence_length = masked_prediction.shape[0] + label_weights = self.label_weights - # Start with everything totally unlikely - tag_sequence.fill_(-10000.0) - # At timestep 0 we must have the START_TAG - tag_sequence[0, start_tag] = 0.0 - # At steps 1, ..., sequence_length we just use the incoming prediction - tag_sequence[1 : (sequence_length + 1), :num_tags] = masked_prediction - # And at the last timestep we must have the END_TAG - tag_sequence[sequence_length + 1, end_tag] = 0.0 + # weight transition score using current_tag, i.e., t(i,j) will be t(i,j)*w(i), + # where t(i,j) is the score of transitioning from i to j and w(i) is the weight + # for tag i. + transitions = self.transitions * label_weights.view(-1, 1) - # We pass the tags and the transitions to `viterbi_decode`. - viterbi_paths, viterbi_scores = util.viterbi_decode( - tag_sequence=tag_sequence[: (sequence_length + 2)], - transition_matrix=transitions, - top_k=top_k, - ) - top_k_paths = [] - for viterbi_path, viterbi_score in zip(viterbi_paths, viterbi_scores): - # Get rid of START and END sentinels and append. - viterbi_path = viterbi_path[1:-1] - top_k_paths.append((viterbi_path, viterbi_score.item())) - best_paths.append(top_k_paths) + # scale the logits for all examples and all time steps + inputs = inputs * label_weights.view(1, 1, -1) - if flatten_output: - return [top_k_paths[0] for top_k_paths in best_paths] + log_denominator = self._input_likelihood(inputs, transitions, mask) + log_numerator = self._joint_likelihood(inputs, transitions, tags, mask) - return best_paths + return torch.sum(log_numerator - log_denominator) diff --git a/tests/modules/conditional_random_field_test.py b/tests/modules/conditional_random_field_test.py index d41c982b271..658d7973856 100644 --- a/tests/modules/conditional_random_field_test.py +++ b/tests/modules/conditional_random_field_test.py @@ -6,6 +6,8 @@ from numpy.testing import assert_allclose from allennlp.modules import ConditionalRandomField +from allennlp.modules import ConditionalRandomFieldWeightEmission +from allennlp.modules import ConditionalRandomFieldWeightTrans from allennlp.modules.conditional_random_field import allowed_transitions from allennlp.common.checks import ConfigurationError from allennlp.common.testing import AllenNlpTestCase @@ -384,12 +386,17 @@ def test_allowed_transitions(self): } -class TestWeightedConditionalRandomField(TestConditionalRandomField): +class TestConditionalRandomFieldWeightEmission(TestConditionalRandomField): def setup_method(self): super().setup_method() - + self.label_weights = torch.FloatTensor([1.0, 1.0, 0.5, 0.5, 0.5]) + self.crf = ConditionalRandomFieldWeightEmission(5, label_weights=self.label_weights) + self.crf.transitions = torch.nn.Parameter(self.transitions) + self.crf.start_transitions = torch.nn.Parameter(self.transitions_from_start) + self.crf.end_transitions = torch.nn.Parameter(self.transitions_to_end) + # Use the CRF Module with labels weights. self.crf.label_weights = torch.nn.Parameter(self.label_weights, requires_grad=False) @@ -409,7 +416,6 @@ def score_with_weights(self, logits, tags): total += logit[tag] * self.label_weights[tag] return total - def test_forward_works_without_mask(self): log_likelihood = self.crf(self.logits, self.tags).item() @@ -463,3 +469,34 @@ def test_forward_works_with_mask(self): # The manually computed log likelihood should equal the result of crf.forward. assert manual_log_likelihood.item() == approx(log_likelihood) + + +class TestConditionalRandomFieldWeightTrans(TestConditionalRandomFieldWeightEmission): + def setup_method(self): + super().setup_method() + + self.label_weights = torch.FloatTensor([1.0, 1.0, 0.5, 0.5, 0.5]) + + self.crf = ConditionalRandomFieldWeightTrans(5, label_weights=self.label_weights) + self.crf.transitions = torch.nn.Parameter(self.transitions) + self.crf.start_transitions = torch.nn.Parameter(self.transitions_from_start) + self.crf.end_transitions = torch.nn.Parameter(self.transitions_to_end) + + # Use the CRF Module with labels weights. + self.crf.label_weights = torch.nn.Parameter(self.label_weights, requires_grad=False) + + def score_with_weights(self, logits, tags): + """ + Computes the likelihood score for the given sequence of tags, + given the provided logits, the transition weights in the CRF model + and the label weights. + """ + # Start with transitions from START and to END + total = self.transitions_from_start[tags[0]] + self.transitions_to_end[tags[-1]] + # Add in all the intermediate transitions + for tag, next_tag in zip(tags, tags[1:]): + total += self.transitions[tag, next_tag] * self.label_weights[tag] + # Add in the logits for the observed tags + for logit, tag in zip(logits, tags): + total += logit[tag] * self.label_weights[tag] + return total From 030da5c49bc0a0402c9cc843fb790109370bf077 Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Sat, 18 Jun 2022 01:42:07 +0200 Subject: [PATCH 05/10] Weighted CRF: refactoring of three methods --- .../modules/conditional_random_field_lannoy.py | 18 +++++++++--------- .../conditional_random_field_wemission.py | 14 +++++++------- .../modules/conditional_random_field_wtrans.py | 16 ++++++++-------- .../training/metrics/fbeta_verbose_measure.py | 4 ++-- 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/allennlp/modules/conditional_random_field_lannoy.py b/allennlp/modules/conditional_random_field_lannoy.py index 56dbf2e1956..8132169ba3d 100644 --- a/allennlp/modules/conditional_random_field_lannoy.py +++ b/allennlp/modules/conditional_random_field_lannoy.py @@ -167,29 +167,29 @@ class ConditionalRandomFieldLannoy(torch.nn.Module): num_tags : `int`, required The number of tags. - constraints : `List[Tuple[int, int]]`, optional (default = `None`) - An optional list of allowed transitions (from_tag_id, to_tag_id). - These are applied to `viterbi_tags()` but do not affect `forward()`. - These should be derived from `allowed_transitions` so that the - start and end transitions are handled correctly for your tag type. - include_start_end_transitions : `bool`, optional (default = `True`) - Whether to include the start and end transition parameters. label_weights : `List[float]`, optional (default=`None`) - An optional list of weights to be used in the loss function in order to + A list of weights to be used in the loss function in order to give different weights for each token depending on its label. `len(label_weights)` must be equal to `num_tags`. This is useful to deal with highly unbalanced datasets. The method implemented here was based on the paper *Weighted conditional random fields for supervised interpatient heartbeat classification* proposed by De Lannoy et. al (2019). See https://perso.uclouvain.be/michel.verleysen/papers/ieeetbe12gdl.pdf + constraints : `List[Tuple[int, int]]`, optional (default = `None`) + An optional list of allowed transitions (from_tag_id, to_tag_id). + These are applied to `viterbi_tags()` but do not affect `forward()`. + These should be derived from `allowed_transitions` so that the + start and end transitions are handled correctly for your tag type. + include_start_end_transitions : `bool`, optional (default = `True`) + Whether to include the start and end transition parameters. """ def __init__( self, num_tags: int, + label_weights: List[float], constraints: List[Tuple[int, int]] = None, include_start_end_transitions: bool = True, - label_weights: List[float] = None, ) -> None: super().__init__() self.num_tags = num_tags diff --git a/allennlp/modules/conditional_random_field_wemission.py b/allennlp/modules/conditional_random_field_wemission.py index babee7f3ebc..1a7215b7221 100644 --- a/allennlp/modules/conditional_random_field_wemission.py +++ b/allennlp/modules/conditional_random_field_wemission.py @@ -21,6 +21,13 @@ class ConditionalRandomFieldWeightEmission(ConditionalRandomField): num_tags : `int`, required The number of tags. + label_weights : `List[float]`, required + A list of weights to be used in the loss function in order to + give different weights for each token depending on its label. + `len(label_weights)` must be equal to `num_tags`. This is useful to + deal with highly unbalanced datasets. The method implemented here is + based on the simple idea of weighting emission scores using the weight + given for the corresponding tag. constraints : `List[Tuple[int, int]]`, optional (default = `None`) An optional list of allowed transitions (from_tag_id, to_tag_id). These are applied to `viterbi_tags()` but do not affect `forward()`. @@ -28,13 +35,6 @@ class ConditionalRandomFieldWeightEmission(ConditionalRandomField): start and end transitions are handled correctly for your tag type. include_start_end_transitions : `bool`, optional (default = `True`) Whether to include the start and end transition parameters. - label_weights : `List[float]`, optional (default=`None`) - An optional list of weights to be used in the loss function in order to - give different weights for each token depending on its label. - `len(label_weights)` must be equal to `num_tags`. This is useful to - deal with highly unbalanced datasets. The method implemented here is - based on the simple idea of weighting emission and transition scores - using the weight given for the corresponding tag. """ def __init__( diff --git a/allennlp/modules/conditional_random_field_wtrans.py b/allennlp/modules/conditional_random_field_wtrans.py index d6b514f8a97..00750b197ee 100644 --- a/allennlp/modules/conditional_random_field_wtrans.py +++ b/allennlp/modules/conditional_random_field_wtrans.py @@ -21,6 +21,13 @@ class ConditionalRandomFieldWeightTrans(ConditionalRandomField): num_tags : `int`, required The number of tags. + label_weights : `List[float]`, required + A list of weights to be used in the loss function in order to + give different weights for each token depending on its label. + `len(label_weights)` must be equal to `num_tags`. This is useful to + deal with highly unbalanced datasets. The method implemented here is + based on the simple idea of weighting emission and transition scores + using the weight given for the corresponding tag. constraints : `List[Tuple[int, int]]`, optional (default = `None`) An optional list of allowed transitions (from_tag_id, to_tag_id). These are applied to `viterbi_tags()` but do not affect `forward()`. @@ -28,21 +35,14 @@ class ConditionalRandomFieldWeightTrans(ConditionalRandomField): start and end transitions are handled correctly for your tag type. include_start_end_transitions : `bool`, optional (default = `True`) Whether to include the start and end transition parameters. - label_weights : `List[float]`, optional (default=`None`) - An optional list of weights to be used in the loss function in order to - give different weights for each token depending on its label. - `len(label_weights)` must be equal to `num_tags`. This is useful to - deal with highly unbalanced datasets. The method implemented here is - based on the simple idea of weighting emission and transition scores - using the weight given for the corresponding tag. """ def __init__( self, num_tags: int, + label_weights: List[float], constraints: List[Tuple[int, int]] = None, include_start_end_transitions: bool = True, - label_weights: List[float] = None, ) -> None: super().__init__(num_tags, constraints, include_start_end_transitions) diff --git a/allennlp/training/metrics/fbeta_verbose_measure.py b/allennlp/training/metrics/fbeta_verbose_measure.py index e3e2b5edff2..d52d9582ade 100644 --- a/allennlp/training/metrics/fbeta_verbose_measure.py +++ b/allennlp/training/metrics/fbeta_verbose_measure.py @@ -54,13 +54,13 @@ class and also three averaged values for each metric: micro, beta : `float`, optional (default = `1.0`) The strength of recall versus precision in the F-score. - labels : `List[int]`, optional + labels : `List[int]`, optional (default = `None`) The set of labels to include. Labels present in the data can be excluded, for example, to calculate a multi-class average ignoring a majority negative class. Labels not present in the data will result in 0 components in a macro or weighted average. - index_to_label : `Dict[int, str]`, optional + index_to_label : `Dict[int, str]`, optional (default = `None`) A dictionary mapping indices to the corresponding label. If this map is giving, the provided metrics include the label instead of the index for each class. From a7cf34c9008942536885383bef5cd3161049aa59 Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Wed, 22 Jun 2022 23:16:29 +0200 Subject: [PATCH 06/10] Weighted CRF: black formatting --- allennlp/modules/conditional_random_field_wemission.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/allennlp/modules/conditional_random_field_wemission.py b/allennlp/modules/conditional_random_field_wemission.py index 1a7215b7221..bc80f2208e1 100644 --- a/allennlp/modules/conditional_random_field_wemission.py +++ b/allennlp/modules/conditional_random_field_wemission.py @@ -26,7 +26,7 @@ class ConditionalRandomFieldWeightEmission(ConditionalRandomField): give different weights for each token depending on its label. `len(label_weights)` must be equal to `num_tags`. This is useful to deal with highly unbalanced datasets. The method implemented here is - based on the simple idea of weighting emission scores using the weight + based on the simple idea of weighting emission scores using the weight given for the corresponding tag. constraints : `List[Tuple[int, int]]`, optional (default = `None`) An optional list of allowed transitions (from_tag_id, to_tag_id). From 4602ec028f37deded781e5cceb053a875ee8b279 Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Sat, 9 Jul 2022 00:38:23 +0200 Subject: [PATCH 07/10] Weighted CRF: moved classes to new module Simplified Lannoy implementation. --- allennlp/modules/__init__.py | 3 - .../conditional_random_field_lannoy.py | 486 ------------------ .../__init__.py | 3 + .../conditional_random_field_wemission.py | 15 +- .../conditional_random_field_wlannoy.py | 230 +++++++++ .../conditional_random_field_wtrans.py | 16 +- .../modules/conditional_random_field_test.py | 20 +- 7 files changed, 274 insertions(+), 499 deletions(-) delete mode 100644 allennlp/modules/conditional_random_field_lannoy.py create mode 100644 allennlp/modules/conditional_random_field_weighted/__init__.py rename allennlp/modules/{ => conditional_random_field_weighted}/conditional_random_field_wemission.py (81%) create mode 100644 allennlp/modules/conditional_random_field_weighted/conditional_random_field_wlannoy.py rename allennlp/modules/{ => conditional_random_field_weighted}/conditional_random_field_wtrans.py (81%) diff --git a/allennlp/modules/__init__.py b/allennlp/modules/__init__.py index f2256b472bf..0e47f36d0f6 100644 --- a/allennlp/modules/__init__.py +++ b/allennlp/modules/__init__.py @@ -8,9 +8,6 @@ from allennlp.modules.backbones import Backbone from allennlp.modules.bimpm_matching import BiMpmMatching from allennlp.modules.conditional_random_field import ConditionalRandomField -from allennlp.modules.conditional_random_field_wemission import ConditionalRandomFieldWeightEmission -from allennlp.modules.conditional_random_field_wtrans import ConditionalRandomFieldWeightTrans -from allennlp.modules.conditional_random_field_lannoy import ConditionalRandomFieldLannoy from allennlp.modules.elmo import Elmo from allennlp.modules.feedforward import FeedForward from allennlp.modules.gated_sum import GatedSum diff --git a/allennlp/modules/conditional_random_field_lannoy.py b/allennlp/modules/conditional_random_field_lannoy.py deleted file mode 100644 index 8132169ba3d..00000000000 --- a/allennlp/modules/conditional_random_field_lannoy.py +++ /dev/null @@ -1,486 +0,0 @@ -""" -Conditional random field -""" -from typing import List, Tuple, Dict, Union - -import torch - -from allennlp.common.checks import ConfigurationError -import allennlp.nn.util as util - -VITERBI_DECODING = Tuple[List[int], float] # a list of tags, and a viterbi score - - -def allowed_transitions(constraint_type: str, labels: Dict[int, str]) -> List[Tuple[int, int]]: - """ - Given labels and a constraint type, returns the allowed transitions. It will - additionally include transitions for the start and end states, which are used - by the conditional random field. - - # Parameters - - constraint_type : `str`, required - Indicates which constraint to apply. Current choices are - "BIO", "IOB1", "BIOUL", and "BMES". - labels : `Dict[int, str]`, required - A mapping {label_id -> label}. Most commonly this would be the value from - Vocabulary.get_index_to_token_vocabulary() - - # Returns - - `List[Tuple[int, int]]` - The allowed transitions (from_label_id, to_label_id). - """ - num_labels = len(labels) - start_tag = num_labels - end_tag = num_labels + 1 - labels_with_boundaries = list(labels.items()) + [(start_tag, "START"), (end_tag, "END")] - - allowed = [] - for from_label_index, from_label in labels_with_boundaries: - if from_label in ("START", "END"): - from_tag = from_label - from_entity = "" - else: - from_tag = from_label[0] - from_entity = from_label[1:] - for to_label_index, to_label in labels_with_boundaries: - if to_label in ("START", "END"): - to_tag = to_label - to_entity = "" - else: - to_tag = to_label[0] - to_entity = to_label[1:] - if is_transition_allowed(constraint_type, from_tag, from_entity, to_tag, to_entity): - allowed.append((from_label_index, to_label_index)) - return allowed - - -def is_transition_allowed( - constraint_type: str, from_tag: str, from_entity: str, to_tag: str, to_entity: str -): - """ - Given a constraint type and strings `from_tag` and `to_tag` that - represent the origin and destination of the transition, return whether - the transition is allowed under the given constraint type. - - # Parameters - - constraint_type : `str`, required - Indicates which constraint to apply. Current choices are - "BIO", "IOB1", "BIOUL", and "BMES". - from_tag : `str`, required - The tag that the transition originates from. For example, if the - label is `I-PER`, the `from_tag` is `I`. - from_entity : `str`, required - The entity corresponding to the `from_tag`. For example, if the - label is `I-PER`, the `from_entity` is `PER`. - to_tag : `str`, required - The tag that the transition leads to. For example, if the - label is `I-PER`, the `to_tag` is `I`. - to_entity : `str`, required - The entity corresponding to the `to_tag`. For example, if the - label is `I-PER`, the `to_entity` is `PER`. - - # Returns - - `bool` - Whether the transition is allowed under the given `constraint_type`. - """ - - if to_tag == "START" or from_tag == "END": - # Cannot transition into START or from END - return False - - if constraint_type == "BIOUL": - if from_tag == "START": - return to_tag in ("O", "B", "U") - if to_tag == "END": - return from_tag in ("O", "L", "U") - return any( - [ - # O can transition to O, B-* or U-* - # L-x can transition to O, B-*, or U-* - # U-x can transition to O, B-*, or U-* - from_tag in ("O", "L", "U") and to_tag in ("O", "B", "U"), - # B-x can only transition to I-x or L-x - # I-x can only transition to I-x or L-x - from_tag in ("B", "I") and to_tag in ("I", "L") and from_entity == to_entity, - ] - ) - elif constraint_type == "BIO": - if from_tag == "START": - return to_tag in ("O", "B") - if to_tag == "END": - return from_tag in ("O", "B", "I") - return any( - [ - # Can always transition to O or B-x - to_tag in ("O", "B"), - # Can only transition to I-x from B-x or I-x - to_tag == "I" and from_tag in ("B", "I") and from_entity == to_entity, - ] - ) - elif constraint_type == "IOB1": - if from_tag == "START": - return to_tag in ("O", "I") - if to_tag == "END": - return from_tag in ("O", "B", "I") - return any( - [ - # Can always transition to O or I-x - to_tag in ("O", "I"), - # Can only transition to B-x from B-x or I-x, where - # x is the same tag. - to_tag == "B" and from_tag in ("B", "I") and from_entity == to_entity, - ] - ) - elif constraint_type == "BMES": - if from_tag == "START": - return to_tag in ("B", "S") - if to_tag == "END": - return from_tag in ("E", "S") - return any( - [ - # Can only transition to B or S from E or S. - to_tag in ("B", "S") and from_tag in ("E", "S"), - # Can only transition to M-x from B-x, where - # x is the same tag. - to_tag == "M" and from_tag in ("B", "M") and from_entity == to_entity, - # Can only transition to E-x from B-x or M-x, where - # x is the same tag. - to_tag == "E" and from_tag in ("B", "M") and from_entity == to_entity, - ] - ) - else: - raise ConfigurationError(f"Unknown constraint type: {constraint_type}") - - -class ConditionalRandomFieldLannoy(torch.nn.Module): - """ - This module uses the "forward-backward" algorithm to compute - the log-likelihood of its inputs assuming a conditional random field model. - - See, e.g. http://www.cs.columbia.edu/~mcollins/fb.pdf - - # Parameters - - num_tags : `int`, required - The number of tags. - label_weights : `List[float]`, optional (default=`None`) - A list of weights to be used in the loss function in order to - give different weights for each token depending on its label. - `len(label_weights)` must be equal to `num_tags`. This is useful to - deal with highly unbalanced datasets. The method implemented here was based on - the paper *Weighted conditional random fields for supervised interpatient heartbeat - classification* proposed by De Lannoy et. al (2019). - See https://perso.uclouvain.be/michel.verleysen/papers/ieeetbe12gdl.pdf - constraints : `List[Tuple[int, int]]`, optional (default = `None`) - An optional list of allowed transitions (from_tag_id, to_tag_id). - These are applied to `viterbi_tags()` but do not affect `forward()`. - These should be derived from `allowed_transitions` so that the - start and end transitions are handled correctly for your tag type. - include_start_end_transitions : `bool`, optional (default = `True`) - Whether to include the start and end transition parameters. - """ - - def __init__( - self, - num_tags: int, - label_weights: List[float], - constraints: List[Tuple[int, int]] = None, - include_start_end_transitions: bool = True, - ) -> None: - super().__init__() - self.num_tags = num_tags - - # transitions[i, j] is the logit for transitioning from state i to state j. - self.transitions = torch.nn.Parameter(torch.empty(num_tags, num_tags)) - - # _constraint_mask indicates valid transitions (based on supplied constraints). - # Include special start of sequence (num_tags + 1) and end of sequence tags (num_tags + 2) - if constraints is None: - # All transitions are valid. - constraint_mask = torch.full((num_tags + 2, num_tags + 2), 1.0) - else: - constraint_mask = torch.full((num_tags + 2, num_tags + 2), 0.0) - for i, j in constraints: - constraint_mask[i, j] = 1.0 - - self._constraint_mask = torch.nn.Parameter(constraint_mask, requires_grad=False) - - # Also need logits for transitioning from "start" state and to "end" state. - self.include_start_end_transitions = include_start_end_transitions - if include_start_end_transitions: - self.start_transitions = torch.nn.Parameter(torch.Tensor(num_tags)) - self.end_transitions = torch.nn.Parameter(torch.Tensor(num_tags)) - - # If label_weights is not given, use 1.0 for all weights. - if label_weights is None: - label_weights = [1.0] * num_tags - self.label_weights = torch.nn.Parameter(torch.Tensor(label_weights), requires_grad=False) - - self.reset_parameters() - - def reset_parameters(self): - torch.nn.init.xavier_normal_(self.transitions) - if self.include_start_end_transitions: - torch.nn.init.normal_(self.start_transitions) - torch.nn.init.normal_(self.end_transitions) - - def _input_likelihood( - self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor - ) -> torch.Tensor: - """ - Computes the (batch_size,) denominator term for the log-likelihood, which is the - sum of the likelihoods across all possible state sequences. - - Compute this value using the scaling trick instead of the log domain trick, since - this is necessary to implement the label-weighting method by Lannoy et al. (2012). - """ - batch_size, sequence_length, num_tags = logits.size() - - # Transpose batch size and sequence dimensions - mask = mask.transpose(0, 1).contiguous() - logits = logits.transpose(0, 1).contiguous() - tags = tags.transpose(0, 1).contiguous() - - # insert an 1-sized second dimension to match z.shape - label_weights = self.label_weights.view(num_tags, 1) - - # emit_scores.shape = (batch_size, num_tags) - emit_scores = logits[0] - - # Initial alpha is the (batch_size, num_tags) tensor of likelihoods combining the - # transitions to the initial states and the logits for the first timestep. - # alpha.shape = (batch_size, num_tags) - if self.include_start_end_transitions: - alpha = torch.exp(self.start_transitions.view(1, num_tags) + emit_scores) - else: - alpha = torch.exp(emit_scores) - - # z.shape = (batch_size, 1) - z = alpha.sum(dim=1, keepdim=True) - alpha = alpha / z - sum_log_z = torch.log(z) * label_weights[tags[0]] - - # For each i we compute logits for the transitions from timestep i-1 to timestep i. - # We do so in a (batch_size, num_tags, num_tags) tensor where the axes are - # (instance, current_tag, next_tag) - for i in range(1, sequence_length): - # multiply the logits by the label weights - # logits[i].shape: (batch_size, num_tags) - # emit_scores = torch.mul(logits[i], label_weights) - emit_scores = logits[i] - - # The emit scores are for time i ("next_tag") so we broadcast along the current_tag axis. - emit_scores = emit_scores.view(batch_size, 1, num_tags) - # Transition scores are (current_tag, next_tag) so we broadcast along the instance axis. - transition_scores = self.transitions.view(1, num_tags, num_tags) - # Alpha is for the current_tag (i-1), so we broadcast along the next_tag axis. - broadcast_alpha = alpha.view(batch_size, num_tags, 1) - - # Add all the scores together and logexp over the current_tag axis. - inner = broadcast_alpha * torch.exp(emit_scores + transition_scores) - - # In valid positions (mask == True) we want to take the logsumexp over the current_tag dimension - # of `inner`. Otherwise (mask == False) we want to retain the previous alpha. - alpha = inner.sum(dim=1) * mask[i].view(batch_size, 1) + alpha * (~mask[i]).view( - batch_size, 1 - ) - - # scale alphas to avoid underflow (sum of alphas equal to 1) - z = alpha.sum(dim=1, keepdim=True) - alpha = alpha / z - # weight z (normalization factor) according to the current tag - sum_log_z += torch.log(z) * label_weights[tags[i]] - - # Every sequence needs to end with a transition to the stop_tag. - if self.include_start_end_transitions: - alpha = alpha * torch.exp(self.end_transitions.view(1, num_tags)) - z = alpha.sum(dim=1, keepdim=True) - # alpha = alpha / z # this step is unnecessary since alpha is not used anymore - sum_log_z += torch.log(z) - - return sum_log_z.squeeze(1) - - def _joint_likelihood( - self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor - ) -> torch.Tensor: - """ - Computes the numerator term for the log-likelihood, which is just score(inputs, tags) - """ - batch_size, sequence_length, _ = logits.data.shape - - # Transpose batch size and sequence dimensions: - logits = logits.transpose(0, 1).contiguous() - mask = mask.transpose(0, 1).contiguous() - tags = tags.transpose(0, 1).contiguous() - - # Start with the transition scores from start_tag to the first tag in each input - if self.include_start_end_transitions: - score = self.start_transitions.index_select(0, tags[0]) - else: - score = 0.0 - - label_weights = self.label_weights - - # weight transition score using current_tag, i.e., t(i,j) will be t(i,j)*w(i), - # where t(i,j) is the score to transition from i to j and w(i) is the weight - # for tag i. - transitions = self.transitions * label_weights.view(-1, 1) - - # Add up the scores for the observed transitions and all the inputs but the last - for i in range(sequence_length - 1): - # Each is shape (batch_size,) - current_tag, next_tag = tags[i], tags[i + 1] - - # The scores for transitioning from current_tag to next_tag - transition_score = transitions[current_tag.view(-1), next_tag.view(-1)] - - # The score for using current_tag - emit_score = logits[i].gather(1, current_tag.view(batch_size, 1)).squeeze(1) - - # Weight emit scores by label. - emit_score *= label_weights[current_tag.view(-1)] - - # Include transition score if next element is unmasked, - # input_score if this element is unmasked. - score = score + transition_score * mask[i + 1] + emit_score * mask[i] - - # Transition from last state to "stop" state. To start with, we need to find the last tag - # for each instance. - last_tag_index = mask.sum(0).long() - 1 - last_tags = tags.gather(0, last_tag_index.view(1, batch_size)).squeeze(0) - - # Compute score of transitioning to `stop_tag` from each "last tag". - if self.include_start_end_transitions: - last_transition_score = self.end_transitions.index_select(0, last_tags) - else: - last_transition_score = 0.0 - - # Add the last input if it's not masked. - last_inputs = logits[-1] # (batch_size, num_tags) - last_input_score = last_inputs.gather(1, last_tags.view(-1, 1)) # (batch_size, 1) - last_input_score = last_input_score.squeeze() # (batch_size,) - - # Weight last emit scores by label weights. - last_input_score = last_input_score * label_weights[last_tags.view(-1)] - - score = score + last_transition_score + last_input_score * mask[-1] - - return score - - def forward( - self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor = None - ) -> torch.Tensor: - """ - Computes the log likelihood. - """ - - if mask is None: - mask = torch.ones(*tags.size(), dtype=torch.bool, device=inputs.device) - else: - # The code below fails in weird ways if this isn't a bool tensor, so we make sure. - mask = mask.to(torch.bool) - - # TODO check if weights are being used during test/validation - - log_denominator = self._input_likelihood(inputs, tags, mask) - log_numerator = self._joint_likelihood(inputs, tags, mask) - - return torch.sum(log_numerator - log_denominator) - - def viterbi_tags( - self, logits: torch.Tensor, mask: torch.BoolTensor = None, top_k: int = None - ) -> Union[List[VITERBI_DECODING], List[List[VITERBI_DECODING]]]: - """ - Uses viterbi algorithm to find most likely tags for the given inputs. - If constraints are applied, disallows all other transitions. - - Returns a list of results, of the same size as the batch (one result per batch member) - Each result is a List of length top_k, containing the top K viterbi decodings - Each decoding is a tuple (tag_sequence, viterbi_score) - - For backwards compatibility, if top_k is None, then instead returns a flat list of - tag sequences (the top tag sequence for each batch item). - """ - if mask is None: - mask = torch.ones(*logits.shape[:2], dtype=torch.bool, device=logits.device) - - if top_k is None: - top_k = 1 - flatten_output = True - else: - flatten_output = False - - _, max_seq_length, num_tags = logits.size() - - # Get the tensors out of the variables - logits, mask = logits.data, mask.data - - # Augment transitions matrix with start and end transitions - start_tag = num_tags - end_tag = num_tags + 1 - transitions = torch.full((num_tags + 2, num_tags + 2), -10000.0, device=logits.device) - - # Apply transition constraints - constrained_transitions = self.transitions * self._constraint_mask[ - :num_tags, :num_tags - ] + -10000.0 * (1 - self._constraint_mask[:num_tags, :num_tags]) - transitions[:num_tags, :num_tags] = constrained_transitions.data - - if self.include_start_end_transitions: - transitions[ - start_tag, :num_tags - ] = self.start_transitions.detach() * self._constraint_mask[ - start_tag, :num_tags - ].data + -10000.0 * ( - 1 - self._constraint_mask[start_tag, :num_tags].detach() - ) - transitions[:num_tags, end_tag] = self.end_transitions.detach() * self._constraint_mask[ - :num_tags, end_tag - ].data + -10000.0 * (1 - self._constraint_mask[:num_tags, end_tag].detach()) - else: - transitions[start_tag, :num_tags] = -10000.0 * ( - 1 - self._constraint_mask[start_tag, :num_tags].detach() - ) - transitions[:num_tags, end_tag] = -10000.0 * ( - 1 - self._constraint_mask[:num_tags, end_tag].detach() - ) - - best_paths = [] - # Pad the max sequence length by 2 to account for start_tag + end_tag. - tag_sequence = torch.empty(max_seq_length + 2, num_tags + 2, device=logits.device) - - for prediction, prediction_mask in zip(logits, mask): - mask_indices = prediction_mask.nonzero(as_tuple=False).squeeze() - masked_prediction = torch.index_select(prediction, 0, mask_indices) - sequence_length = masked_prediction.shape[0] - - # Start with everything totally unlikely - tag_sequence.fill_(-10000.0) - # At timestep 0 we must have the START_TAG - tag_sequence[0, start_tag] = 0.0 - # At steps 1, ..., sequence_length we just use the incoming prediction - tag_sequence[1 : (sequence_length + 1), :num_tags] = masked_prediction - # And at the last timestep we must have the END_TAG - tag_sequence[sequence_length + 1, end_tag] = 0.0 - - # We pass the tags and the transitions to `viterbi_decode`. - viterbi_paths, viterbi_scores = util.viterbi_decode( - tag_sequence=tag_sequence[: (sequence_length + 2)], - transition_matrix=transitions, - top_k=top_k, - ) - top_k_paths = [] - for viterbi_path, viterbi_score in zip(viterbi_paths, viterbi_scores): - # Get rid of START and END sentinels and append. - viterbi_path = viterbi_path[1:-1] - top_k_paths.append((viterbi_path, viterbi_score.item())) - best_paths.append(top_k_paths) - - if flatten_output: - return [top_k_paths[0] for top_k_paths in best_paths] - - return best_paths diff --git a/allennlp/modules/conditional_random_field_weighted/__init__.py b/allennlp/modules/conditional_random_field_weighted/__init__.py new file mode 100644 index 00000000000..3d7ea3281d2 --- /dev/null +++ b/allennlp/modules/conditional_random_field_weighted/__init__.py @@ -0,0 +1,3 @@ +from allennlp.modules.conditional_random_field_weighted.conditional_random_field_wemission import ConditionalRandomFieldWeightEmission +from allennlp.modules.conditional_random_field_weighted.conditional_random_field_wtrans import ConditionalRandomFieldWeightTrans +from allennlp.modules.conditional_random_field_weighted.conditional_random_field_wlannoy import ConditionalRandomFieldWeightLannoy \ No newline at end of file diff --git a/allennlp/modules/conditional_random_field_wemission.py b/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wemission.py similarity index 81% rename from allennlp/modules/conditional_random_field_wemission.py rename to allennlp/modules/conditional_random_field_weighted/conditional_random_field_wemission.py index bc80f2208e1..7c9da52b071 100644 --- a/allennlp/modules/conditional_random_field_wemission.py +++ b/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wemission.py @@ -1,13 +1,12 @@ """ -Conditional random field +Conditional random field with emission-based weighting """ from typing import List, Tuple import torch from allennlp.common.checks import ConfigurationError - -from .conditional_random_field import ConditionalRandomField +from allennlp.modules.conditional_random_field import ConditionalRandomField class ConditionalRandomFieldWeightEmission(ConditionalRandomField): @@ -17,6 +16,14 @@ class ConditionalRandomFieldWeightEmission(ConditionalRandomField): See, e.g. http://www.cs.columbia.edu/~mcollins/fb.pdf + This is a weighted version of `ConditionalRandomField` which accepts a `label_weights` + parameter to be used in the loss function in order to give different weights for each + token depending on its label. The method implemented here is based on the simple idea + of weighting emission scores using the weight given for the corresponding tag. + + There are two other sample weighting methods implemented. You can find more details + about them in: https://eraldoluis.github.io/2022/05/10/weighted-crf.html + # Parameters num_tags : `int`, required @@ -49,7 +56,7 @@ def __init__( if label_weights is None: raise ConfigurationError("label_weights must be given") - self.label_weights = torch.nn.Parameter(torch.Tensor(label_weights), requires_grad=False) + self.register_buffer("label_weights", torch.Tensor(label_weights)) def forward( self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor = None diff --git a/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wlannoy.py b/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wlannoy.py new file mode 100644 index 00000000000..ba6a1ec7a16 --- /dev/null +++ b/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wlannoy.py @@ -0,0 +1,230 @@ +""" +Conditional random field with weighting based on Lannoy et al. (2019) approach +""" +from typing import List, Tuple + +import torch + +from allennlp.common.checks import ConfigurationError +from allennlp.modules.conditional_random_field import ConditionalRandomField + + +class ConditionalRandomFieldWeightLannoy(ConditionalRandomField): + """ + This module uses the "forward-backward" algorithm to compute + the log-likelihood of its inputs assuming a conditional random field model. + + See, e.g. http://www.cs.columbia.edu/~mcollins/fb.pdf + + This is a weighted version of `ConditionalRandomField` which accepts a `label_weights` + parameter to be used in the loss function in order to give different weights for each + token depending on its label. The method implemented here is based on the paper + *Weighted conditional random fields for supervised interpatient heartbeat + classification* proposed by De Lannoy et. al (2019). + See https://perso.uclouvain.be/michel.verleysen/papers/ieeetbe12gdl.pdf for more details. + + There are two other sample weighting methods implemented. You can find more details + about them in: https://eraldoluis.github.io/2022/05/10/weighted-crf.html + + # Parameters + + num_tags : `int`, required + The number of tags. + label_weights : `List[float]`, required + A list of weights to be used in the loss function in order to + give different weights for each token depending on its label. + `len(label_weights)` must be equal to `num_tags`. This is useful to + deal with highly unbalanced datasets. The method implemented here was based on + the paper *Weighted conditional random fields for supervised interpatient heartbeat + classification* proposed by De Lannoy et. al (2019). + See https://perso.uclouvain.be/michel.verleysen/papers/ieeetbe12gdl.pdf + constraints : `List[Tuple[int, int]]`, optional (default = `None`) + An optional list of allowed transitions (from_tag_id, to_tag_id). + These are applied to `viterbi_tags()` but do not affect `forward()`. + These should be derived from `allowed_transitions` so that the + start and end transitions are handled correctly for your tag type. + include_start_end_transitions : `bool`, optional (default = `True`) + Whether to include the start and end transition parameters. + """ + + def __init__( + self, + num_tags: int, + label_weights: List[float], + constraints: List[Tuple[int, int]] = None, + include_start_end_transitions: bool = True, + ) -> None: + super().__init__(num_tags, constraints, include_start_end_transitions) + + if label_weights is None: + raise ConfigurationError("label_weights must be given") + + self.register_buffer("label_weights", torch.Tensor(label_weights)) + + def forward( + self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor = None + ) -> torch.Tensor: + """Computes the log likelihood for the given batch of input sequences $(x,y)$ + + Args: + inputs (torch.Tensor): (batch_size, sequence_length, num_tags) tensor of logits for the inputs $x$ + tags (torch.Tensor): (batch_size, sequence_length) tensor of tags $y$ + mask (torch.BoolTensor, optional): (batch_size, sequence_length) tensor of masking flags. + Defaults to None. + + Returns: + torch.Tensor: (batch_size,) log likelihoods $log P(y|x)$ for each input + """ + if mask is None: + mask = torch.ones(*tags.size(), dtype=torch.bool, device=inputs.device) + else: + # The code below fails in weird ways if this isn't a bool tensor, so we make sure. + mask = mask.to(torch.bool) + + log_denominator = self._input_likelihood(inputs, tags, mask) + log_numerator = self._joint_likelihood(inputs, tags, mask) + + return torch.sum(log_numerator - log_denominator) + + def _input_likelihood( + self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor + ) -> torch.Tensor: + """ + Computes the (batch_size,) denominator term for the log-likelihood, which is the + sum of the likelihoods across all possible state sequences. + + Compute this value using the scaling trick instead of the log domain trick, since + this is necessary to implement the label-weighting method by Lannoy et al. (2012). + """ + batch_size, sequence_length, num_tags = logits.size() + + # Transpose batch size and sequence dimensions + mask = mask.transpose(0, 1).contiguous() + logits = logits.transpose(0, 1).contiguous() + tags = tags.transpose(0, 1).contiguous() + + # insert an 1-sized second dimension to match z.shape + label_weights = self.label_weights.view(num_tags, 1) + + # emit_scores.shape = (batch_size, num_tags) + emit_scores = logits[0] + + # Initial alpha is the (batch_size, num_tags) tensor of likelihoods combining the + # transitions to the initial states and the logits for the first timestep. + # alpha.shape = (batch_size, num_tags) + if self.include_start_end_transitions: + alpha = torch.exp(self.start_transitions.view(1, num_tags) + emit_scores) + else: + alpha = torch.exp(emit_scores) + + # z.shape = (batch_size, 1) + z = alpha.sum(dim=1, keepdim=True) + alpha = alpha / z + sum_log_z = torch.log(z) * label_weights[tags[0]] + + # For each i we compute logits for the transitions from timestep i-1 to timestep i. + # We do so in a (batch_size, num_tags, num_tags) tensor where the axes are + # (instance, current_tag, next_tag) + for i in range(1, sequence_length): + # multiply the logits by the label weights + # logits[i].shape: (batch_size, num_tags) + # emit_scores = torch.mul(logits[i], label_weights) + emit_scores = logits[i] + + # The emit scores are for time i ("next_tag") so we broadcast along the current_tag axis. + emit_scores = emit_scores.view(batch_size, 1, num_tags) + # Transition scores are (current_tag, next_tag) so we broadcast along the instance axis. + transition_scores = self.transitions.view(1, num_tags, num_tags) + # Alpha is for the current_tag (i-1), so we broadcast along the next_tag axis. + broadcast_alpha = alpha.view(batch_size, num_tags, 1) + + # Add all the scores together and logexp over the current_tag axis. + inner = broadcast_alpha * torch.exp(emit_scores + transition_scores) + + # In valid positions (mask == True) we want to take the logsumexp over the current_tag dimension + # of `inner`. Otherwise (mask == False) we want to retain the previous alpha. + alpha = inner.sum(dim=1) * mask[i].view(batch_size, 1) + alpha * (~mask[i]).view( + batch_size, 1 + ) + + # scale alphas to avoid underflow (sum of alphas equal to 1) + z = alpha.sum(dim=1, keepdim=True) + alpha = alpha / z + # weight z (normalization factor) according to the current tag + sum_log_z += torch.log(z) * label_weights[tags[i]] + + # Every sequence needs to end with a transition to the stop_tag. + if self.include_start_end_transitions: + alpha = alpha * torch.exp(self.end_transitions.view(1, num_tags)) + z = alpha.sum(dim=1, keepdim=True) + # alpha = alpha / z # this step is unnecessary since alpha is not used anymore + sum_log_z += torch.log(z) + + return sum_log_z.squeeze(1) + + def _joint_likelihood( + self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor + ) -> torch.Tensor: + """ + Computes the numerator term for the log-likelihood, which is just score(inputs, tags) + """ + batch_size, sequence_length, _ = logits.data.shape + + # Transpose batch size and sequence dimensions: + logits = logits.transpose(0, 1).contiguous() + mask = mask.transpose(0, 1).contiguous() + tags = tags.transpose(0, 1).contiguous() + + # Start with the transition scores from start_tag to the first tag in each input + if self.include_start_end_transitions: + score = self.start_transitions.index_select(0, tags[0]) + else: + score = 0.0 + + label_weights = self.label_weights + + # weight transition score using current_tag, i.e., t(i,j) will be t(i,j)*w(i), + # where t(i,j) is the score to transition from i to j and w(i) is the weight + # for tag i. + transitions = self.transitions * label_weights.view(-1, 1) + + # Add up the scores for the observed transitions and all the inputs but the last + for i in range(sequence_length - 1): + # Each is shape (batch_size,) + current_tag, next_tag = tags[i], tags[i + 1] + + # The scores for transitioning from current_tag to next_tag + transition_score = transitions[current_tag.view(-1), next_tag.view(-1)] + + # The score for using current_tag + emit_score = logits[i].gather(1, current_tag.view(batch_size, 1)).squeeze(1) + + # Weight emit scores by label. + emit_score *= label_weights[current_tag.view(-1)] + + # Include transition score if next element is unmasked, + # input_score if this element is unmasked. + score = score + transition_score * mask[i + 1] + emit_score * mask[i] + + # Transition from last state to "stop" state. To start with, we need to find the last tag + # for each instance. + last_tag_index = mask.sum(0).long() - 1 + last_tags = tags.gather(0, last_tag_index.view(1, batch_size)).squeeze(0) + + # Compute score of transitioning to `stop_tag` from each "last tag". + if self.include_start_end_transitions: + last_transition_score = self.end_transitions.index_select(0, last_tags) + else: + last_transition_score = 0.0 + + # Add the last input if it's not masked. + last_inputs = logits[-1] # (batch_size, num_tags) + last_input_score = last_inputs.gather(1, last_tags.view(-1, 1)) # (batch_size, 1) + last_input_score = last_input_score.squeeze() # (batch_size,) + + # Weight last emit scores by label weights. + last_input_score = last_input_score * label_weights[last_tags.view(-1)] + + score = score + last_transition_score + last_input_score * mask[-1] + + return score diff --git a/allennlp/modules/conditional_random_field_wtrans.py b/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wtrans.py similarity index 81% rename from allennlp/modules/conditional_random_field_wtrans.py rename to allennlp/modules/conditional_random_field_weighted/conditional_random_field_wtrans.py index 00750b197ee..5978f750e2d 100644 --- a/allennlp/modules/conditional_random_field_wtrans.py +++ b/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wtrans.py @@ -1,13 +1,12 @@ """ -Conditional random field +Conditional random field with emission- and transition-based weighting """ from typing import List, Tuple import torch from allennlp.common.checks import ConfigurationError - -from .conditional_random_field import ConditionalRandomField +from allennlp.modules.conditional_random_field import ConditionalRandomField class ConditionalRandomFieldWeightTrans(ConditionalRandomField): @@ -17,6 +16,15 @@ class ConditionalRandomFieldWeightTrans(ConditionalRandomField): See, e.g. http://www.cs.columbia.edu/~mcollins/fb.pdf + This is a weighted version of `ConditionalRandomField` which accepts a `label_weights` + parameter to be used in the loss function in order to give different weights for each + token depending on its label. The method implemented here is based on the simple idea + of weighting emission and transition scores using the weight given for the + corresponding tag. + + There are two other sample weighting methods implemented. You can find more details + about them in: https://eraldoluis.github.io/2022/05/10/weighted-crf.html + # Parameters num_tags : `int`, required @@ -49,7 +57,7 @@ def __init__( if label_weights is None: raise ConfigurationError("label_weights must be given") - self.label_weights = torch.nn.Parameter(torch.Tensor(label_weights), requires_grad=False) + self.register_buffer("label_weights", torch.Tensor(label_weights)) def forward( self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor = None diff --git a/tests/modules/conditional_random_field_test.py b/tests/modules/conditional_random_field_test.py index 658d7973856..3ec5829e94f 100644 --- a/tests/modules/conditional_random_field_test.py +++ b/tests/modules/conditional_random_field_test.py @@ -6,8 +6,9 @@ from numpy.testing import assert_allclose from allennlp.modules import ConditionalRandomField -from allennlp.modules import ConditionalRandomFieldWeightEmission -from allennlp.modules import ConditionalRandomFieldWeightTrans +from allennlp.modules.conditional_random_field_weighted import ConditionalRandomFieldWeightEmission +from allennlp.modules.conditional_random_field_weighted import ConditionalRandomFieldWeightTrans +from allennlp.modules.conditional_random_field_weighted import ConditionalRandomFieldWeightLannoy from allennlp.modules.conditional_random_field import allowed_transitions from allennlp.common.checks import ConfigurationError from allennlp.common.testing import AllenNlpTestCase @@ -500,3 +501,18 @@ def score_with_weights(self, logits, tags): for logit, tag in zip(logits, tags): total += logit[tag] * self.label_weights[tag] return total + + +class TestConditionalRandomFieldWeightLannoy(TestConditionalRandomFieldWeightEmission): + def setup_method(self): + super().setup_method() + + self.label_weights = torch.FloatTensor([1.0, 1.0, 1.0, 1.0, 1.0]) + + self.crf = ConditionalRandomFieldWeightLannoy(5, label_weights=self.label_weights) + self.crf.transitions = torch.nn.Parameter(self.transitions) + self.crf.start_transitions = torch.nn.Parameter(self.transitions_from_start) + self.crf.end_transitions = torch.nn.Parameter(self.transitions_to_end) + + # Use the CRF Module with labels weights. + self.crf.label_weights = torch.nn.Parameter(self.label_weights, requires_grad=False) From a01d29f45e3d47775f0765c9b82fc29a3f960b69 Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Sat, 9 Jul 2022 00:45:58 +0200 Subject: [PATCH 08/10] formatting and type checking --- .../conditional_random_field_weighted/__init__.py | 12 +++++++++--- .../conditional_random_field_wemission.py | 6 +++--- .../conditional_random_field_wlannoy.py | 14 +++++++------- .../conditional_random_field_wtrans.py | 8 ++++---- 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/allennlp/modules/conditional_random_field_weighted/__init__.py b/allennlp/modules/conditional_random_field_weighted/__init__.py index 3d7ea3281d2..1e4d4daeead 100644 --- a/allennlp/modules/conditional_random_field_weighted/__init__.py +++ b/allennlp/modules/conditional_random_field_weighted/__init__.py @@ -1,3 +1,9 @@ -from allennlp.modules.conditional_random_field_weighted.conditional_random_field_wemission import ConditionalRandomFieldWeightEmission -from allennlp.modules.conditional_random_field_weighted.conditional_random_field_wtrans import ConditionalRandomFieldWeightTrans -from allennlp.modules.conditional_random_field_weighted.conditional_random_field_wlannoy import ConditionalRandomFieldWeightLannoy \ No newline at end of file +from allennlp.modules.conditional_random_field_weighted.conditional_random_field_wemission import ( + ConditionalRandomFieldWeightEmission, +) +from allennlp.modules.conditional_random_field_weighted.conditional_random_field_wtrans import ( + ConditionalRandomFieldWeightTrans, +) +from allennlp.modules.conditional_random_field_weighted.conditional_random_field_wlannoy import ( + ConditionalRandomFieldWeightLannoy, +) diff --git a/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wemission.py b/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wemission.py index 7c9da52b071..5198b45624a 100644 --- a/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wemission.py +++ b/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wemission.py @@ -17,11 +17,11 @@ class ConditionalRandomFieldWeightEmission(ConditionalRandomField): See, e.g. http://www.cs.columbia.edu/~mcollins/fb.pdf This is a weighted version of `ConditionalRandomField` which accepts a `label_weights` - parameter to be used in the loss function in order to give different weights for each - token depending on its label. The method implemented here is based on the simple idea + parameter to be used in the loss function in order to give different weights for each + token depending on its label. The method implemented here is based on the simple idea of weighting emission scores using the weight given for the corresponding tag. - There are two other sample weighting methods implemented. You can find more details + There are two other sample weighting methods implemented. You can find more details about them in: https://eraldoluis.github.io/2022/05/10/weighted-crf.html # Parameters diff --git a/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wlannoy.py b/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wlannoy.py index ba6a1ec7a16..7b62c63a3fd 100644 --- a/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wlannoy.py +++ b/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wlannoy.py @@ -17,13 +17,13 @@ class ConditionalRandomFieldWeightLannoy(ConditionalRandomField): See, e.g. http://www.cs.columbia.edu/~mcollins/fb.pdf This is a weighted version of `ConditionalRandomField` which accepts a `label_weights` - parameter to be used in the loss function in order to give different weights for each - token depending on its label. The method implemented here is based on the paper + parameter to be used in the loss function in order to give different weights for each + token depending on its label. The method implemented here is based on the paper *Weighted conditional random fields for supervised interpatient heartbeat classification* proposed by De Lannoy et. al (2019). See https://perso.uclouvain.be/michel.verleysen/papers/ieeetbe12gdl.pdf for more details. - There are two other sample weighting methods implemented. You can find more details + There are two other sample weighting methods implemented. You can find more details about them in: https://eraldoluis.github.io/2022/05/10/weighted-crf.html # Parameters @@ -81,12 +81,12 @@ def forward( # The code below fails in weird ways if this isn't a bool tensor, so we make sure. mask = mask.to(torch.bool) - log_denominator = self._input_likelihood(inputs, tags, mask) - log_numerator = self._joint_likelihood(inputs, tags, mask) + log_denominator = self._input_likelihood_lannoy(inputs, tags, mask) + log_numerator = self._joint_likelihood_lannoy(inputs, tags, mask) return torch.sum(log_numerator - log_denominator) - def _input_likelihood( + def _input_likelihood_lannoy( self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor ) -> torch.Tensor: """ @@ -162,7 +162,7 @@ def _input_likelihood( return sum_log_z.squeeze(1) - def _joint_likelihood( + def _joint_likelihood_lannoy( self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor ) -> torch.Tensor: """ diff --git a/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wtrans.py b/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wtrans.py index 5978f750e2d..bc514542ea0 100644 --- a/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wtrans.py +++ b/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wtrans.py @@ -17,12 +17,12 @@ class ConditionalRandomFieldWeightTrans(ConditionalRandomField): See, e.g. http://www.cs.columbia.edu/~mcollins/fb.pdf This is a weighted version of `ConditionalRandomField` which accepts a `label_weights` - parameter to be used in the loss function in order to give different weights for each - token depending on its label. The method implemented here is based on the simple idea - of weighting emission and transition scores using the weight given for the + parameter to be used in the loss function in order to give different weights for each + token depending on its label. The method implemented here is based on the simple idea + of weighting emission and transition scores using the weight given for the corresponding tag. - There are two other sample weighting methods implemented. You can find more details + There are two other sample weighting methods implemented. You can find more details about them in: https://eraldoluis.github.io/2022/05/10/weighted-crf.html # Parameters From 236f654bd5b93ccb08a97518689b62c694eef89a Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Wed, 13 Jul 2022 23:13:23 +0200 Subject: [PATCH 09/10] Moved ConditionalRandomField to new module Renamed module allennlp.modules.conditional_random_field_weight to ...conditional_random_files --- .../modules/conditional_random_field/__init__.py | 12 ++++++++++++ .../conditional_random_field.py | 0 .../conditional_random_field_wemission.py | 4 +++- .../conditional_random_field_wlannoy.py | 4 +++- .../conditional_random_field_wtrans.py | 4 +++- .../conditional_random_field_weighted/__init__.py | 9 --------- tests/modules/conditional_random_field_test.py | 10 ++++++---- 7 files changed, 27 insertions(+), 16 deletions(-) create mode 100644 allennlp/modules/conditional_random_field/__init__.py rename allennlp/modules/{ => conditional_random_field}/conditional_random_field.py (100%) rename allennlp/modules/{conditional_random_field_weighted => conditional_random_field}/conditional_random_field_wemission.py (97%) rename allennlp/modules/{conditional_random_field_weighted => conditional_random_field}/conditional_random_field_wlannoy.py (98%) rename allennlp/modules/{conditional_random_field_weighted => conditional_random_field}/conditional_random_field_wtrans.py (97%) delete mode 100644 allennlp/modules/conditional_random_field_weighted/__init__.py diff --git a/allennlp/modules/conditional_random_field/__init__.py b/allennlp/modules/conditional_random_field/__init__.py new file mode 100644 index 00000000000..609aac89e95 --- /dev/null +++ b/allennlp/modules/conditional_random_field/__init__.py @@ -0,0 +1,12 @@ +from allennlp.modules.conditional_random_field.conditional_random_field import ( + ConditionalRandomField, +) +from allennlp.modules.conditional_random_field.conditional_random_field_wemission import ( + ConditionalRandomFieldWeightEmission, +) +from allennlp.modules.conditional_random_field.conditional_random_field_wtrans import ( + ConditionalRandomFieldWeightTrans, +) +from allennlp.modules.conditional_random_field.conditional_random_field_wlannoy import ( + ConditionalRandomFieldWeightLannoy, +) diff --git a/allennlp/modules/conditional_random_field.py b/allennlp/modules/conditional_random_field/conditional_random_field.py similarity index 100% rename from allennlp/modules/conditional_random_field.py rename to allennlp/modules/conditional_random_field/conditional_random_field.py diff --git a/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wemission.py b/allennlp/modules/conditional_random_field/conditional_random_field_wemission.py similarity index 97% rename from allennlp/modules/conditional_random_field_weighted/conditional_random_field_wemission.py rename to allennlp/modules/conditional_random_field/conditional_random_field_wemission.py index 5198b45624a..516a354e1a8 100644 --- a/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wemission.py +++ b/allennlp/modules/conditional_random_field/conditional_random_field_wemission.py @@ -6,7 +6,9 @@ import torch from allennlp.common.checks import ConfigurationError -from allennlp.modules.conditional_random_field import ConditionalRandomField +from allennlp.modules.conditional_random_field.conditional_random_field import ( + ConditionalRandomField, +) class ConditionalRandomFieldWeightEmission(ConditionalRandomField): diff --git a/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wlannoy.py b/allennlp/modules/conditional_random_field/conditional_random_field_wlannoy.py similarity index 98% rename from allennlp/modules/conditional_random_field_weighted/conditional_random_field_wlannoy.py rename to allennlp/modules/conditional_random_field/conditional_random_field_wlannoy.py index 7b62c63a3fd..2ee84d4d3da 100644 --- a/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wlannoy.py +++ b/allennlp/modules/conditional_random_field/conditional_random_field_wlannoy.py @@ -6,7 +6,9 @@ import torch from allennlp.common.checks import ConfigurationError -from allennlp.modules.conditional_random_field import ConditionalRandomField +from allennlp.modules.conditional_random_field.conditional_random_field import ( + ConditionalRandomField, +) class ConditionalRandomFieldWeightLannoy(ConditionalRandomField): diff --git a/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wtrans.py b/allennlp/modules/conditional_random_field/conditional_random_field_wtrans.py similarity index 97% rename from allennlp/modules/conditional_random_field_weighted/conditional_random_field_wtrans.py rename to allennlp/modules/conditional_random_field/conditional_random_field_wtrans.py index bc514542ea0..065810e4f72 100644 --- a/allennlp/modules/conditional_random_field_weighted/conditional_random_field_wtrans.py +++ b/allennlp/modules/conditional_random_field/conditional_random_field_wtrans.py @@ -6,7 +6,9 @@ import torch from allennlp.common.checks import ConfigurationError -from allennlp.modules.conditional_random_field import ConditionalRandomField +from allennlp.modules.conditional_random_field.conditional_random_field import ( + ConditionalRandomField, +) class ConditionalRandomFieldWeightTrans(ConditionalRandomField): diff --git a/allennlp/modules/conditional_random_field_weighted/__init__.py b/allennlp/modules/conditional_random_field_weighted/__init__.py deleted file mode 100644 index 1e4d4daeead..00000000000 --- a/allennlp/modules/conditional_random_field_weighted/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from allennlp.modules.conditional_random_field_weighted.conditional_random_field_wemission import ( - ConditionalRandomFieldWeightEmission, -) -from allennlp.modules.conditional_random_field_weighted.conditional_random_field_wtrans import ( - ConditionalRandomFieldWeightTrans, -) -from allennlp.modules.conditional_random_field_weighted.conditional_random_field_wlannoy import ( - ConditionalRandomFieldWeightLannoy, -) diff --git a/tests/modules/conditional_random_field_test.py b/tests/modules/conditional_random_field_test.py index 3ec5829e94f..3e22ac54d82 100644 --- a/tests/modules/conditional_random_field_test.py +++ b/tests/modules/conditional_random_field_test.py @@ -6,10 +6,12 @@ from numpy.testing import assert_allclose from allennlp.modules import ConditionalRandomField -from allennlp.modules.conditional_random_field_weighted import ConditionalRandomFieldWeightEmission -from allennlp.modules.conditional_random_field_weighted import ConditionalRandomFieldWeightTrans -from allennlp.modules.conditional_random_field_weighted import ConditionalRandomFieldWeightLannoy -from allennlp.modules.conditional_random_field import allowed_transitions +from allennlp.modules.conditional_random_field import ( + ConditionalRandomFieldWeightEmission, + ConditionalRandomFieldWeightTrans, + ConditionalRandomFieldWeightLannoy, +) +from allennlp.modules.conditional_random_field.conditional_random_field import allowed_transitions from allennlp.common.checks import ConfigurationError from allennlp.common.testing import AllenNlpTestCase From 748bac6be6c45b899bcc415ece000903de7b1ecc Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Wed, 13 Jul 2022 23:20:13 +0200 Subject: [PATCH 10/10] Updated changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a272530663..70a22f8c927 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added metric `FBetaVerboseMeasure` which extends `FBetaMeasure` to ensure compatibility with logging plugins and add some options. +- Added three sample weighting techniques to `ConditionalRandomField` by supplying three new subclasses: `ConditionalRandomFieldWeightEmission`, `ConditionalRandomFieldWeightTrans`, and `ConditionalRandomFieldWeightLannoy`. ### Fixed