diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a10e18b3..02ea1a2c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added + +- Changed the token-based verbose metric in the `CrfTagger` model (when `verbose_metrics` is `True` and `calculate_span_f1` is `False`) to be `FBetaVerboseMeasure` instead of `FBetaMeasure`. +- Added option `weight_strategy` to `CrfTagger` in order to support three sample weighting techniques. + ## [v2.9.3](https://github.com/allenai/allennlp-models/releases/tag/v2.9.3) - 2022-04-14 ### Added diff --git a/allennlp_models/tagging/models/crf_tagger.py b/allennlp_models/tagging/models/crf_tagger.py index 5d2277959..9cb554782 100644 --- a/allennlp_models/tagging/models/crf_tagger.py +++ b/allennlp_models/tagging/models/crf_tagger.py @@ -7,12 +7,20 @@ from allennlp.common.checks import check_dimensions_match, ConfigurationError from allennlp.data import TextFieldTensors, Vocabulary from allennlp.modules import Seq2SeqEncoder, TimeDistributed, TextFieldEmbedder -from allennlp.modules import ConditionalRandomField, FeedForward -from allennlp.modules.conditional_random_field import allowed_transitions +from allennlp.modules import ( + ConditionalRandomField, + FeedForward, +) +from allennlp.modules.conditional_random_field import ( + ConditionalRandomFieldWeightEmission, + ConditionalRandomFieldWeightTrans, + ConditionalRandomFieldWeightLannoy, +) +from allennlp.modules.conditional_random_field.conditional_random_field import allowed_transitions from allennlp.models.model import Model from allennlp.nn import InitializerApplicator import allennlp.nn.util as util -from allennlp.training.metrics import CategoricalAccuracy, SpanBasedF1Measure +from allennlp.training.metrics import CategoricalAccuracy, SpanBasedF1Measure, FBetaVerboseMeasure @Model.register("crf_tagger") @@ -72,6 +80,23 @@ class CrfTagger(Model): If True, we compute the loss only for actual spans in `tags`, and not on `O` tokens. This is useful for computing gradients of the loss on a _single span_, for interpretation / attacking. + label_weights : `Dict[str, float]`, optional (default=`None`) + A mapping {label : weight} to be used in the loss function in order to + give different weights for each token depending on its label. This is useful to + deal with highly unbalanced datasets. There are three available strategies to deal + with weighted labels (see below). The default strategy is "emission". + weight_strategy : `str`, optional (default=`None`) + If `label_weights` is given and this is `None`, then it is the same as "emission". + It indicates which strategy is used to sample weighting. Valid options are: + "emission", "emission_transition", "lannoy". If "emission" then the emission score + of each tag is multiplied by the corresponding weight (as given by `label_weights`). + If "emission_transition", both emission and transition scores of each tag are multiplied + by the corresponding weight. In this case, a transition score `t(i,j)`, between consecutive + tokens `i` and `j`, is multiplied by `w[tags[i]]`, i.e., the weight related to the tag of token `i`. + If `weight_strategy` is "lannoy" then we use the strategy proposed by + [Lannoy et al. (2019)](https://perso.uclouvain.be/michel.verleysen/papers/ieeetbe12gdl.pdf). + You can see an experimental comparison among these three strategies and a brief discussion + of their differences [here](https://eraldoluis.github.io/2022/05/10/weighted-crf.html). """ def __init__( @@ -90,6 +115,8 @@ def __init__( initializer: InitializerApplicator = InitializerApplicator(), top_k: int = 1, ignore_loss_on_o_tags: bool = False, + label_weights: Optional[Dict[str, float]] = None, + weight_strategy: str = None, **kwargs, ) -> None: super().__init__(vocab, **kwargs) @@ -132,15 +159,64 @@ def __init__( else: constraints = None + # Label weights are given as a dict {label: weight} but we convert it to a list of weights for each label, + # and weights for omitted labels are set to 1. + if label_weights is None: + if weight_strategy is not None: + raise ConfigurationError( + "`weight_strategy` can only be used when `label_weights` is given" + ) + + # ordinary CRF (not weighted) + self.crf = ConditionalRandomField( + self.num_tags, + constraints, + include_start_end_transitions, + ) + else: # label_weights is not None + label_to_index = vocab.get_token_to_index_vocabulary(label_namespace) + self.label_weights = [1.0] * len(label_to_index) + for label, weight in label_weights.items(): + try: + self.label_weights[label_to_index[label]] = weight + except KeyError: + raise ConfigurationError( + f"'{label}' not found in vocab namespace '{label_namespace}')" + ) + + if weight_strategy is None or weight_strategy == "emission": + self.crf = ConditionalRandomFieldWeightEmission( + self.num_tags, + self.label_weights, + constraints, + include_start_end_transitions, + ) + elif weight_strategy == "emission_transition": + self.crf = ConditionalRandomFieldWeightTrans( + self.num_tags, + self.label_weights, + constraints, + include_start_end_transitions, + ) + elif weight_strategy == "lannoy": + self.crf = ConditionalRandomFieldWeightLannoy( + self.num_tags, + self.label_weights, + constraints, + include_start_end_transitions, + ) + else: + raise ConfigurationError( + "weight_strategy must be one of 'emission', 'emission_transition' or 'lannoy'" + ) + self.include_start_end_transitions = include_start_end_transitions - self.crf = ConditionalRandomField( - self.num_tags, constraints, include_start_end_transitions=include_start_end_transitions - ) self.metrics = { "accuracy": CategoricalAccuracy(), "accuracy3": CategoricalAccuracy(top_k=3), } + self.calculate_span_f1 = calculate_span_f1 if calculate_span_f1: if not label_encoding: @@ -150,6 +226,11 @@ def __init__( self._f1_metric = SpanBasedF1Measure( vocab, tag_namespace=label_namespace, label_encoding=label_encoding ) + elif verbose_metrics: + # verbose metrics for token classification (not span-based) + self._f_beta_measure = FBetaVerboseMeasure( + index_to_label=vocab.get_index_to_token_vocabulary(label_namespace), + ) check_dimensions_match( text_field_embedder.get_output_dim(), @@ -191,7 +272,7 @@ def forward( A torch tensor representing the sequence of integer gold class labels of shape `(batch_size, num_tokens)`. metadata : `List[Dict[str, Any]]`, optional, (default = `None`) - metadata containg the original words in the sentence to be tagged under a 'words' key. + metadata containing the original words in the sentence to be tagged under a 'words' key. ignore_loss_on_o_tags : `Optional[bool]`, optional (default = `None`) If True, we compute the loss only for actual spans in `tags`, and not on `O` tokens. This is useful for computing gradients of the loss on a _single span_, for @@ -263,6 +344,8 @@ def forward( metric(class_probabilities, tags, mask) if self.calculate_span_f1: self._f1_metric(class_probabilities, tags, mask) + elif self._verbose_metrics: + self._f_beta_measure(class_probabilities, tags, mask) if metadata is not None: output["words"] = [x["words"] for x in metadata] return output @@ -305,6 +388,11 @@ def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics_to_return.update(f1_dict) else: metrics_to_return.update({x: y for x, y in f1_dict.items() if "overall" in x}) + elif self._verbose_metrics: + # verbose metrics for token classification (not span-based) + f_beta_dict = self._f_beta_measure.get_metric(reset=reset) + metrics_to_return.update(f_beta_dict) + return metrics_to_return default_predictor = "sentence_tagger" diff --git a/tests/tagging/models/crf_tagger_label_weights_test.py b/tests/tagging/models/crf_tagger_label_weights_test.py new file mode 100644 index 000000000..442530890 --- /dev/null +++ b/tests/tagging/models/crf_tagger_label_weights_test.py @@ -0,0 +1,128 @@ +from flaky import flaky +import pytest + +from allennlp.commands.train import train_model_from_file +from allennlp.common.testing import ModelTestCase +from allennlp.common.checks import ConfigurationError + +from tests import FIXTURES_ROOT + + +class CrfTaggerLabelWeightsTest(ModelTestCase): + def setup_method(self): + super().setup_method() + self.set_up_model( + FIXTURES_ROOT / "tagging" / "crf_tagger" / "experiment.json", + FIXTURES_ROOT / "tagging" / "conll2003.txt", + ) + + def test_label_weights_effectiveness(self): + training_tensors = self.dataset.as_tensor_dict() + save_dir = self.TEST_DIR / "save_and_load_test" + + # original CRF + output_dict_original = self.model(**training_tensors) + + # weighted CRF + model_weighted = train_model_from_file( + self.param_file, + save_dir, + overrides={"model.label_weights": {"I-ORG": 10.0}}, + force=True, + return_model=True, + ) + output_dict_weighted = model_weighted(**training_tensors) + + # assert that logits are substantially different + assert ( + output_dict_weighted["logits"].isclose(output_dict_original["logits"]).sum() + < output_dict_original["logits"].numel() / 2 + ) + + def test_label_weights_effectiveness_emission_transition(self): + training_tensors = self.dataset.as_tensor_dict() + save_dir = self.TEST_DIR / "save_and_load_test" + + # original CRF + output_dict_original = self.model(**training_tensors) + + # weighted CRF + model_weighted = train_model_from_file( + self.param_file, + save_dir, + overrides={ + "model.label_weights": {"I-ORG": 10.0}, + "model.weight_strategy": "emission_transition", + }, + force=True, + return_model=True, + ) + output_dict_weighted = model_weighted(**training_tensors) + + # assert that logits are substantially different + assert ( + output_dict_weighted["logits"].isclose(output_dict_original["logits"]).sum() + < output_dict_original["logits"].numel() / 2 + ) + + def test_label_weights_effectiveness_lannoy(self): + training_tensors = self.dataset.as_tensor_dict() + save_dir = self.TEST_DIR / "save_and_load_test" + + # original CRF + output_dict_original = self.model(**training_tensors) + + # weighted CRF + model_weighted = train_model_from_file( + self.param_file, + save_dir, + overrides={ + "model.label_weights": {"I-ORG": 10.0}, + "model.weight_strategy": "lannoy", + }, + force=True, + return_model=True, + ) + output_dict_weighted = model_weighted(**training_tensors) + + # assert that logits are substantially different + assert ( + output_dict_weighted["logits"].isclose(output_dict_original["logits"]).sum() + < output_dict_original["logits"].numel() / 2 + ) + + def test_config_error_invalid_label(self): + save_dir = self.TEST_DIR / "save_and_load_test" + with pytest.raises(ConfigurationError): + train_model_from_file( + self.param_file, + save_dir, + overrides={"model.label_weights": {"BLA": 10.0}}, + force=True, + return_model=True, + ) + + def test_config_error_strategy_without_weights(self): + save_dir = self.TEST_DIR / "save_and_load_test" + with pytest.raises(ConfigurationError): + train_model_from_file( + self.param_file, + save_dir, + overrides={"model.weight_strategy": "emission"}, + force=True, + return_model=True, + ) + + def test_config_error_invalid_strategy(self): + save_dir = self.TEST_DIR / "save_and_load_test" + with pytest.raises(ConfigurationError): + train_model_from_file( + self.param_file, + save_dir, + overrides={ + "model.label_weights": {"I-ORG": 10.0}, + "model.weight_strategy": "invalid", + }, + force=True, + return_model=True, + ) diff --git a/tests/tagging/models/crf_tagger_test.py b/tests/tagging/models/crf_tagger_test.py index dd3e85272..56c25e3e9 100644 --- a/tests/tagging/models/crf_tagger_test.py +++ b/tests/tagging/models/crf_tagger_test.py @@ -1,6 +1,7 @@ from flaky import flaky import pytest +from allennlp.commands.train import train_model_from_file from allennlp.common.testing import ModelTestCase from allennlp.common.checks import ConfigurationError from allennlp.common.params import Params @@ -77,3 +78,25 @@ def test_mismatching_dimensions_throws_configuration_error(self): params["model"]["encoder"]["input_size"] = 10 with pytest.raises(ConfigurationError): Model.from_params(vocab=self.vocab, params=params.pop("model")) + + def test_token_based_verbose_metrics(self): + training_tensors = self.dataset.as_tensor_dict() + save_dir = self.TEST_DIR / "save_and_load_test" + + model = train_model_from_file( + self.param_file, + save_dir, + overrides={ + "model.calculate_span_f1": False, + "model.verbose_metrics": True, + }, + force=True, + return_model=True, + ) + model(**training_tensors) + metrics = model.get_metrics() + + # assert that metrics contain all verbose keys + for tag in ["O", "I-PER", "I-ORG", "I-LOC", "micro", "macro", "weighted"]: + for m in ["precision", "recall", "fscore"]: + assert f"{tag}-{m}" in metrics