Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference]Support sentence transformers clip #495

Merged
merged 13 commits into from
Feb 28, 2024
Merged

Conversation

JingyaHuang
Copy link
Collaborator

@JingyaHuang JingyaHuang commented Feb 21, 2024

What does this PR do?

Follow up for #408, support the inference of sentence transformers clip.

[Compilation]

  • With CLI
optimum-cli export neuron -m sentence-transformers/clip-ViT-B-32 --sequence_length 64 --text_batch_size 3 --image_batch_size 1 --num_channels 3 --height 224 --width 224 --task feature-extraction --library-name sentence_transformers --subfolder 0_CLIPModel clip_emb/
  • Code
from optimum.neuron import NeuronModelForSentenceTransformers

# [Compile]
model_id = "sentence-transformers/clip-ViT-B-32"

# configs for compiling model
input_shapes = {
    "num_channels": 3,
    "height": 224,
    "width": 224,
    "text_batch_size": 3,
    "image_batch_size": 1,
    "sequence_length": 64,
}

emb_model = NeuronModelForSentenceTransformers.from_pretrained(
    model_id, subfolder="0_CLIPModel", export=True, library_name="sentence_transformers", dynamic_batch_size=False, **input_shapes
)

# Save locally or upload to the HuggingFace Hub
save_directory = "clip_emb/"
emb_model.save_pretrained(save_directory)

[Inference]

from PIL import Image
from sentence_transformers import util
from transformers import CLIPProcessor

from optimum.neuron import NeuronModelForSentenceTransformers

save_directory = "clip_emb"
emb_model = NeuronModelForSentenceTransformers.from_pretrained(save_directory)

processor = CLIPProcessor.from_pretrained(save_directory)
inputs = processor(
    text=["Two dogs in the snow", 'A cat on a table', 'A picture of London at night'], images=Image.open("two_dogs_in_snow.jpg"), return_tensors="pt", padding=True
)  

outputs = emb_model(**inputs)


# Compute cosine similarities
cos_scores = util.cos_sim(outputs.image_embeds, outputs.text_embeds)
print(cos_scores)

# tensor([[0.3072, 0.1016, 0.1095]])

Caveat

Since compiled models with dynamic batch size enabled only accept tensors with the same batch size, we cannot set dynamic_batch_size=True if the input texts and images have different batch sizes. And as NeuronModelForSentenceTransformers pad the inputs to the batch size used during the compilation, you could use a relatively larger batch_size during the compilation for flexibility with the trade-off of compute).

eg. if you want to encode 3 or 4 or 5 texts and 1 image, you could set text_batch_size = 5 = max(3, 4, 5) and image_batch_size = 1 during the compilation.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

docs/source/tutorials/sentence_transformers.mdx Outdated Show resolved Hide resolved
docs/source/tutorials/sentence_transformers.mdx Outdated Show resolved Hide resolved
docs/source/tutorials/sentence_transformers.mdx Outdated Show resolved Hide resolved
optimum/commands/export/neuronx.py Outdated Show resolved Hide resolved
optimum/commands/export/neuronx.py Outdated Show resolved Hide resolved
optimum/exporters/neuron/convert.py Show resolved Hide resolved
optimum/neuron/modeling.py Show resolved Hide resolved

# Copied and adapted from https://github.com/huggingface/optimum/blob/d03ab100206cb9f0e62167a36ee6997424bb9bb5/optimum/utils/save_utils.py#L27
# To remove once we can bump to a transformers release including https://github.com/huggingface/transformers/pull/29169
def maybe_load_preprocessors(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe put a check of the transformers version, if it's above the coming release, then fail. This way we can actually catch that and remove it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checked the PR I submited to transformers, it's not merged yet and I'm not 100% percent sure that it should be accepted by trfrs maintainers, thus completely don't know what trfr version to put here. Let's keep it this way, I will keep my eyes on how it goes...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JingyaHuang and others added 7 commits February 27, 2024 17:59
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@JingyaHuang JingyaHuang merged commit 196f3c7 into main Feb 28, 2024
14 checks passed
@JingyaHuang JingyaHuang deleted the support-st-trfrs-clip branch February 28, 2024 09:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants