-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds CLIP to models exportable with ONNX #18515
Changes from 10 commits
cf55224
845d2e2
8418aee
eae965e
118ca35
4a22e42
7204e1d
b108224
3a8d870
bfed078
2306fbc
19e0423
af3e2fc
32295b5
3737ec2
82d4a1b
1b30df9
933b12e
0f7a95a
3028908
7663f29
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,8 +16,14 @@ | |
|
||
import copy | ||
import os | ||
from typing import Union | ||
from collections import OrderedDict | ||
|
||
from typing import Any, Mapping, Union, Optional | ||
|
||
from transformers import TensorType | ||
from transformers.processing_utils import ProcessorMixin | ||
|
||
from ...onnx import OnnxConfig | ||
from ...configuration_utils import PretrainedConfig | ||
from ...utils import logging | ||
|
||
|
@@ -317,3 +323,44 @@ def to_dict(self): | |
output["vision_config"] = self.vision_config.to_dict() | ||
output["model_type"] = self.__class__.model_type | ||
return output | ||
|
||
|
||
class CLIPOnnxConfig(OnnxConfig): | ||
@property | ||
def inputs(self) -> Mapping[str, Mapping[int, str]]: | ||
return OrderedDict( | ||
[ | ||
("input_ids", {0: "batch", 1: "sequence"}), | ||
("pixel_values", {0: "batch"}), | ||
("attention_mask", {0: "batch", 1: "sequence"}), | ||
] | ||
) | ||
|
||
@property | ||
def outputs(self) -> Mapping[str, Mapping[int, str]]: | ||
return OrderedDict( | ||
[ | ||
("logits_per_image", {0: "batch"}), | ||
("logits_per_text", {0: "batch"}), | ||
("text_embeds", {0: "batch"}), | ||
("image_embeds", {0: "batch"}), | ||
] | ||
) | ||
|
||
@property | ||
def atol_for_validation(self) -> float: | ||
return 1e-4 | ||
|
||
def generate_dummy_inputs( | ||
self, | ||
processor: ProcessorMixin, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should use a forward reference here |
||
framework: Optional[TensorType] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should use a forward reference here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. by forward reference do you mean making it like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes exactly |
||
) -> Mapping[str, Any]: | ||
|
||
text_input_dict = super().generate_dummy_inputs(processor.tokenizer, framework=framework) | ||
image_input_dict = super().generate_dummy_inputs(processor.feature_extractor, framework=framework) | ||
return {**text_input_dict, **image_input_dict} | ||
|
||
@property | ||
def default_onnx_opset(self) -> int: | ||
return 14 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -630,6 +630,7 @@ def forward( | |
if input_ids is None: | ||
raise ValueError("You have to specify either input_ids") | ||
|
||
input_ids = input_ids.to(torch.int) # for onnx compatibility, since onnx doesn't support int64 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The cast is only needed for the |
||
input_shape = input_ids.size() | ||
input_ids = input_ids.view(-1, input_shape[-1]) | ||
|
||
|
@@ -1044,8 +1045,8 @@ def forward( | |
text_embeds = self.text_projection(text_embeds) | ||
|
||
# normalized features | ||
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) | ||
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) | ||
image_embeds = image_embeds / image_embeds.norm(p=2, dim=1, keepdim=True) | ||
text_embeds = text_embeds / text_embeds.norm(p=2, dim=1, keepdim=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer not modifying the source code like this. Why did you need to do that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can revert this, this was from the original repo There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah okay I see. Yes I would prefer to keep it as it is. |
||
|
||
# cosine similarity as logits | ||
logit_scale = self.logit_scale.exp() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TensorType
andProcessorMixin
are only needed for type hinting. The way we manage imports in Transformers will return an error with this implementation. That is why some tests failed :)To solve this, you need to put these two imports in a
TYPE_CHECKING
conditional statement, here is an example.Also, it's better to use relative imports because absolute imports can lead to weird errors.
I made the exact same mistakes in the PR of LayoutLMv3 haha
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my black version is
22.6.0
, I think I didmake fix-copies
but tests were still failing so I didblack .
from inside the CLIP folder.what is the correct way to fix this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try with black 22.3, and use the command
make style
from the root of the repo to format the code.