Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Implementation of Weighted CRF Tagger (handling unbalanced datasets) #5676

Merged
merged 12 commits into from
Jul 14, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added metric `FBetaVerboseMeasure` which extends `FBetaMeasure` to ensure compatibility with logging plugins and add some options.
- Added three sample weighting techniques to `ConditionalRandomField` by supplying three new subclasses: `ConditionalRandomFieldWeightEmission`, `ConditionalRandomFieldWeightTrans`, and `ConditionalRandomFieldWeightLannoy`.

### Fixed

Expand Down
12 changes: 12 additions & 0 deletions allennlp/modules/conditional_random_field/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from allennlp.modules.conditional_random_field.conditional_random_field import (
ConditionalRandomField,
)
from allennlp.modules.conditional_random_field.conditional_random_field_wemission import (
ConditionalRandomFieldWeightEmission,
)
from allennlp.modules.conditional_random_field.conditional_random_field_wtrans import (
ConditionalRandomFieldWeightTrans,
)
from allennlp.modules.conditional_random_field.conditional_random_field_wlannoy import (
ConditionalRandomFieldWeightLannoy,
)
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,21 @@ def reset_parameters(self):
torch.nn.init.normal_(self.start_transitions)
torch.nn.init.normal_(self.end_transitions)

def _input_likelihood(self, logits: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
"""
Computes the (batch_size,) denominator term for the log-likelihood, which is the
sum of the likelihoods across all possible state sequences.
def _input_likelihood(
self, logits: torch.Tensor, transitions: torch.Tensor, mask: torch.BoolTensor
) -> torch.Tensor:
"""Computes the (batch_size,) denominator term $Z(x)$, per example, for the log-likelihood

This is the sum of the likelihoods across all possible state sequences.

Args:
logits (torch.Tensor): a (batch_size, sequence_length num_tags) tensor of
unnormalized log-probabilities
transitions (torch.Tensor): a (batch_size, num_tags, num_tags) tensor of transition scores
mask (torch.BoolTensor): a (batch_size, sequence_length) tensor of masking flags

Returns:
torch.Tensor: (batch_size,) denominator term $Z(x)$, per example, for the log-likelihood
"""
batch_size, sequence_length, num_tags = logits.size()

Expand All @@ -239,7 +250,7 @@ def _input_likelihood(self, logits: torch.Tensor, mask: torch.BoolTensor) -> tor
# The emit scores are for time i ("next_tag") so we broadcast along the current_tag axis.
emit_scores = logits[i].view(batch_size, 1, num_tags)
# Transition scores are (current_tag, next_tag) so we broadcast along the instance axis.
transition_scores = self.transitions.view(1, num_tags, num_tags)
transition_scores = transitions.view(1, num_tags, num_tags)
# Alpha is for the current_tag, so we broadcast along the next_tag axis.
broadcast_alpha = alpha.view(batch_size, num_tags, 1)

Expand All @@ -262,10 +273,23 @@ def _input_likelihood(self, logits: torch.Tensor, mask: torch.BoolTensor) -> tor
return util.logsumexp(stops)

def _joint_likelihood(
self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor
self,
logits: torch.Tensor,
transitions: torch.Tensor,
tags: torch.Tensor,
mask: torch.BoolTensor,
) -> torch.Tensor:
"""
Computes the numerator term for the log-likelihood, which is just score(inputs, tags)
"""Computes the numerator term for the log-likelihood, which is just score(inputs, tags)

Args:
logits (torch.Tensor): a (batch_size, sequence_length num_tags) tensor of unnormalized
log-probabilities
transitions (torch.Tensor): a (batch_size, num_tags, num_tags) tensor of transition scores
tags (torch.Tensor): output tag sequences (batch_size, sequence_length) $y$ for each input sequence
mask (torch.BoolTensor): a (batch_size, sequence_length) tensor of masking flags

Returns:
torch.Tensor: numerator term for the log-likelihood, which is just score(inputs, tags)
"""
batch_size, sequence_length, _ = logits.data.shape

Expand All @@ -286,7 +310,7 @@ def _joint_likelihood(
current_tag, next_tag = tags[i], tags[i + 1]

# The scores for transitioning from current_tag to next_tag
transition_score = self.transitions[current_tag.view(-1), next_tag.view(-1)]
transition_score = transitions[current_tag.view(-1), next_tag.view(-1)]

# The score for using current_tag
emit_score = logits[i].gather(1, current_tag.view(batch_size, 1)).squeeze(1)
Expand Down Expand Up @@ -318,18 +342,25 @@ def _joint_likelihood(
def forward(
self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor = None
) -> torch.Tensor:
"""
Computes the log likelihood.
"""
"""Computes the log likelihood for the given batch of input sequences $(x,y)$

Args:
inputs (torch.Tensor): (batch_size, sequence_length, num_tags) tensor of logits for the inputs $x$
tags (torch.Tensor): (batch_size, sequence_length) tensor of tags $y$
mask (torch.BoolTensor, optional): (batch_size, sequence_length) tensor of masking flags.
Defaults to None.

Returns:
torch.Tensor: (batch_size,) log likelihoods $log P(y|x)$ for each input
"""
if mask is None:
mask = torch.ones(*tags.size(), dtype=torch.bool, device=inputs.device)
else:
# The code below fails in weird ways if this isn't a bool tensor, so we make sure.
mask = mask.to(torch.bool)

log_denominator = self._input_likelihood(inputs, mask)
log_numerator = self._joint_likelihood(inputs, tags, mask)
log_denominator = self._input_likelihood(inputs, self.transitions, mask)
log_numerator = self._joint_likelihood(inputs, self.transitions, tags, mask)

return torch.sum(log_numerator - log_denominator)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
Conditional random field with emission-based weighting
"""
from typing import List, Tuple

import torch

from allennlp.common.checks import ConfigurationError
from allennlp.modules.conditional_random_field.conditional_random_field import (
ConditionalRandomField,
)


class ConditionalRandomFieldWeightEmission(ConditionalRandomField):
"""
This module uses the "forward-backward" algorithm to compute
the log-likelihood of its inputs assuming a conditional random field model.

See, e.g. http://www.cs.columbia.edu/~mcollins/fb.pdf

This is a weighted version of `ConditionalRandomField` which accepts a `label_weights`
parameter to be used in the loss function in order to give different weights for each
token depending on its label. The method implemented here is based on the simple idea
of weighting emission scores using the weight given for the corresponding tag.

There are two other sample weighting methods implemented. You can find more details
about them in: https://eraldoluis.github.io/2022/05/10/weighted-crf.html

# Parameters

num_tags : `int`, required
The number of tags.
label_weights : `List[float]`, required
A list of weights to be used in the loss function in order to
give different weights for each token depending on its label.
`len(label_weights)` must be equal to `num_tags`. This is useful to
deal with highly unbalanced datasets. The method implemented here is
based on the simple idea of weighting emission scores using the weight
given for the corresponding tag.
constraints : `List[Tuple[int, int]]`, optional (default = `None`)
An optional list of allowed transitions (from_tag_id, to_tag_id).
These are applied to `viterbi_tags()` but do not affect `forward()`.
These should be derived from `allowed_transitions` so that the
start and end transitions are handled correctly for your tag type.
include_start_end_transitions : `bool`, optional (default = `True`)
Whether to include the start and end transition parameters.
"""

def __init__(
self,
num_tags: int,
label_weights: List[float],
constraints: List[Tuple[int, int]] = None,
include_start_end_transitions: bool = True,
) -> None:
super().__init__(num_tags, constraints, include_start_end_transitions)

if label_weights is None:
raise ConfigurationError("label_weights must be given")

self.register_buffer("label_weights", torch.Tensor(label_weights))

def forward(
self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor = None
) -> torch.Tensor:
"""Computes the log likelihood for the given batch of input sequences $(x,y)$

Args:
inputs (torch.Tensor): (batch_size, sequence_length, num_tags) tensor of logits for the inputs $x$
tags (torch.Tensor): (batch_size, sequence_length) tensor of tags $y$
mask (torch.BoolTensor, optional): (batch_size, sequence_length) tensor of masking flags.
Defaults to None.

Returns:
torch.Tensor: (batch_size,) log likelihoods $log P(y|x)$ for each input
"""
if mask is None:
mask = torch.ones(*tags.size(), dtype=torch.bool, device=inputs.device)
else:
# The code below fails in weird ways if this isn't a bool tensor, so we make sure.
mask = mask.to(torch.bool)

label_weights = self.label_weights

# scale the logits for all examples and all time steps
inputs = inputs * label_weights.view(1, 1, -1)

log_denominator = self._input_likelihood(inputs, self.transitions, mask)
log_numerator = self._joint_likelihood(inputs, self.transitions, tags, mask)

return torch.sum(log_numerator - log_denominator)
Loading