Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error when cls_pooling="mean" or cls_pooling="max" for TransformerDocumentEmbeddings #3558

30 changes: 14 additions & 16 deletions flair/embeddings/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,7 @@

import flair
from flair.data import Sentence, Token, log
from flair.embeddings.base import (
DocumentEmbeddings,
Embeddings,
TokenEmbeddings,
register_embeddings,
)
from flair.embeddings.base import DocumentEmbeddings, Embeddings, TokenEmbeddings, register_embeddings

SENTENCE_BOUNDARY_TAG: str = "[FLERT]"

Expand Down Expand Up @@ -198,24 +193,33 @@ def fill_mean_token_embeddings(


@torch.jit.script_if_tracing
def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor):
def document_cls_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor:
return sentence_hidden_states[torch.arange(sentence_hidden_states.shape[0]), sentence_lengths - 1]


@torch.jit.script_if_tracing
def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor:
result = torch.zeros(
sentence_hidden_states.shape[0], sentence_hidden_states.shape[2], dtype=sentence_hidden_states.dtype
)

for i in torch.arange(sentence_hidden_states.shape[0]):
result[i] = sentence_hidden_states[i, : sentence_lengths[i]].mean(dim=0)

return result


@torch.jit.script_if_tracing
def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor):
def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor:
result = torch.zeros(
sentence_hidden_states.shape[0], sentence_hidden_states.shape[2], dtype=sentence_hidden_states.dtype
)

for i in torch.arange(sentence_hidden_states.shape[0]):
result[i], _ = sentence_hidden_states[i, : sentence_lengths[i]].max(dim=0)

return result


def _legacy_reconstruct_word_ids(
embedding: "TransformerBaseEmbeddings", flair_tokens: list[list[str]]
Expand Down Expand Up @@ -1127,11 +1131,7 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool:
if peft_config is not None:
# add adapters for finetuning
try:
from peft import (
TaskType,
get_peft_model,
prepare_model_for_kbit_training,
)
from peft import TaskType, get_peft_model, prepare_model_for_kbit_training
except ImportError:
log.error("You cannot use the PEFT finetuning without peft being installed")
raise
Expand Down Expand Up @@ -1446,9 +1446,7 @@ def forward(
else:
assert sub_token_lengths is not None
if self.cls_pooling == "cls":
document_embeddings = sentence_hidden_states[
torch.arange(sentence_hidden_states.shape[0]), sub_token_lengths - 1
]
document_embeddings = document_cls_pooling(sentence_hidden_states, sub_token_lengths)
elif self.cls_pooling == "mean":
document_embeddings = document_mean_pooling(sentence_hidden_states, sub_token_lengths)
elif self.cls_pooling == "max":
Expand Down
28 changes: 22 additions & 6 deletions tests/embedding_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ class BaseEmbeddingsTest:
name_field: Optional[str] = None
weired_texts: list[str] = [
"Hybrid mesons , qq ̄ states with an admixture",
"typical proportionalities of \u223C 1nmV \u2212 1 [ 3,4 ] .",
"typical proportionalities of \u223c 1nmV \u2212 1 [ 3,4 ] .",
"🤟 🤟 🤟 hüllo",
"🤟hallo 🤟 🤟 🤟 🤟",
"🤟",
"\uF8F9",
"\uf8f9",
]

def create_embedding_from_name(self, name: str):
Expand Down Expand Up @@ -150,9 +150,17 @@ def test_embeddings_stay_the_same_after_saving_and_loading(self, args):

if self.is_token_embedding:
for token_old, token_new in zip(sentence_old, sentence_new):
assert (token_old.get_embedding(names_old) == token_new.get_embedding(names_new)).all()
assert torch.allclose(
token_old.get_embedding(names_old),
token_new.get_embedding(names_new),
atol=1e-6,
)
if self.is_document_embedding:
assert (sentence_old.get_embedding(names_old) == sentence_new.get_embedding(names_new)).all()
assert torch.allclose(
sentence_old.get_embedding(names_old),
sentence_new.get_embedding(names_new),
atol=1e-6,
)

def test_default_embeddings_stay_the_same_after_saving_and_loading(self):
embeddings = self.create_embedding_with_args(self.default_args)
Expand All @@ -176,9 +184,17 @@ def test_default_embeddings_stay_the_same_after_saving_and_loading(self):

if self.is_token_embedding:
for token_old, token_new in zip(sentence_old, sentence_new):
assert (token_old.get_embedding(names_old) == token_new.get_embedding(names_new)).all()
assert torch.allclose(
token_old.get_embedding(names_old),
token_new.get_embedding(names_new),
atol=1e-6,
)
if self.is_document_embedding:
assert (sentence_old.get_embedding(names_old) == sentence_new.get_embedding(names_new)).all()
assert torch.allclose(
sentence_old.get_embedding(names_old),
sentence_new.get_embedding(names_new),
atol=1e-6,
helpmefindaname marked this conversation as resolved.
Show resolved Hide resolved
)

def test_embeddings_load_in_eval_mode(self):
embeddings = self.create_embedding_with_args(self.default_args)
Expand Down
17 changes: 16 additions & 1 deletion tests/embeddings/test_transformer_document_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from flair.data import Dictionary
import pytest

from flair.data import Dictionary, Sentence
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.nn import Classifier
Expand Down Expand Up @@ -37,3 +39,16 @@ def test_if_loaded_embeddings_have_all_attributes(tasks_base_path):
# check that context_length and use_context_separator is the same for both
assert model.embeddings.context_length == loaded_single_task.embeddings.context_length
assert model.embeddings.use_context_separator == loaded_single_task.embeddings.use_context_separator


@pytest.mark.parametrize("cls_pooling", ["cls", "mean", "max"])
def test_cls_pooling(cls_pooling):
embeddings = TransformerDocumentEmbeddings(
model="distilbert-base-uncased",
layers="-1",
cls_pooling=cls_pooling,
allow_long_sentences=True,
)
sentence = Sentence("Today is a good day.")
embeddings.embed(sentence)
assert sentence.embedding is not None