diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index c43258a19..a27b57ab9 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -33,12 +33,7 @@ import flair from flair.data import Sentence, Token, log -from flair.embeddings.base import ( - DocumentEmbeddings, - Embeddings, - TokenEmbeddings, - register_embeddings, -) +from flair.embeddings.base import DocumentEmbeddings, Embeddings, TokenEmbeddings, register_embeddings SENTENCE_BOUNDARY_TAG: str = "[FLERT]" @@ -198,7 +193,12 @@ def fill_mean_token_embeddings( @torch.jit.script_if_tracing -def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor): +def document_cls_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor: + return sentence_hidden_states[torch.arange(sentence_hidden_states.shape[0]), sentence_lengths - 1] + + +@torch.jit.script_if_tracing +def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor: result = torch.zeros( sentence_hidden_states.shape[0], sentence_hidden_states.shape[2], dtype=sentence_hidden_states.dtype ) @@ -206,9 +206,11 @@ def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths for i in torch.arange(sentence_hidden_states.shape[0]): result[i] = sentence_hidden_states[i, : sentence_lengths[i]].mean(dim=0) + return result + @torch.jit.script_if_tracing -def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor): +def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor: result = torch.zeros( sentence_hidden_states.shape[0], sentence_hidden_states.shape[2], dtype=sentence_hidden_states.dtype ) @@ -216,6 +218,8 @@ def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: for i in torch.arange(sentence_hidden_states.shape[0]): result[i], _ = sentence_hidden_states[i, : sentence_lengths[i]].max(dim=0) + return result + def _legacy_reconstruct_word_ids( embedding: "TransformerBaseEmbeddings", flair_tokens: list[list[str]] @@ -1127,11 +1131,7 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool: if peft_config is not None: # add adapters for finetuning try: - from peft import ( - TaskType, - get_peft_model, - prepare_model_for_kbit_training, - ) + from peft import TaskType, get_peft_model, prepare_model_for_kbit_training except ImportError: log.error("You cannot use the PEFT finetuning without peft being installed") raise @@ -1446,9 +1446,7 @@ def forward( else: assert sub_token_lengths is not None if self.cls_pooling == "cls": - document_embeddings = sentence_hidden_states[ - torch.arange(sentence_hidden_states.shape[0]), sub_token_lengths - 1 - ] + document_embeddings = document_cls_pooling(sentence_hidden_states, sub_token_lengths) elif self.cls_pooling == "mean": document_embeddings = document_mean_pooling(sentence_hidden_states, sub_token_lengths) elif self.cls_pooling == "max": diff --git a/tests/embeddings/test_transformer_document_embeddings.py b/tests/embeddings/test_transformer_document_embeddings.py index 1a65a96fd..f0f6389b7 100644 --- a/tests/embeddings/test_transformer_document_embeddings.py +++ b/tests/embeddings/test_transformer_document_embeddings.py @@ -1,4 +1,6 @@ -from flair.data import Dictionary +import pytest + +from flair.data import Dictionary, Sentence from flair.embeddings import TransformerDocumentEmbeddings from flair.models import TextClassifier from flair.nn import Classifier @@ -37,3 +39,16 @@ def test_if_loaded_embeddings_have_all_attributes(tasks_base_path): # check that context_length and use_context_separator is the same for both assert model.embeddings.context_length == loaded_single_task.embeddings.context_length assert model.embeddings.use_context_separator == loaded_single_task.embeddings.use_context_separator + + +@pytest.mark.parametrize("cls_pooling", ["cls", "mean", "max"]) +def test_cls_pooling(cls_pooling): + embeddings = TransformerDocumentEmbeddings( + model="distilbert-base-uncased", + layers="-1", + cls_pooling=cls_pooling, + allow_long_sentences=True, + ) + sentence = Sentence("Today is a good day.") + embeddings.embed(sentence) + assert sentence.embedding is not None