Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Add SentenceTransformers support to pipeline for feature-extration #583

Merged
merged 5 commits into from
May 6, 2024

Conversation

philschmid
Copy link
Member

What does this PR do?

This PR adds a new, slightly modified FeatureExtractionPipeline from Transformers that allows us to use it with sentence-transformers models. When using the pipeline object from optimum,, the library checks if the requested model for feature-extraction is a sentence-transformers model and if so, it would return the sentence_embeddings instead of the first hidden state.

Thats is done by adding a new is_sentence_transformer_model that checks if the requested model is a transformers or sentence-transformers model. If it is a sentence-transformers model, it uses NeuronModelForSentenceTransformers and the FeatureExtractionPipeline returns _model_outputs.sentence_embedding[0] instead of model_outputs[0]

Example:

from optimum.neuron import pipeline

input_shapes = {"batch_size": 1, "sequence_length": 64} 
p = pipeline("feature-extraction","sentence-transformers/all-MiniLM-L6-v2",export=True, input_shapes=input_shapes)
> # Using Sentence Transformers compatible Feature extraction pipeline

p("test")
> [0.06765521317720413,
> 0.06349243223667145,
> 0.04871273413300514,
> 0.0793028473854065,

Validated with torch.allclose

Implications:

  • sentence-transformers models will now always return the sentence_embeddings when initialized with the FeatureExtractionPipeline pipeline.

Alternatives options:

  • Instead of modifying the feature-extraction pipeline, we could also introduce a new task sentence-embeddings to optimum, but that might hinder more general adoption since it is unique to optimum-neuron.

@philschmid
Copy link
Member Author

@tomaarsen can you also do a review?

@philschmid philschmid changed the title [Infernece] Add SentenceTransformers support to pipeline for feature-extration [Inference] Add SentenceTransformers support to pipeline for feature-extration Apr 30, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several pipeline tests are failing. It is unclear to me if they should be triggered or not.


return preprocess_params, {}, postprocess_params

def preprocess(self, inputs, **tokenize_kwargs) -> Dict[str, GenericTensor]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My primary concern at this time is that the Sentence Transformer tokenizer uses this max_seq_length as the "correct" maximum length as opposed to the value defined in the tokenizer_config.json.

Here, we are relying on the tokenizer defined in the Pipeline, which won't use the max_seq_length. As a result, I think this ST pipeline component will perform differently (worse, to be precise) for longer input texts. A solution is to use model_inputs = self.model.tokenize(inputs) instead.
Do note that the ST tokenize method does not allow for extra tokenize kwargs such as truncation, return_tensors, or padding. These are unfortunately hardcoded at the moment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yet another solution is to rely exclusively on self.model.encode(...) in def _forward, but I recognize that this might clash with some requirements of the Pipeline.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tomaarsen for inferentia models need to be traced to a sequence length before running inference since we have static shapes. You always need to specify a sequence_length and batch_size before you can compile a model which is then used.
This abstracted away by the user in the NeuronModelForSentenceTransformers class.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no "sentence-transformers" used at all at the end. since the model is traced and it happens with transformers. We should be good here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See here:

class NeuronModelForSentenceTransformers(NeuronBaseModel):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, I see! Thanks for the heads up. I figured it was more like Intel Gaudi, which does just work with regular Sentence Transformers (as long as the padding is "max_length" to also get static shapes & the device is "hpu").

Then my concern still stands: I think the max_seq_length might not be taken into account correctly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sentence_bert_config.json is not taken into consideration during the export, and maybe we should. @tomaarsen where can we usually find the max_seq_length? Is there a specific name / path that it's stored, if so could we add it to config.json?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi, max_seq_length is not taken into account at all by the export of Neuron model. Shall we prevent users from setting static seq len higher than this value?

Copy link
Member

@tomaarsen tomaarsen May 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Sentence Transformer models, max_seq_length should have priority. In 99% of cases, this is stored in the sentence_bert_config.json file in the root of the model directory/repository. You might indeed want to store it in a config.json when exporting for Neuron, or override model_max_length in tokenizer_config.json as that should work in a more "expected" fashion.

Copy link
Collaborator

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disclaimer: I don't know sentence transformers, so this review is on the general outlook of the code and test, which looks good to me. @JingyaHuang and @tomaarsen reviews will be more relevant.

Copy link
Collaborator

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @philschmid, left some small nits.


return preprocess_params, {}, postprocess_params

def preprocess(self, inputs, **tokenize_kwargs) -> Dict[str, GenericTensor]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sentence_bert_config.json is not taken into consideration during the export, and maybe we should. @tomaarsen where can we usually find the max_seq_length? Is there a specific name / path that it's stored, if so could we add it to config.json?

Copy link
Collaborator

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @philschmid for adding it!

And thanks @tomaarsen for raising the concern of max_seq_length, it's on the backlog, will improve it in a coming PR.

@JingyaHuang JingyaHuang merged commit 9361b55 into main May 6, 2024
11 of 12 checks passed
@JingyaHuang JingyaHuang deleted the st-pipeline branch May 6, 2024 22:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants