Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Add SentenceTransformers support to pipeline for feature-extration #583

Merged
merged 5 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion optimum/neuron/pipelines/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from transformers import (
AutoConfig,
FeatureExtractionPipeline,
FillMaskPipeline,
Pipeline,
PreTrainedModel,
Expand All @@ -37,12 +36,17 @@

from optimum.modeling_base import OptimizedModel
from optimum.neuron.modeling_base import NeuronBaseModel
from optimum.neuron.pipelines.transformers.sentence_transformers import (
FeatureExtractionPipeline,
is_sentence_transformer_model,
)

from ...modeling import (
NeuronModelForCausalLM,
NeuronModelForFeatureExtraction,
NeuronModelForMaskedLM,
NeuronModelForQuestionAnswering,
NeuronModelForSentenceTransformers,
NeuronModelForSequenceClassification,
NeuronModelForTokenClassification,
)
Expand Down Expand Up @@ -119,6 +123,13 @@ def load_pipeline(
elif isinstance(model, str):
model_id = model
neuronx_model_class = supported_tasks[targeted_task]["class"][0]
# Try to determine the correct feature extraction class to use.
if targeted_task == "feature-extraction" and is_sentence_transformer_model(
model, token=token, revision=revision
):
logger.info("Using Sentence Transformers compatible Feature extraction pipeline")
neuronx_model_class = NeuronModelForSentenceTransformers

model = neuronx_model_class.from_pretrained(
model, export=export, **compiler_args, **input_shapes, **hub_kwargs, **kwargs
)
Expand Down Expand Up @@ -267,5 +278,6 @@ def pipeline(
feature_extractor=feature_extractor,
use_fast=use_fast,
batch_size=batch_size,
pipeline_class=NEURONX_SUPPORTED_TASKS[task]["impl"],
**kwargs,
)
90 changes: 90 additions & 0 deletions optimum/neuron/pipelines/transformers/sentence_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import Dict

from transformers.pipelines.base import GenericTensor, Pipeline

from optimum.utils import is_sentence_transformers_available


def is_sentence_transformer_model(model: str, token: str = None, revision: str = None):
"""Checks if the model is a sentence transformer model based on provided model id"""
if is_sentence_transformers_available():
from optimum.exporters.tasks import TasksManager
philschmid marked this conversation as resolved.
Show resolved Hide resolved

try:
_library_name = TasksManager.infer_library_from_model(model, use_auth_token=token, revision=revision)
return True if _library_name == "sentence_transformers" else False
except ValueError:
return False
philschmid marked this conversation as resolved.
Show resolved Hide resolved
return False


class FeatureExtractionPipeline(Pipeline):
"""
Sentence Transformers compatible Feature extraction pipeline uses no model head.
This pipeline extracts the sentence embeddings from the sentence transformers, which can be used
in embedding-based tasks like clustering and search. The pipeline is based on the `transformers` library.
And automatically used instead of the `transformers` library's pipeline when the model is a sentence transformer model.

Example:

```python
>>> from optimum.neuron import pipeline

>>> extractor = pipeline(model="sentence-transformers/all-MiniLM-L6-v2", task="feature-extraction", export=True, batch_size=2, sequence_length=128)
>>> result = extractor("This is a simple test.", return_tensors=True)
>>> result.shape # This is a tensor of shape [1, dimension] representing the input string.
torch.Size([1, 384])
```
"""

def _sanitize_parameters(self, truncation=None, tokenize_kwargs=None, return_tensors=None, **kwargs):
if tokenize_kwargs is None:
tokenize_kwargs = {}

if truncation is not None:
if "truncation" in tokenize_kwargs:
raise ValueError(
"truncation parameter defined twice (given as keyword argument as well as in tokenize_kwargs)"
)
tokenize_kwargs["truncation"] = truncation

preprocess_params = tokenize_kwargs

postprocess_params = {}
if return_tensors is not None:
postprocess_params["return_tensors"] = return_tensors

return preprocess_params, {}, postprocess_params

def preprocess(self, inputs, **tokenize_kwargs) -> Dict[str, GenericTensor]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My primary concern at this time is that the Sentence Transformer tokenizer uses this max_seq_length as the "correct" maximum length as opposed to the value defined in the tokenizer_config.json.

Here, we are relying on the tokenizer defined in the Pipeline, which won't use the max_seq_length. As a result, I think this ST pipeline component will perform differently (worse, to be precise) for longer input texts. A solution is to use model_inputs = self.model.tokenize(inputs) instead.
Do note that the ST tokenize method does not allow for extra tokenize kwargs such as truncation, return_tensors, or padding. These are unfortunately hardcoded at the moment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yet another solution is to rely exclusively on self.model.encode(...) in def _forward, but I recognize that this might clash with some requirements of the Pipeline.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tomaarsen for inferentia models need to be traced to a sequence length before running inference since we have static shapes. You always need to specify a sequence_length and batch_size before you can compile a model which is then used.
This abstracted away by the user in the NeuronModelForSentenceTransformers class.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no "sentence-transformers" used at all at the end. since the model is traced and it happens with transformers. We should be good here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See here:

class NeuronModelForSentenceTransformers(NeuronBaseModel):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, I see! Thanks for the heads up. I figured it was more like Intel Gaudi, which does just work with regular Sentence Transformers (as long as the padding is "max_length" to also get static shapes & the device is "hpu").

Then my concern still stands: I think the max_seq_length might not be taken into account correctly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sentence_bert_config.json is not taken into consideration during the export, and maybe we should. @tomaarsen where can we usually find the max_seq_length? Is there a specific name / path that it's stored, if so could we add it to config.json?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi, max_seq_length is not taken into account at all by the export of Neuron model. Shall we prevent users from setting static seq len higher than this value?

Copy link
Member

@tomaarsen tomaarsen May 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Sentence Transformer models, max_seq_length should have priority. In 99% of cases, this is stored in the sentence_bert_config.json file in the root of the model directory/repository. You might indeed want to store it in a config.json when exporting for Neuron, or override model_max_length in tokenizer_config.json as that should work in a more "expected" fashion.

model_inputs = self.tokenizer(inputs, return_tensors=self.framework, **tokenize_kwargs)
return model_inputs

def _forward(self, model_inputs):
model_outputs = self.model(**model_inputs)
return model_outputs

def postprocess(self, _model_outputs, return_tensors=False):
# Needed change for sentence transformers.
# Check if the model outputs sentence embeddings or not.
if hasattr(_model_outputs, "sentence_embedding"):
model_outputs = _model_outputs.sentence_embedding
else:
model_outputs = _model_outputs
# [0] is the first available tensor, logits or last_hidden_state.
if return_tensors:
return model_outputs[0]
if self.framework == "pt":
return model_outputs[0].tolist()

def __call__(self, *args, **kwargs):
"""
Extract the features of the input(s).

Args:
args (`str` or `List[str]`): One or several texts (or one list of texts) to get the features of.

Return:
A nested list of `float`: The features computed by the model.
"""
return super().__call__(*args, **kwargs)
18 changes: 18 additions & 0 deletions tests/inference/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,24 @@ def test_sentence_transformers_clip(self, model_arch):

gc.collect()

@parameterized.expand(["transformer"], skip_on_empty=True)
@requires_neuronx
def test_pipeline_model(self, model_arch):
input_shapes = {
"batch_size": 1,
"sequence_length": 16,
}
model_id = SENTENCE_TRANSFORMERS_MODEL_NAMES[model_arch]
neuron_model = self.NEURON_MODEL_CLASS.from_pretrained(model_id, export=True, **input_shapes)
tokenizer = get_preprocessor(model_id)
pipe = pipeline(self.TASK, model=neuron_model, tokenizer=tokenizer)
text = "My Name is Philipp."
outputs = pipe(text)

self.assertTrue(all(isinstance(item, float) for item in outputs))

gc.collect()


@is_inferentia_test
class NeuronModelForMaskedLMIntegrationTest(NeuronModelTestMixin):
Expand Down
Loading