diff --git a/CHANGELOG.md b/CHANGELOG.md index 94aff3f61d..89d200eeff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [Unreleased] - YYYY-DD-MM ### Added +- Added `TextEmbedder` task ([#996](https://github.com/PyTorchLightning/lightning-flash/pull/996)) - Added predict_kwargs in `ObjectDetector`, `InstanceSegmentation`, `KeypointDetector` ([#990](https://github.com/PyTorchLightning/lightning-flash/pull/990)) diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 581c5cd719..082731f8d0 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -107,6 +107,7 @@ def _compare_version(package: str, op, version) -> bool: _ALBUMENTATIONS_AVAILABLE = _module_available("albumentations") _BAAL_AVAILABLE = _module_available("baal") _TORCH_OPTIMIZER_AVAILABLE = _module_available("torch_optimizer") +_SENTENCE_TRANSFORMERS_AVAILABLE = _module_available("sentence_transformers") if _PIL_AVAILABLE: @@ -130,6 +131,7 @@ class Image: _SENTENCEPIECE_AVAILABLE, _DATASETS_AVAILABLE, _TM_TEXT_AVAILABLE, + _SENTENCE_TRANSFORMERS_AVAILABLE, ] ) _TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE and _FORECASTING_AVAILABLE diff --git a/flash/core/utilities/providers.py b/flash/core/utilities/providers.py index 4c2af721a9..422b019992 100644 --- a/flash/core/utilities/providers.py +++ b/flash/core/utilities/providers.py @@ -42,6 +42,7 @@ def __str__(self): _LEARN2LEARN = Provider("learnables/learn2learn", "https://github.com/learnables/learn2learn") _PYSTICHE = Provider("pystiche/pystiche", "https://github.com/pystiche/pystiche") _HUGGINGFACE = Provider("Hugging Face/transformers", "https://github.com/huggingface/transformers") +_SENTENCE_TRANSFORMERS = Provider("UKPLab/sentence-transformers", "https://github.com/UKPLab/sentence-transformers") _FAIRSEQ = Provider("PyTorch/fairseq", "https://github.com/pytorch/fairseq") _OPEN3D_ML = Provider("Intelligent Systems Lab Org/Open3D-ML", "https://github.com/isl-org/Open3D-ML") _PYTORCHVIDEO = Provider("Facebook Research/PyTorchVideo", "https://github.com/facebookresearch/pytorchvideo") diff --git a/flash/text/__init__.py b/flash/text/__init__.py index 7a17659b20..63400dcd9c 100644 --- a/flash/text/__init__.py +++ b/flash/text/__init__.py @@ -1,4 +1,5 @@ from flash.text.classification import TextClassificationData, TextClassifier # noqa: F401 +from flash.text.embedding import TextEmbedder # noqa: F401 from flash.text.question_answering import QuestionAnsweringData, QuestionAnsweringTask # noqa: F401 from flash.text.seq2seq import ( # noqa: F401 Seq2SeqData, diff --git a/flash/text/embedding/__init__.py b/flash/text/embedding/__init__.py new file mode 100644 index 0000000000..ed171439f7 --- /dev/null +++ b/flash/text/embedding/__init__.py @@ -0,0 +1 @@ +from flash.text.embedding.model import TextEmbedder # noqa: F401 diff --git a/flash/text/embedding/backbones.py b/flash/text/embedding/backbones.py new file mode 100644 index 0000000000..c421e0179e --- /dev/null +++ b/flash/text/embedding/backbones.py @@ -0,0 +1,14 @@ +from flash.core.registry import ExternalRegistry, FlashRegistry +from flash.core.utilities.imports import _TEXT_AVAILABLE +from flash.core.utilities.providers import _HUGGINGFACE + +if _TEXT_AVAILABLE: + from transformers import AutoModel + + HUGGINGFACE_BACKBONES = ExternalRegistry( + AutoModel.from_pretrained, + "backbones", + _HUGGINGFACE, + ) +else: + HUGGINGFACE_BACKBONES = FlashRegistry("backbones") diff --git a/flash/text/embedding/model.py b/flash/text/embedding/model.py new file mode 100644 index 0000000000..2fae923403 --- /dev/null +++ b/flash/text/embedding/model.py @@ -0,0 +1,106 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import warnings +from typing import Any, Dict, List, Optional + +import torch +from pytorch_lightning import Callback + +from flash.core.integrations.transformers.states import TransformersBackboneState +from flash.core.model import Task +from flash.core.registry import FlashRegistry, print_provider_info +from flash.core.utilities.imports import _TEXT_AVAILABLE +from flash.core.utilities.providers import _SENTENCE_TRANSFORMERS +from flash.text.embedding.backbones import HUGGINGFACE_BACKBONES +from flash.text.ort_callback import ORTCallback + +if _TEXT_AVAILABLE: + from sentence_transformers.models import Pooling + + Pooling = print_provider_info("Pooling", _SENTENCE_TRANSFORMERS, Pooling) + +logger = logging.getLogger(__name__) + + +class TextEmbedder(Task): + """The ``TextEmbedder`` is a :class:`~flash.Task` for generating sentence embeddings, training and validation. + For more details, see `embeddings`. + + You can change the backbone to any question answering model from `UKPLab/sentence-transformers + `_ using the ``backbone`` + argument. + + Args: + backbone: backbone model to use for the task. + enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training + """ + + required_extras: str = "text" + + backbones: FlashRegistry = HUGGINGFACE_BACKBONES + + def __init__( + self, + backbone: str = "sentence-transformers/all-MiniLM-L6-v2", + tokenizer_backbone: Optional[str] = None, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + enable_ort: bool = False, + ): + os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" + # disable HF thousand warnings + warnings.simplefilter("ignore") + # set os environ variable for multiprocesses + os.environ["PYTHONWARNINGS"] = "ignore" + super().__init__() + + if tokenizer_backbone is None: + tokenizer_backbone = backbone + self.set_state(TransformersBackboneState(tokenizer_backbone, tokenizer_kwargs=tokenizer_kwargs)) + self.model = self.backbones.get(backbone)() + self.pooling = Pooling(self.model.config.hidden_size) + self.enable_ort = enable_ort + + def training_step(self, batch: Any, batch_idx: int) -> Any: + raise NotImplementedError("Training a `TextEmbedder` is not supported. Use a different text task instead.") + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + raise NotImplementedError("Validating a `TextEmbedder` is not supported. Use a different text task instead.") + + def test_step(self, batch: Any, batch_idx: int) -> Any: + raise NotImplementedError("Testing a `TextEmbedder` is not supported. Use a different text task instead.") + + def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + """Adapted from sentence-transformers: + + https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/models/Transformer.py#L45 + """ + + trans_features = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]} + if "token_type_ids" in batch: + trans_features["token_type_ids"] = batch["token_type_ids"] + + output_states = self.model(**trans_features, return_dict=False) + output_tokens = output_states[0] + + batch.update({"token_embeddings": output_tokens, "attention_mask": batch["attention_mask"]}) + + return self.pooling(batch)["sentence_embedding"] + + def configure_callbacks(self) -> List[Callback]: + callbacks = super().configure_callbacks() or [] + if self.enable_ort: + callbacks.append(ORTCallback()) + return callbacks diff --git a/flash_examples/text_embedder.py b/flash_examples/text_embedder.py new file mode 100644 index 0000000000..f613f0def8 --- /dev/null +++ b/flash_examples/text_embedder.py @@ -0,0 +1,34 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +import flash +from flash.text import TextClassificationData, TextEmbedder + +# 1. Create the DataModule +datamodule = TextClassificationData.from_lists( + predict_data=[ + "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", + "The worst movie in the history of cinema.", + "I come from Bulgaria where it 's almost impossible to have a tornado.", + ] +) + +# 2. Load a previously trained TextEmbedder +model = TextEmbedder(backbone="sentence-transformers/all-MiniLM-L6-v2") + +# 3. Generate embeddings for the first 3 graphs +trainer = flash.Trainer(gpus=torch.cuda.device_count()) +predictions = trainer.predict(model, datamodule=datamodule) +print(predictions) diff --git a/requirements/datatype_text.txt b/requirements/datatype_text.txt index aba24a7ef5..4173bcbd78 100644 --- a/requirements/datatype_text.txt +++ b/requirements/datatype_text.txt @@ -3,3 +3,4 @@ filelock transformers>=4.5 torchmetrics[text]>=0.5.1 datasets>=1.8,<1.13 +sentence-transformers diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 033ee35b3d..6ee35fef73 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -80,6 +80,10 @@ "text_classification.py", marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed"), ), + pytest.param( + "text_embedder.py", + marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed"), + ), # pytest.param( # "text_classification_multi_label.py", # marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed") diff --git a/tests/text/embedding/__init__.py b/tests/text/embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/text/embedding/test_model.py b/tests/text/embedding/test_model.py new file mode 100644 index 0000000000..0a712b3b3a --- /dev/null +++ b/tests/text/embedding/test_model.py @@ -0,0 +1,43 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest +import torch + +import flash +from flash.text import TextClassificationData, TextEmbedder +from tests.helpers.utils import _TEXT_TESTING + +# ======== Mock data ======== + +predict_data = [ + "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", + "The worst movie in the history of cinema.", + "I come from Bulgaria where it 's almost impossible to have a tornado.", +] +# ============================== + +TEST_BACKBONE = "sentence-transformers/all-MiniLM-L6-v2" # super small model for testing + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_predict(tmpdir): + datamodule = TextClassificationData.from_lists(predict_data=predict_data) + model = TextEmbedder(backbone=TEST_BACKBONE) + + trainer = flash.Trainer(gpus=torch.cuda.device_count()) + predictions = trainer.predict(model, datamodule=datamodule) + assert [t.size() for t in predictions[0]] == [torch.Size([384]), torch.Size([384]), torch.Size([384])]