Skip to content

Commit

Permalink
Optimize RelationClassifier by filtering long sentences
Browse files Browse the repository at this point in the history
  • Loading branch information
alanakbik committed Jan 2, 2025
1 parent 8bc9c28 commit fc786b3
Showing 1 changed file with 51 additions and 28 deletions.
79 changes: 51 additions & 28 deletions flair/models/relation_classifier_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ def __init__(
encoding_strategy: EncodingStrategy = TypedEntityMarker(),
zero_tag_value: str = "O",
allow_unk_tag: bool = True,
max_allowed_tokens_between_entities: int = 50,
max_encoded_sentence_length: int = 100,
**classifierargs,
) -> None:
"""Initializes a `RelationClassifier`.
Expand All @@ -271,13 +273,27 @@ def __init__(
encoding_strategy: An instance of a class conforming the :class:`EncodingStrategy` protocol
zero_tag_value: The label to use for out-of-class relations
allow_unk_tag: If `False`, removes `<unk>` from the passed label dictionary, otherwise do nothing.
max_allowed_tokens_between_entities: The maximum allowed number of allowed tokens between entities. All other entity pairs are filtered from consideration.
max_encoded_sentence_length: The maximum length of encoded sentences. Smaller values speed up processing but potentially remove important context.
classifierargs: The remaining parameters passed to the underlying :class:`flair.models.DefaultClassifier`
"""
# Set label type and prepare label dictionary
self._label_type = label_type
self._zero_tag_value = zero_tag_value
self._allow_unk_tag = allow_unk_tag

if max_encoded_sentence_length - 2 < max_allowed_tokens_between_entities:
logger.warning(
"You set 'max_encoded_sentence_length' to be potentially smaller than 'max_allowed_tokens_between_entities'."
"To ensure that each encoded sentence at least contains the entities in a relation, "
"'max_encoded_sentence_length' should be at least 2 tokens larger than 'max_allowed_tokens_between_entities'."
"I am automatically changing 'max_encoded_sentence_length' to 'max_allowed_tokens_between_entities' + 2"
)
max_encoded_sentence_length = max_allowed_tokens_between_entities + 2

self._max_allowed_tokens_between_entities = max_allowed_tokens_between_entities
self._max_encoded_sentence_length = max_encoded_sentence_length

