From caea8bb58467f9b84b4025d21906c8af9b1fdb64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Karlo=20Do=C5=A1ilovi=C4=87?= Date: Sat, 19 Oct 2024 12:11:28 +0200 Subject: [PATCH 1/6] Initial commit. --- flair/embeddings/transformer.py | 27 ++++++++----------- .../test_transformer_document_embeddings.py | 19 +++++++++++++ 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index 1e88787deb..3205aeabc6 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -8,18 +8,18 @@ from abc import abstractmethod from io import BytesIO from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union, cast +from typing import Any, cast, Dict, List, Literal, Optional, Tuple, Type, Union import torch import transformers from packaging.version import Version from torch.jit import ScriptModule from transformers import ( - CONFIG_MAPPING, AutoConfig, AutoFeatureExtractor, AutoModel, AutoTokenizer, + CONFIG_MAPPING, FeatureExtractionMixin, LayoutLMTokenizer, LayoutLMTokenizerFast, @@ -32,13 +32,8 @@ from transformers.utils import PaddingStrategy import flair -from flair.data import Sentence, Token, log -from flair.embeddings.base import ( - DocumentEmbeddings, - Embeddings, - TokenEmbeddings, - register_embeddings, -) +from flair.data import log, Sentence, Token +from flair.embeddings.base import DocumentEmbeddings, Embeddings, register_embeddings, TokenEmbeddings SENTENCE_BOUNDARY_TAG: str = "[FLERT]" @@ -198,7 +193,7 @@ def fill_mean_token_embeddings( @torch.jit.script_if_tracing -def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor): +def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor: result = torch.zeros( sentence_hidden_states.shape[0], sentence_hidden_states.shape[2], dtype=sentence_hidden_states.dtype ) @@ -206,9 +201,11 @@ def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths for i in torch.arange(sentence_hidden_states.shape[0]): result[i] = sentence_hidden_states[i, : sentence_lengths[i]].mean(dim=0) + return result + @torch.jit.script_if_tracing -def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor): +def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor: result = torch.zeros( sentence_hidden_states.shape[0], sentence_hidden_states.shape[2], dtype=sentence_hidden_states.dtype ) @@ -216,6 +213,8 @@ def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: for i in torch.arange(sentence_hidden_states.shape[0]): result[i], _ = sentence_hidden_states[i, : sentence_lengths[i]].max(dim=0) + return result + def _legacy_reconstruct_word_ids( embedding: "TransformerBaseEmbeddings", flair_tokens: List[List[str]] @@ -1127,11 +1126,7 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool: if peft_config is not None: # add adapters for finetuning try: - from peft import ( - TaskType, - get_peft_model, - prepare_model_for_kbit_training, - ) + from peft import get_peft_model, prepare_model_for_kbit_training, TaskType except ImportError: log.error("You cannot use the PEFT finetuning without peft being installed") raise diff --git a/tests/embeddings/test_transformer_document_embeddings.py b/tests/embeddings/test_transformer_document_embeddings.py index 1a65a96fdb..f253d2d8d8 100644 --- a/tests/embeddings/test_transformer_document_embeddings.py +++ b/tests/embeddings/test_transformer_document_embeddings.py @@ -1,7 +1,10 @@ +import pytest + from flair.data import Dictionary from flair.embeddings import TransformerDocumentEmbeddings from flair.models import TextClassifier from flair.nn import Classifier + from tests.embedding_test_utils import BaseEmbeddingsTest @@ -37,3 +40,19 @@ def test_if_loaded_embeddings_have_all_attributes(tasks_base_path): # check that context_length and use_context_separator is the same for both assert model.embeddings.context_length == loaded_single_task.embeddings.context_length assert model.embeddings.use_context_separator == loaded_single_task.embeddings.use_context_separator + + +@pytest.mark.parametrize("cls_pooling", ["cls", "mean", "max"]) +def test_cls_pooling(cls_pooling): + from flair.data import Sentence + from flair.embeddings import TransformerDocumentEmbeddings + + embeddings = TransformerDocumentEmbeddings( + model="xlm-roberta-base", + layers="-1", + cls_pooling=cls_pooling, + allow_long_sentences=True, + ) + sentence = Sentence("Today is a good day.") + embeddings.embed(sentence) + assert sentence.embedding is not None From 43fc96a5370fa0260c62018d0bd4f87579a2e987 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Karlo=20Do=C5=A1ilovi=C4=87?= Date: Sat, 19 Oct 2024 12:16:52 +0200 Subject: [PATCH 2/6] Remove imports from the test function. --- tests/embeddings/test_transformer_document_embeddings.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/embeddings/test_transformer_document_embeddings.py b/tests/embeddings/test_transformer_document_embeddings.py index f253d2d8d8..bf202110b1 100644 --- a/tests/embeddings/test_transformer_document_embeddings.py +++ b/tests/embeddings/test_transformer_document_embeddings.py @@ -1,6 +1,6 @@ import pytest -from flair.data import Dictionary +from flair.data import Dictionary, Sentence from flair.embeddings import TransformerDocumentEmbeddings from flair.models import TextClassifier from flair.nn import Classifier @@ -44,9 +44,6 @@ def test_if_loaded_embeddings_have_all_attributes(tasks_base_path): @pytest.mark.parametrize("cls_pooling", ["cls", "mean", "max"]) def test_cls_pooling(cls_pooling): - from flair.data import Sentence - from flair.embeddings import TransformerDocumentEmbeddings - embeddings = TransformerDocumentEmbeddings( model="xlm-roberta-base", layers="-1", From fdb49952c61abe3f671b0a39d1eb2685a8aff9f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Karlo=20Do=C5=A1ilovi=C4=87?= Date: Sat, 19 Oct 2024 13:12:25 +0200 Subject: [PATCH 3/6] Refactor cls pooling into a function. --- flair/embeddings/transformer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index 3205aeabc6..8a1e76cd72 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -191,6 +191,10 @@ def fill_mean_token_embeddings( return all_token_embeddings +@torch.jit.script_if_tracing +def document_cls_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor: + return sentence_hidden_states[torch.arange(sentence_hidden_states.shape[0]), sentence_lengths - 1] + @torch.jit.script_if_tracing def document_mean_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor: @@ -1436,9 +1440,7 @@ def forward( else: assert sub_token_lengths is not None if self.cls_pooling == "cls": - document_embeddings = sentence_hidden_states[ - torch.arange(sentence_hidden_states.shape[0]), sub_token_lengths - 1 - ] + document_embeddings = document_cls_pooling(sentence_hidden_states, sub_token_lengths) elif self.cls_pooling == "mean": document_embeddings = document_mean_pooling(sentence_hidden_states, sub_token_lengths) elif self.cls_pooling == "max": From d932baf83cad1081d8eb8f2ee9715d947a6c466a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Karlo=20Do=C5=A1ilovi=C4=87?= Date: Fri, 25 Oct 2024 23:05:48 +0200 Subject: [PATCH 4/6] Fix formatting errors. --- flair/embeddings/transformer.py | 11 ++++++----- .../test_transformer_document_embeddings.py | 1 - 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index 8a1e76cd72..8ba17b1fec 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -8,18 +8,18 @@ from abc import abstractmethod from io import BytesIO from pathlib import Path -from typing import Any, cast, Dict, List, Literal, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union, cast import torch import transformers from packaging.version import Version from torch.jit import ScriptModule from transformers import ( + CONFIG_MAPPING, AutoConfig, AutoFeatureExtractor, AutoModel, AutoTokenizer, - CONFIG_MAPPING, FeatureExtractionMixin, LayoutLMTokenizer, LayoutLMTokenizerFast, @@ -32,8 +32,8 @@ from transformers.utils import PaddingStrategy import flair -from flair.data import log, Sentence, Token -from flair.embeddings.base import DocumentEmbeddings, Embeddings, register_embeddings, TokenEmbeddings +from flair.data import Sentence, Token, log +from flair.embeddings.base import DocumentEmbeddings, Embeddings, TokenEmbeddings, register_embeddings SENTENCE_BOUNDARY_TAG: str = "[FLERT]" @@ -191,6 +191,7 @@ def fill_mean_token_embeddings( return all_token_embeddings + @torch.jit.script_if_tracing def document_cls_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: torch.Tensor) -> torch.Tensor: return sentence_hidden_states[torch.arange(sentence_hidden_states.shape[0]), sentence_lengths - 1] @@ -1130,7 +1131,7 @@ def is_supported_t5_model(config: PretrainedConfig) -> bool: if peft_config is not None: # add adapters for finetuning try: - from peft import get_peft_model, prepare_model_for_kbit_training, TaskType + from peft import TaskType, get_peft_model, prepare_model_for_kbit_training except ImportError: log.error("You cannot use the PEFT finetuning without peft being installed") raise diff --git a/tests/embeddings/test_transformer_document_embeddings.py b/tests/embeddings/test_transformer_document_embeddings.py index bf202110b1..7402b2b467 100644 --- a/tests/embeddings/test_transformer_document_embeddings.py +++ b/tests/embeddings/test_transformer_document_embeddings.py @@ -4,7 +4,6 @@ from flair.embeddings import TransformerDocumentEmbeddings from flair.models import TextClassifier from flair.nn import Classifier - from tests.embedding_test_utils import BaseEmbeddingsTest From 3c3620061c4a3578a536bf9c25ea78c745e2f19b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Karlo=20Do=C5=A1ilovi=C4=87?= Date: Tue, 29 Oct 2024 21:01:30 +0100 Subject: [PATCH 5/6] Use torch.allclose for comparing tensors in BaseEmbeddingsTest. --- tests/embedding_test_utils.py | 28 +++++++++++++++---- .../test_transformer_document_embeddings.py | 2 +- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/tests/embedding_test_utils.py b/tests/embedding_test_utils.py index 554ef32777..933dafad73 100644 --- a/tests/embedding_test_utils.py +++ b/tests/embedding_test_utils.py @@ -19,11 +19,11 @@ class BaseEmbeddingsTest: name_field: Optional[str] = None weired_texts: List[str] = [ "Hybrid mesons , qq ̄ states with an admixture", - "typical proportionalities of \u223C 1nmV \u2212 1 [ 3,4 ] .", + "typical proportionalities of \u223c 1nmV \u2212 1 [ 3,4 ] .", "🤟 🤟 🤟 hüllo", "🤟hallo 🤟 🤟 🤟 🤟", "🤟", - "\uF8F9", + "\uf8f9", ] def create_embedding_from_name(self, name: str): @@ -150,9 +150,17 @@ def test_embeddings_stay_the_same_after_saving_and_loading(self, args): if self.is_token_embedding: for token_old, token_new in zip(sentence_old, sentence_new): - assert (token_old.get_embedding(names_old) == token_new.get_embedding(names_new)).all() + assert torch.allclose( + token_old.get_embedding(names_old), + token_new.get_embedding(names_new), + atol=1e-6, + ) if self.is_document_embedding: - assert (sentence_old.get_embedding(names_old) == sentence_new.get_embedding(names_new)).all() + assert torch.allclose( + sentence_old.get_embedding(names_old), + sentence_new.get_embedding(names_new), + atol=1e-6, + ) def test_default_embeddings_stay_the_same_after_saving_and_loading(self): embeddings = self.create_embedding_with_args(self.default_args) @@ -176,9 +184,17 @@ def test_default_embeddings_stay_the_same_after_saving_and_loading(self): if self.is_token_embedding: for token_old, token_new in zip(sentence_old, sentence_new): - assert (token_old.get_embedding(names_old) == token_new.get_embedding(names_new)).all() + assert torch.allclose( + token_old.get_embedding(names_old), + token_new.get_embedding(names_new), + atol=1e-6, + ) if self.is_document_embedding: - assert (sentence_old.get_embedding(names_old) == sentence_new.get_embedding(names_new)).all() + assert torch.allclose( + sentence_old.get_embedding(names_old), + sentence_new.get_embedding(names_new), + atol=1e-6, + ) def test_embeddings_load_in_eval_mode(self): embeddings = self.create_embedding_with_args(self.default_args) diff --git a/tests/embeddings/test_transformer_document_embeddings.py b/tests/embeddings/test_transformer_document_embeddings.py index 7402b2b467..f0f6389b7d 100644 --- a/tests/embeddings/test_transformer_document_embeddings.py +++ b/tests/embeddings/test_transformer_document_embeddings.py @@ -44,7 +44,7 @@ def test_if_loaded_embeddings_have_all_attributes(tasks_base_path): @pytest.mark.parametrize("cls_pooling", ["cls", "mean", "max"]) def test_cls_pooling(cls_pooling): embeddings = TransformerDocumentEmbeddings( - model="xlm-roberta-base", + model="distilbert-base-uncased", layers="-1", cls_pooling=cls_pooling, allow_long_sentences=True, From 904052d76ff83c13bdf33f325d6612ee3225ed7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Filip=20Karlo=20Do=C5=A1ilovi=C4=87?= Date: Fri, 29 Nov 2024 14:24:39 +0100 Subject: [PATCH 6/6] Address reviewer's comments. --- tests/embedding_test_utils.py | 28 ++++++---------------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/tests/embedding_test_utils.py b/tests/embedding_test_utils.py index 706a5cb587..c1a0b1a791 100644 --- a/tests/embedding_test_utils.py +++ b/tests/embedding_test_utils.py @@ -19,11 +19,11 @@ class BaseEmbeddingsTest: name_field: Optional[str] = None weired_texts: list[str] = [ "Hybrid mesons , qq ̄ states with an admixture", - "typical proportionalities of \u223c 1nmV \u2212 1 [ 3,4 ] .", + "typical proportionalities of \u223C 1nmV \u2212 1 [ 3,4 ] .", "🤟 🤟 🤟 hüllo", "🤟hallo 🤟 🤟 🤟 🤟", "🤟", - "\uf8f9", + "\uF8F9", ] def create_embedding_from_name(self, name: str): @@ -150,17 +150,9 @@ def test_embeddings_stay_the_same_after_saving_and_loading(self, args): if self.is_token_embedding: for token_old, token_new in zip(sentence_old, sentence_new): - assert torch.allclose( - token_old.get_embedding(names_old), - token_new.get_embedding(names_new), - atol=1e-6, - ) + assert (token_old.get_embedding(names_old) == token_new.get_embedding(names_new)).all() if self.is_document_embedding: - assert torch.allclose( - sentence_old.get_embedding(names_old), - sentence_new.get_embedding(names_new), - atol=1e-6, - ) + assert (sentence_old.get_embedding(names_old) == sentence_new.get_embedding(names_new)).all() def test_default_embeddings_stay_the_same_after_saving_and_loading(self): embeddings = self.create_embedding_with_args(self.default_args) @@ -184,17 +176,9 @@ def test_default_embeddings_stay_the_same_after_saving_and_loading(self): if self.is_token_embedding: for token_old, token_new in zip(sentence_old, sentence_new): - assert torch.allclose( - token_old.get_embedding(names_old), - token_new.get_embedding(names_new), - atol=1e-6, - ) + assert (token_old.get_embedding(names_old) == token_new.get_embedding(names_new)).all() if self.is_document_embedding: - assert torch.allclose( - sentence_old.get_embedding(names_old), - sentence_new.get_embedding(names_new), - atol=1e-6, - ) + assert (sentence_old.get_embedding(names_old) == sentence_new.get_embedding(names_new)).all() def test_embeddings_load_in_eval_mode(self): embeddings = self.create_embedding_with_args(self.default_args)