From 012ec547531ff7ae9802b70f88188a0aa55b9293 Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Fri, 9 Oct 2020 13:05:24 -0400 Subject: [PATCH 01/13] (rebase) Weighted CRF: scaled emission scores --- allennlp_models/tagging/models/crf_tagger.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/allennlp_models/tagging/models/crf_tagger.py b/allennlp_models/tagging/models/crf_tagger.py index 5d2277959..137920898 100644 --- a/allennlp_models/tagging/models/crf_tagger.py +++ b/allennlp_models/tagging/models/crf_tagger.py @@ -72,6 +72,13 @@ class CrfTagger(Model): If True, we compute the loss only for actual spans in `tags`, and not on `O` tokens. This is useful for computing gradients of the loss on a _single span_, for interpretation / attacking. + label_weights : `Dict[str, float]`, optional (default=`None`) + An optional mapping {label -> weight} to be used in the loss function in order to + give different weights for each token depending on its label. This is useful to + deal with highly unbalanced datasets. The method implemented here was based on + the paper *Weighted conditional random fields for supervised interpatient heartbeat + classification* proposed by De Lannoy et. al (2019). + See https://perso.uclouvain.be/michel.verleysen/papers/ieeetbe12gdl.pdf """ def __init__( @@ -90,6 +97,7 @@ def __init__( initializer: InitializerApplicator = InitializerApplicator(), top_k: int = 1, ignore_loss_on_o_tags: bool = False, + label_weights: Optional[Dict[str, float]] = None, **kwargs, ) -> None: super().__init__(vocab, **kwargs) @@ -137,6 +145,18 @@ def __init__( self.num_tags, constraints, include_start_end_transitions=include_start_end_transitions ) + # Label weights are given as a mapping {label -> weight} + # We convert it to a list of weights for each label. + # Weights for ommited labels are set to 1. + if label_weights is not None: + label_to_index = vocab.get_token_to_index_vocabulary(label_namespace) + self.label_weights = [1.0]*len(label_to_index) + for label, weight in label_weights.items(): + try: + self.label_weights[label_to_index[label]] = weight + except KeyError: + raise KeyError(f"'{label}' not found in vocab namespace '{label_namespace}')") + self.metrics = { "accuracy": CategoricalAccuracy(), "accuracy3": CategoricalAccuracy(top_k=3), From bb3e69544b324b593f817a0e24ffdb64ba7671a1 Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Mon, 28 Feb 2022 22:58:15 +0100 Subject: [PATCH 02/13] Added FBetaMeasure to CrfTagger just to test class weights --- allennlp_models/tagging/models/crf_tagger.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/allennlp_models/tagging/models/crf_tagger.py b/allennlp_models/tagging/models/crf_tagger.py index 137920898..731a42d2e 100644 --- a/allennlp_models/tagging/models/crf_tagger.py +++ b/allennlp_models/tagging/models/crf_tagger.py @@ -12,7 +12,7 @@ from allennlp.models.model import Model from allennlp.nn import InitializerApplicator import allennlp.nn.util as util -from allennlp.training.metrics import CategoricalAccuracy, SpanBasedF1Measure +from allennlp.training.metrics import CategoricalAccuracy, SpanBasedF1Measure, FBetaMeasure @Model.register("crf_tagger") @@ -160,6 +160,8 @@ def __init__( self.metrics = { "accuracy": CategoricalAccuracy(), "accuracy3": CategoricalAccuracy(top_k=3), + # TODO test for weighted CRF + "_f_beta_measure": FBetaMeasure() } self.calculate_span_f1 = calculate_span_f1 if calculate_span_f1: From b264389c85ca2674821131c816a54252ab1e8304 Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Fri, 11 Mar 2022 22:37:53 +0100 Subject: [PATCH 03/13] Added FBetaMeasure2 to CrfTagger. --- allennlp_models/tagging/models/crf_tagger.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/allennlp_models/tagging/models/crf_tagger.py b/allennlp_models/tagging/models/crf_tagger.py index 731a42d2e..b5456db07 100644 --- a/allennlp_models/tagging/models/crf_tagger.py +++ b/allennlp_models/tagging/models/crf_tagger.py @@ -12,7 +12,7 @@ from allennlp.models.model import Model from allennlp.nn import InitializerApplicator import allennlp.nn.util as util -from allennlp.training.metrics import CategoricalAccuracy, SpanBasedF1Measure, FBetaMeasure +from allennlp.training.metrics import CategoricalAccuracy, SpanBasedF1Measure, FBetaMeasure2 @Model.register("crf_tagger") @@ -159,10 +159,9 @@ def __init__( self.metrics = { "accuracy": CategoricalAccuracy(), - "accuracy3": CategoricalAccuracy(top_k=3), - # TODO test for weighted CRF - "_f_beta_measure": FBetaMeasure() + "accuracy3": CategoricalAccuracy(top_k=3) } + self.calculate_span_f1 = calculate_span_f1 if calculate_span_f1: if not label_encoding: @@ -172,6 +171,10 @@ def __init__( self._f1_metric = SpanBasedF1Measure( vocab, tag_namespace=label_namespace, label_encoding=label_encoding ) + elif verbose_metrics: + # verbose metrics for token classification (not span-based) + self._f_beta_measure = FBetaMeasure2(index_to_label=vocab.get_index_to_token_vocabulary(label_namespace)) + check_dimensions_match( text_field_embedder.get_output_dim(), @@ -285,6 +288,8 @@ def forward( metric(class_probabilities, tags, mask) if self.calculate_span_f1: self._f1_metric(class_probabilities, tags, mask) + elif self._verbose_metrics: + self._f_beta_measure(class_probabilities, tags, mask) if metadata is not None: output["words"] = [x["words"] for x in metadata] return output @@ -327,6 +332,11 @@ def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics_to_return.update(f1_dict) else: metrics_to_return.update({x: y for x, y in f1_dict.items() if "overall" in x}) + elif self._verbose_metrics: + # verbose metrics for token classification (not span-based) + f_beta_dict = self._f_beta_measure.get_metric(reset=reset) + metrics_to_return.update(f_beta_dict) + return metrics_to_return default_predictor = "sentence_tagger" From 6bdaf2a1ab7e56a21371df949e7d0562824f8878 Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Fri, 11 Mar 2022 19:27:28 -0400 Subject: [PATCH 04/13] Fixed bug regarding label_weights in CrfTagger --- allennlp_models/tagging/models/crf_tagger.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/allennlp_models/tagging/models/crf_tagger.py b/allennlp_models/tagging/models/crf_tagger.py index b5456db07..14f325202 100644 --- a/allennlp_models/tagging/models/crf_tagger.py +++ b/allennlp_models/tagging/models/crf_tagger.py @@ -140,11 +140,6 @@ def __init__( else: constraints = None - self.include_start_end_transitions = include_start_end_transitions - self.crf = ConditionalRandomField( - self.num_tags, constraints, include_start_end_transitions=include_start_end_transitions - ) - # Label weights are given as a mapping {label -> weight} # We convert it to a list of weights for each label. # Weights for ommited labels are set to 1. @@ -157,6 +152,13 @@ def __init__( except KeyError: raise KeyError(f"'{label}' not found in vocab namespace '{label_namespace}')") + self.include_start_end_transitions = include_start_end_transitions + self.crf = ConditionalRandomField( + self.num_tags, constraints, + include_start_end_transitions=include_start_end_transitions, + label_weights=self.label_weights + ) + self.metrics = { "accuracy": CategoricalAccuracy(), "accuracy3": CategoricalAccuracy(top_k=3) From 63c7fe28540989220652223d3322595a6f91dfde Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Sat, 12 Mar 2022 07:13:53 -0400 Subject: [PATCH 05/13] CrfTagger: using micro and macro average for FBetaMeasure2 --- allennlp_models/tagging/models/crf_tagger.py | 25 ++++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/allennlp_models/tagging/models/crf_tagger.py b/allennlp_models/tagging/models/crf_tagger.py index 14f325202..de8321ffa 100644 --- a/allennlp_models/tagging/models/crf_tagger.py +++ b/allennlp_models/tagging/models/crf_tagger.py @@ -142,10 +142,11 @@ def __init__( # Label weights are given as a mapping {label -> weight} # We convert it to a list of weights for each label. - # Weights for ommited labels are set to 1. + # Weights for omitted labels are set to 1. + self.label_weights = None if label_weights is not None: label_to_index = vocab.get_token_to_index_vocabulary(label_namespace) - self.label_weights = [1.0]*len(label_to_index) + self.label_weights = [1.0] * len(label_to_index) for label, weight in label_weights.items(): try: self.label_weights[label_to_index[label]] = weight @@ -154,16 +155,17 @@ def __init__( self.include_start_end_transitions = include_start_end_transitions self.crf = ConditionalRandomField( - self.num_tags, constraints, - include_start_end_transitions=include_start_end_transitions, - label_weights=self.label_weights + self.num_tags, + constraints, + include_start_end_transitions=include_start_end_transitions, + label_weights=self.label_weights, ) self.metrics = { "accuracy": CategoricalAccuracy(), - "accuracy3": CategoricalAccuracy(top_k=3) + "accuracy3": CategoricalAccuracy(top_k=3), } - + self.calculate_span_f1 = calculate_span_f1 if calculate_span_f1: if not label_encoding: @@ -175,8 +177,11 @@ def __init__( ) elif verbose_metrics: # verbose metrics for token classification (not span-based) - self._f_beta_measure = FBetaMeasure2(index_to_label=vocab.get_index_to_token_vocabulary(label_namespace)) - + self._f_beta_measure = FBetaMeasure2( + index_to_label=vocab.get_index_to_token_vocabulary(label_namespace), + # TODO included to test weighted CRF but it should be included in CrfTagger options + average=["micro", "macro"], + ) check_dimensions_match( text_field_embedder.get_output_dim(), @@ -218,7 +223,7 @@ def forward( A torch tensor representing the sequence of integer gold class labels of shape `(batch_size, num_tokens)`. metadata : `List[Dict[str, Any]]`, optional, (default = `None`) - metadata containg the original words in the sentence to be tagged under a 'words' key. + metadata containing the original words in the sentence to be tagged under a 'words' key. ignore_loss_on_o_tags : `Optional[bool]`, optional (default = `None`) If True, we compute the loss only for actual spans in `tags`, and not on `O` tokens. This is useful for computing gradients of the loss on a _single span_, for From 0f1032559506340401f01c850c9aa1e92e3b6d39 Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Wed, 6 Apr 2022 17:22:47 -0400 Subject: [PATCH 06/13] CRF weighting strategies --- allennlp_models/tagging/models/crf_tagger.py | 66 ++++++++++++++++---- 1 file changed, 54 insertions(+), 12 deletions(-) diff --git a/allennlp_models/tagging/models/crf_tagger.py b/allennlp_models/tagging/models/crf_tagger.py index de8321ffa..42cbdf382 100644 --- a/allennlp_models/tagging/models/crf_tagger.py +++ b/allennlp_models/tagging/models/crf_tagger.py @@ -7,7 +7,12 @@ from allennlp.common.checks import check_dimensions_match, ConfigurationError from allennlp.data import TextFieldTensors, Vocabulary from allennlp.modules import Seq2SeqEncoder, TimeDistributed, TextFieldEmbedder -from allennlp.modules import ConditionalRandomField, FeedForward +from allennlp.modules import ( + ConditionalRandomField, + ConditionalRandomFieldWeightTrans, + ConditionalRandomFieldLannoy, + FeedForward, +) from allennlp.modules.conditional_random_field import allowed_transitions from allennlp.models.model import Model from allennlp.nn import InitializerApplicator @@ -73,12 +78,20 @@ class CrfTagger(Model): This is useful for computing gradients of the loss on a _single span_, for interpretation / attacking. label_weights : `Dict[str, float]`, optional (default=`None`) - An optional mapping {label -> weight} to be used in the loss function in order to + A mapping {label : weight} to be used in the loss function in order to give different weights for each token depending on its label. This is useful to - deal with highly unbalanced datasets. The method implemented here was based on - the paper *Weighted conditional random fields for supervised interpatient heartbeat - classification* proposed by De Lannoy et. al (2019). - See https://perso.uclouvain.be/michel.verleysen/papers/ieeetbe12gdl.pdf + deal with highly unbalanced datasets. There are three available methods to deal + with weighted labels (see below). + weight_strategy : `str`, optional (default=`None`) + Only allowed when `label_weights` is not `None`. It indicates which strategy is + used to weight each tag. Valid options are: "emission", "emission_transition", + "lannoy". If `None` and `label_weights` is not `None`, then "emission" is assumed. + If "emission" then the emission score of each tag is multiplied by the + corresponding weight. If "emission_transition", both emission and transition + scores of each tag are multiplied by the corresponding weight. In this case, + a transition score `t(i,j)`, between consecutive tokens `i` and `j`, is multiplied + by `w[tags[i]]`, i.e., the weight related to the tag of token `i`. + If `weight_strategy` is "lannoy" then we use the strategy proposed by [Lannoy et al. (2019)](https://perso.uclouvain.be/michel.verleysen/papers/ieeetbe12gdl.pdf). """ def __init__( @@ -98,6 +111,7 @@ def __init__( top_k: int = 1, ignore_loss_on_o_tags: bool = False, label_weights: Optional[Dict[str, float]] = None, + weight_strategy: str = None, **kwargs, ) -> None: super().__init__(vocab, **kwargs) @@ -153,13 +167,41 @@ def __init__( except KeyError: raise KeyError(f"'{label}' not found in vocab namespace '{label_namespace}')") + if weight_strategy is None or weight_strategy == "emission": + self.crf = ConditionalRandomField( + self.num_tags, + constraints, + include_start_end_transitions=include_start_end_transitions, + label_weights=self.label_weights, + ) + elif weight_strategy == "emission_transition": + self.crf = ConditionalRandomFieldWeightTrans( + self.num_tags, + constraints, + include_start_end_transitions=include_start_end_transitions, + label_weights=self.label_weights, + ) + elif weight_strategy == "lannoy": + self.crf = ConditionalRandomFieldLannoy( + self.num_tags, + constraints, + include_start_end_transitions=include_start_end_transitions, + label_weights=self.label_weights, + ) + else: + raise ConfigurationError( + "weight_strategy must be one of 'emission', 'emission_transition' or 'lannoy'" + ) + elif weight_strategy is not None: + raise ConfigurationError("weight_strategy is given but label_weights is not") + else: + self.crf = ConditionalRandomField( + self.num_tags, + constraints, + include_start_end_transitions=include_start_end_transitions, + ) + self.include_start_end_transitions = include_start_end_transitions - self.crf = ConditionalRandomField( - self.num_tags, - constraints, - include_start_end_transitions=include_start_end_transitions, - label_weights=self.label_weights, - ) self.metrics = { "accuracy": CategoricalAccuracy(), From 1d7a97f8477059f5c697a5679f96be5ab70c4d47 Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Sat, 18 Jun 2022 23:55:18 +0200 Subject: [PATCH 07/13] Weighted CRF: adjustments considering refactoring --- allennlp_models/tagging/models/crf_tagger.py | 73 ++++++++++---------- 1 file changed, 38 insertions(+), 35 deletions(-) diff --git a/allennlp_models/tagging/models/crf_tagger.py b/allennlp_models/tagging/models/crf_tagger.py index 42cbdf382..a9d86f5a3 100644 --- a/allennlp_models/tagging/models/crf_tagger.py +++ b/allennlp_models/tagging/models/crf_tagger.py @@ -9,6 +9,7 @@ from allennlp.modules import Seq2SeqEncoder, TimeDistributed, TextFieldEmbedder from allennlp.modules import ( ConditionalRandomField, + ConditionalRandomFieldWeightEmission, ConditionalRandomFieldWeightTrans, ConditionalRandomFieldLannoy, FeedForward, @@ -17,7 +18,7 @@ from allennlp.models.model import Model from allennlp.nn import InitializerApplicator import allennlp.nn.util as util -from allennlp.training.metrics import CategoricalAccuracy, SpanBasedF1Measure, FBetaMeasure2 +from allennlp.training.metrics import CategoricalAccuracy, SpanBasedF1Measure, FBetaVerboseMeasure @Model.register("crf_tagger") @@ -80,18 +81,20 @@ class CrfTagger(Model): label_weights : `Dict[str, float]`, optional (default=`None`) A mapping {label : weight} to be used in the loss function in order to give different weights for each token depending on its label. This is useful to - deal with highly unbalanced datasets. There are three available methods to deal - with weighted labels (see below). + deal with highly unbalanced datasets. There are three available strategies to deal + with weighted labels (see below). The default strategy is "emission". weight_strategy : `str`, optional (default=`None`) - Only allowed when `label_weights` is not `None`. It indicates which strategy is - used to weight each tag. Valid options are: "emission", "emission_transition", - "lannoy". If `None` and `label_weights` is not `None`, then "emission" is assumed. - If "emission" then the emission score of each tag is multiplied by the - corresponding weight. If "emission_transition", both emission and transition - scores of each tag are multiplied by the corresponding weight. In this case, - a transition score `t(i,j)`, between consecutive tokens `i` and `j`, is multiplied - by `w[tags[i]]`, i.e., the weight related to the tag of token `i`. - If `weight_strategy` is "lannoy" then we use the strategy proposed by [Lannoy et al. (2019)](https://perso.uclouvain.be/michel.verleysen/papers/ieeetbe12gdl.pdf). + If `label_weights` is given and this is `None`, then it is the same as "emission". + It indicates which strategy is used to sample weighting. Valid options are: + "emission", "emission_transition", "lannoy". If "emission" then the emission score + of each tag is multiplied by the corresponding weight (as given by `label_weights`). + If "emission_transition", both emission and transition scores of each tag are multiplied + by the corresponding weight. In this case, a transition score `t(i,j)`, between consecutive + tokens `i` and `j`, is multiplied by `w[tags[i]]`, i.e., the weight related to the tag of token `i`. + If `weight_strategy` is "lannoy" then we use the strategy proposed by + [Lannoy et al. (2019)](https://perso.uclouvain.be/michel.verleysen/papers/ieeetbe12gdl.pdf). + You can see an experimental comparison among these three strategies and a brief discussion + of their differences [here](https://eraldoluis.github.io/2022/05/10/weighted-crf.html). """ def __init__( @@ -154,11 +157,21 @@ def __init__( else: constraints = None - # Label weights are given as a mapping {label -> weight} - # We convert it to a list of weights for each label. - # Weights for omitted labels are set to 1. - self.label_weights = None - if label_weights is not None: + # Label weights are given as a dict {label: weight} but we convert it to a list of weights for each label, + # and weights for omitted labels are set to 1. + if label_weights is None: + if weight_strategy is not None: + raise ConfigurationError( + "`weight_strategy` can only be used when `label_weights` is given" + ) + + # ordinary CRF (not weighted) + self.crf = ConditionalRandomField( + self.num_tags, + constraints, + include_start_end_transitions, + ) + else: # label_weights is not None label_to_index = vocab.get_token_to_index_vocabulary(label_namespace) self.label_weights = [1.0] * len(label_to_index) for label, weight in label_weights.items(): @@ -168,38 +181,30 @@ def __init__( raise KeyError(f"'{label}' not found in vocab namespace '{label_namespace}')") if weight_strategy is None or weight_strategy == "emission": - self.crf = ConditionalRandomField( + self.crf = ConditionalRandomFieldWeightEmission( self.num_tags, + self.label_weights, constraints, - include_start_end_transitions=include_start_end_transitions, - label_weights=self.label_weights, + include_start_end_transitions, ) elif weight_strategy == "emission_transition": self.crf = ConditionalRandomFieldWeightTrans( self.num_tags, + self.label_weights, constraints, - include_start_end_transitions=include_start_end_transitions, - label_weights=self.label_weights, + include_start_end_transitions, ) elif weight_strategy == "lannoy": self.crf = ConditionalRandomFieldLannoy( self.num_tags, + self.label_weights, constraints, - include_start_end_transitions=include_start_end_transitions, - label_weights=self.label_weights, + include_start_end_transitions, ) else: raise ConfigurationError( "weight_strategy must be one of 'emission', 'emission_transition' or 'lannoy'" ) - elif weight_strategy is not None: - raise ConfigurationError("weight_strategy is given but label_weights is not") - else: - self.crf = ConditionalRandomField( - self.num_tags, - constraints, - include_start_end_transitions=include_start_end_transitions, - ) self.include_start_end_transitions = include_start_end_transitions @@ -219,10 +224,8 @@ def __init__( ) elif verbose_metrics: # verbose metrics for token classification (not span-based) - self._f_beta_measure = FBetaMeasure2( + self._f_beta_measure = FBetaVerboseMeasure( index_to_label=vocab.get_index_to_token_vocabulary(label_namespace), - # TODO included to test weighted CRF but it should be included in CrfTagger options - average=["micro", "macro"], ) check_dimensions_match( From 926d3a877596a3bad5444683b45c4b43f9fafbd1 Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Wed, 22 Jun 2022 22:30:21 +0200 Subject: [PATCH 08/13] Weighted CRF tests --- allennlp_models/tagging/models/crf_tagger.py | 2 +- .../models/crf_tagger_label_weights_test.py | 137 ++++++++++++++++++ 2 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 tests/tagging/models/crf_tagger_label_weights_test.py diff --git a/allennlp_models/tagging/models/crf_tagger.py b/allennlp_models/tagging/models/crf_tagger.py index a9d86f5a3..c83f4791e 100644 --- a/allennlp_models/tagging/models/crf_tagger.py +++ b/allennlp_models/tagging/models/crf_tagger.py @@ -178,7 +178,7 @@ def __init__( try: self.label_weights[label_to_index[label]] = weight except KeyError: - raise KeyError(f"'{label}' not found in vocab namespace '{label_namespace}')") + raise ConfigurationError(f"'{label}' not found in vocab namespace '{label_namespace}')") if weight_strategy is None or weight_strategy == "emission": self.crf = ConditionalRandomFieldWeightEmission( diff --git a/tests/tagging/models/crf_tagger_label_weights_test.py b/tests/tagging/models/crf_tagger_label_weights_test.py new file mode 100644 index 000000000..44191b0d7 --- /dev/null +++ b/tests/tagging/models/crf_tagger_label_weights_test.py @@ -0,0 +1,137 @@ +from flaky import flaky +import pytest + +from allennlp.commands.train import train_model_from_file +from allennlp.common.testing import ModelTestCase +from allennlp.common.checks import ConfigurationError + +from tests import FIXTURES_ROOT + + +class CrfTaggerLabelWeightsTest(ModelTestCase): + def setup_method(self): + super().setup_method() + self.set_up_model( + FIXTURES_ROOT / "tagging" / "crf_tagger" / "experiment.json", + FIXTURES_ROOT / "tagging" / "conll2003.txt", + ) + + def test_label_weights_effectiveness(self): + training_tensors = self.dataset.as_tensor_dict() + save_dir = self.TEST_DIR / "save_and_load_test" + + # original CRF + model_original = train_model_from_file( + self.param_file, save_dir, force=True, return_model=True + ) + output_dict_original = model_original(**training_tensors) + + # weighted CRF + model_weighted = train_model_from_file( + self.param_file, + save_dir, + overrides={"model.label_weights": {"I-ORG": 10.0}}, + force=True, + return_model=True, + ) + output_dict_weighted = model_weighted(**training_tensors) + + # assert that logits are substantially different + assert ( + output_dict_weighted["logits"].isclose(output_dict_original["logits"]).sum() + < output_dict_original["logits"].numel() / 2 + ) + + def test_label_weights_effectiveness_emission_transition(self): + training_tensors = self.dataset.as_tensor_dict() + save_dir = self.TEST_DIR / "save_and_load_test" + + # original CRF + model_original = train_model_from_file( + self.param_file, save_dir, force=True, return_model=True + ) + output_dict_original = model_original(**training_tensors) + + # weighted CRF + model_weighted = train_model_from_file( + self.param_file, + save_dir, + overrides={ + "model.label_weights": {"I-ORG": 10.0}, + "model.weight_strategy": "emission_transition", + }, + force=True, + return_model=True, + ) + output_dict_weighted = model_weighted(**training_tensors) + + # assert that logits are substantially different + assert ( + output_dict_weighted["logits"].isclose(output_dict_original["logits"]).sum() + < output_dict_original["logits"].numel() / 2 + ) + + def test_label_weights_effectiveness_lannoy(self): + training_tensors = self.dataset.as_tensor_dict() + save_dir = self.TEST_DIR / "save_and_load_test" + + # original CRF + model_original = train_model_from_file( + self.param_file, save_dir, force=True, return_model=True + ) + output_dict_original = model_original(**training_tensors) + + # weighted CRF + model_weighted = train_model_from_file( + self.param_file, + save_dir, + overrides={ + "model.label_weights": {"I-ORG": 10.0}, + "model.weight_strategy": "lannoy", + }, + force=True, + return_model=True, + ) + output_dict_weighted = model_weighted(**training_tensors) + + # assert that logits are substantially different + assert ( + output_dict_weighted["logits"].isclose(output_dict_original["logits"]).sum() + < output_dict_original["logits"].numel() / 2 + ) + + def test_config_error_invalid_label(self): + save_dir = self.TEST_DIR / "save_and_load_test" + with pytest.raises(ConfigurationError): + train_model_from_file( + self.param_file, + save_dir, + overrides={"model.label_weights": {"BLA": 10.0}}, + force=True, + return_model=True, + ) + + def test_config_error_strategy_without_weights(self): + save_dir = self.TEST_DIR / "save_and_load_test" + with pytest.raises(ConfigurationError): + train_model_from_file( + self.param_file, + save_dir, + overrides={"model.weight_strategy": "emission"}, + force=True, + return_model=True, + ) + + def test_config_error_invalid_strategy(self): + save_dir = self.TEST_DIR / "save_and_load_test" + with pytest.raises(ConfigurationError): + train_model_from_file( + self.param_file, + save_dir, + overrides={ + "model.label_weights": {"I-ORG": 10.0}, + "model.weight_strategy": "invalid", + }, + force=True, + return_model=True, + ) From c0d9798a4c41bb52e4bc7490c43d8b05db011cc5 Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Wed, 22 Jun 2022 23:12:51 +0200 Subject: [PATCH 09/13] Weighted CRF: tests minor adjustments --- .../models/crf_tagger_label_weights_test.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/tests/tagging/models/crf_tagger_label_weights_test.py b/tests/tagging/models/crf_tagger_label_weights_test.py index 44191b0d7..442530890 100644 --- a/tests/tagging/models/crf_tagger_label_weights_test.py +++ b/tests/tagging/models/crf_tagger_label_weights_test.py @@ -21,10 +21,7 @@ def test_label_weights_effectiveness(self): save_dir = self.TEST_DIR / "save_and_load_test" # original CRF - model_original = train_model_from_file( - self.param_file, save_dir, force=True, return_model=True - ) - output_dict_original = model_original(**training_tensors) + output_dict_original = self.model(**training_tensors) # weighted CRF model_weighted = train_model_from_file( @@ -47,10 +44,7 @@ def test_label_weights_effectiveness_emission_transition(self): save_dir = self.TEST_DIR / "save_and_load_test" # original CRF - model_original = train_model_from_file( - self.param_file, save_dir, force=True, return_model=True - ) - output_dict_original = model_original(**training_tensors) + output_dict_original = self.model(**training_tensors) # weighted CRF model_weighted = train_model_from_file( @@ -76,10 +70,7 @@ def test_label_weights_effectiveness_lannoy(self): save_dir = self.TEST_DIR / "save_and_load_test" # original CRF - model_original = train_model_from_file( - self.param_file, save_dir, force=True, return_model=True - ) - output_dict_original = model_original(**training_tensors) + output_dict_original = self.model(**training_tensors) # weighted CRF model_weighted = train_model_from_file( From ae401d385236ad08038c8dfc4746f0767eff43c5 Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Wed, 22 Jun 2022 23:14:04 +0200 Subject: [PATCH 10/13] CrfTagger: added test regarding FBetaVerboseMeasure --- tests/tagging/models/crf_tagger_test.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/tagging/models/crf_tagger_test.py b/tests/tagging/models/crf_tagger_test.py index dd3e85272..56c25e3e9 100644 --- a/tests/tagging/models/crf_tagger_test.py +++ b/tests/tagging/models/crf_tagger_test.py @@ -1,6 +1,7 @@ from flaky import flaky import pytest +from allennlp.commands.train import train_model_from_file from allennlp.common.testing import ModelTestCase from allennlp.common.checks import ConfigurationError from allennlp.common.params import Params @@ -77,3 +78,25 @@ def test_mismatching_dimensions_throws_configuration_error(self): params["model"]["encoder"]["input_size"] = 10 with pytest.raises(ConfigurationError): Model.from_params(vocab=self.vocab, params=params.pop("model")) + + def test_token_based_verbose_metrics(self): + training_tensors = self.dataset.as_tensor_dict() + save_dir = self.TEST_DIR / "save_and_load_test" + + model = train_model_from_file( + self.param_file, + save_dir, + overrides={ + "model.calculate_span_f1": False, + "model.verbose_metrics": True, + }, + force=True, + return_model=True, + ) + model(**training_tensors) + metrics = model.get_metrics() + + # assert that metrics contain all verbose keys + for tag in ["O", "I-PER", "I-ORG", "I-LOC", "micro", "macro", "weighted"]: + for m in ["precision", "recall", "fscore"]: + assert f"{tag}-{m}" in metrics From 7200ea08435809e0e6a4adc603dbef7d2501b0d6 Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Wed, 22 Jun 2022 23:17:00 +0200 Subject: [PATCH 11/13] CrfTagger: black formatting --- allennlp_models/tagging/models/crf_tagger.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/allennlp_models/tagging/models/crf_tagger.py b/allennlp_models/tagging/models/crf_tagger.py index c83f4791e..c88840392 100644 --- a/allennlp_models/tagging/models/crf_tagger.py +++ b/allennlp_models/tagging/models/crf_tagger.py @@ -178,7 +178,9 @@ def __init__( try: self.label_weights[label_to_index[label]] = weight except KeyError: - raise ConfigurationError(f"'{label}' not found in vocab namespace '{label_namespace}')") + raise ConfigurationError( + f"'{label}' not found in vocab namespace '{label_namespace}')" + ) if weight_strategy is None or weight_strategy == "emission": self.crf = ConditionalRandomFieldWeightEmission( From e7daa535d789957f115c578a192e1ca5211b9049 Mon Sep 17 00:00:00 2001 From: "Eraldo R. Fernandes" Date: Wed, 13 Jul 2022 23:40:56 +0200 Subject: [PATCH 12/13] Updated CrfTagger to the new module organization --- CHANGELOG.md | 7 +++++++ allennlp_models/tagging/models/crf_tagger.py | 10 ++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a10e18b3..03b8cdb24 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added + +- Changed the token-based verbose metric (when `verbose_metrics` is `True` and `calculate_span_f1` is `False`) to be `FBetaVerboseMeasure` instead of `FBetaMeasure`. +- Added option `weight_strategy` to `CrfTagger` in order to support three sample weighting techniques. + ## [v2.9.3](https://github.com/allenai/allennlp-models/releases/tag/v2.9.3) - 2022-04-14 ### Added diff --git a/allennlp_models/tagging/models/crf_tagger.py b/allennlp_models/tagging/models/crf_tagger.py index c88840392..9cb554782 100644 --- a/allennlp_models/tagging/models/crf_tagger.py +++ b/allennlp_models/tagging/models/crf_tagger.py @@ -9,12 +9,14 @@ from allennlp.modules import Seq2SeqEncoder, TimeDistributed, TextFieldEmbedder from allennlp.modules import ( ConditionalRandomField, + FeedForward, +) +from allennlp.modules.conditional_random_field import ( ConditionalRandomFieldWeightEmission, ConditionalRandomFieldWeightTrans, - ConditionalRandomFieldLannoy, - FeedForward, + ConditionalRandomFieldWeightLannoy, ) -from allennlp.modules.conditional_random_field import allowed_transitions +from allennlp.modules.conditional_random_field.conditional_random_field import allowed_transitions from allennlp.models.model import Model from allennlp.nn import InitializerApplicator import allennlp.nn.util as util @@ -197,7 +199,7 @@ def __init__( include_start_end_transitions, ) elif weight_strategy == "lannoy": - self.crf = ConditionalRandomFieldLannoy( + self.crf = ConditionalRandomFieldWeightLannoy( self.num_tags, self.label_weights, constraints, From eb4f1704c4ec53c2b0aab2b7d7788fa5f8bab343 Mon Sep 17 00:00:00 2001 From: Pete Date: Wed, 13 Jul 2022 17:28:18 -0700 Subject: [PATCH 13/13] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 03b8cdb24..02ea1a2c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Changed the token-based verbose metric (when `verbose_metrics` is `True` and `calculate_span_f1` is `False`) to be `FBetaVerboseMeasure` instead of `FBetaMeasure`. +- Changed the token-based verbose metric in the `CrfTagger` model (when `verbose_metrics` is `True` and `calculate_span_f1` is `False`) to be `FBetaVerboseMeasure` instead of `FBetaMeasure`. - Added option `weight_strategy` to `CrfTagger` in order to support three sample weighting techniques. ## [v2.9.3](https://github.com/allenai/allennlp-models/releases/tag/v2.9.3) - 2022-04-14