modified_label_dictionary: Dictionary = Dictionary(add_unk=self._allow_unk_tag)
modified_label_dictionary.add_item(self._zero_tag_value)
for label in label_dictionary.get_items():
Expand Down Expand Up @@ -398,7 +414,7 @@ def _encode_sentence(
head: _Entity,
tail: _Entity,
gold_label: Optional[str] = None,
) -> EncodedSentence:
) -> Optional[EncodedSentence]:
"""Returns a new Sentence object with masked/marked head and tail spans according to the encoding strategy.
If provided, the encoded sentence also has the corresponding gold label annotation from :attr:`~label_type`.
Expand All @@ -422,11 +438,15 @@ def _encode_sentence(
# since there may be multiple occurrences of the same entity mentioned in the sentence.
# Therefore, we use the span's position in the sentence.
encoded_sentence_tokens: list[str] = []
head_idx = None
tail_idx = None
for token in original_sentence:
if token is head.span[0]:
head_idx = len(encoded_sentence_tokens)
encoded_sentence_tokens.append(self.encoding_strategy.encode_head(head.span, head.label))

elif token is tail.span[0]:
tail_idx = len(encoded_sentence_tokens)
encoded_sentence_tokens.append(self.encoding_strategy.encode_tail(tail.span, tail.label))

elif all(
Expand All @@ -435,6 +455,15 @@ def _encode_sentence(
):
encoded_sentence_tokens.append(token.text)

# filter cases in which the distance between the two entities is too large
if abs(head_idx - tail_idx) > self._max_allowed_tokens_between_entities:
return None

# remove excess tokens left and right of entity pair to make encoded sentence shorter
encoded_sentence_tokens = self._slice_encoded_sentence_to_max_allowed_length(
encoded_sentence_tokens, head_idx, tail_idx
)

# Create masked sentence
encoded_sentence: EncodedSentence = EncodedSentence(
" ".join(encoded_sentence_tokens), use_tokenizer=SpaceTokenizer()
Expand All @@ -448,6 +477,23 @@ def _encode_sentence(
encoded_sentence.copy_context_from_sentence(original_sentence)
return encoded_sentence

def _slice_encoded_sentence_to_max_allowed_length(self, encoded_sentence_tokens, head_idx, tail_idx):
if len(encoded_sentence_tokens) > self._max_encoded_sentence_length:
begin_slice = head_idx if head_idx < tail_idx else tail_idx
end_slice = tail_idx if head_idx < tail_idx else head_idx
distance = end_slice - begin_slice
padding_amount = self._max_encoded_sentence_length - distance
padding_per_side = padding_amount // 2
begin_slice = begin_slice - padding_per_side if begin_slice - padding_per_side > 0 else 0
end_slice = (
end_slice + padding_per_side
if end_slice + padding_per_side < len(encoded_sentence_tokens)
else len(encoded_sentence_tokens)
)

encoded_sentence_tokens = encoded_sentence_tokens[begin_slice:end_slice]
return encoded_sentence_tokens

def _encode_sentence_for_inference(
self,
sentence: Sentence,
Expand Down Expand Up @@ -520,6 +566,7 @@ def transform_sentence(self, sentences: Union[Sentence, list[Sentence]]) -> list
encoded_sentence
for sentence in sentences
for encoded_sentence in self._encode_sentence_for_training(sentence)
if encoded_sentence is not None
]

def transform_dataset(self, dataset: Dataset[Sentence]) -> FlairDatapointDataset[EncodedSentence]:
Expand Down Expand Up @@ -643,7 +690,9 @@ def predict(
# Deal with the case where all sentences are standard (non-encoded) sentences
Sentence.set_context_for_sentences(cast(list[Sentence], sentences))
sentences_with_relation_reference: list[tuple[EncodedSentence, Relation]] = list(
itertools.chain.from_iterable(self._encode_sentence_for_inference(sentence) for sentence in sentences)
itertools.chain.from_iterable(
self._encode_sentence_for_inference(sentence) for sentence in sentences if sentence is not None
)
)

encoded_sentences = [x[0] for x in sentences_with_relation_reference]
Expand All @@ -667,32 +716,6 @@ def predict(

return loss if return_loss else None

def _print_predictions(self, batch, gold_label_type: str) -> list[str]:
lines = []
for datapoint in batch:
# check if there is a label mismatch
g = [label.labeled_identifier for label in datapoint.get_labels(gold_label_type)]
p = [label.labeled_identifier for label in datapoint.get_labels("predicted")]
g.sort()
p.sort()

# if the gold label is O and is correctly predicted as no label, do not print out as this clutters
# the output file with trivial predictions
if not (
len(datapoint.get_labels(gold_label_type)) == 1
and datapoint.get_label(gold_label_type).value == "O"
and len(datapoint.get_labels("predicted")) == 0
):
correct_string = " -> MISMATCH!\n" if g != p else ""
eval_line = (
f"{datapoint.text}\n"
f" - Gold: {', '.join(label.value if label.data_point == datapoint else label.labeled_identifier for label in datapoint.get_labels(gold_label_type))}\n"
f" - Pred: {', '.join(label.value if label.data_point == datapoint else label.labeled_identifier for label in datapoint.get_labels('predicted'))}\n"
f"{correct_string}\n"
)
lines.append(eval_line)
return lines

def _get_state_dict(self) -> dict[str, Any]:
model_state: dict[str, Any] = {
**super()._get_state_dict(),
Expand Down

0 comments on commit fc786b3

Please sign in to comment.