-
Notifications
You must be signed in to change notification settings - Fork 478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support to export ColPali Model to ONNX #2074
base: main
Are you sure you want to change the base?
Conversation
@fxmarty, @echarlaix, @JingyaHuang, @michaelbenayoun Are you open to merging this? |
Apologies for the delay @akshayballal95, could you add a test with a tiny random model like https://huggingface.co/hf-internal-testing/tiny-random-PaliGemmaForConditionalGeneration, can be added here https://github.com/huggingface/optimum/blob/main/tests/exporters/exporters_utils.py#L37 |
class ColPaliModelPatcher(ModelPatcher): | ||
def __init__( | ||
self, | ||
config: "OnnxConfig", | ||
model: Union["PreTrainedModel", "TFPreTrainedModel"], | ||
model_kwargs: Optional[Dict[str, Any]] = None, | ||
): | ||
super().__init__(config, model, model_kwargs) | ||
|
||
def patched_forward(input_ids=None, pixel_values=None, attention_mask=None, **kwargs): | ||
outputs = self.orig_forward( | ||
input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **kwargs | ||
) | ||
return outputs | ||
|
||
self.patched_forward = patched_forward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is it needed ?
|
||
class ColPaliModelPatcher(ModelPatcher): | ||
def __init__( | ||
self, | ||
config: "OnnxConfig", | ||
model: Union["PreTrainedModel", "TFPreTrainedModel"], | ||
model_kwargs: Optional[Dict[str, Any]] = None, | ||
): | ||
super().__init__(config, model, model_kwargs) | ||
def patched_forward(input_ids=None, pixel_values=None, attention_mask=None, **kwargs): | ||
outputs = self.orig_forward( | ||
input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **kwargs | ||
) | ||
return outputs | ||
self.patched_forward = patched_forward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
already added above
class ColPaliModelPatcher(ModelPatcher): | |
def __init__( | |
self, | |
config: "OnnxConfig", | |
model: Union["PreTrainedModel", "TFPreTrainedModel"], | |
model_kwargs: Optional[Dict[str, Any]] = None, | |
): | |
super().__init__(config, model, model_kwargs) | |
def patched_forward(input_ids=None, pixel_values=None, attention_mask=None, **kwargs): | |
outputs = self.orig_forward( | |
input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **kwargs | |
) | |
return outputs | |
self.patched_forward = patched_forward |
) | ||
dummy_inputs["input_ids"] = generator.concat_inputs([prefix_tensor, dummy_inputs["input_ids"]], dim=1) | ||
dummy_inputs["attention_mask"] = generator.random_mask_tensor( | ||
shape=[generator.batch_size, generator.sequence_length + 1024], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where does the value 1024 comes from? shouldn't it depend from the models config ?
What does this PR do?
This PR adds support for exporting the ColPali merged model to ONNX format. The model is based on the "pali gemma" model type, and thus, I have added it under the "feature-extraction" task. Do suggest if there is a better way to integrate this. If this looks fine with a few modifications, I can add support for the Paligemma text-generation task as well.
Before submitting
Who can review?
@fxmarty, @echarlaix, @JingyaHuang, @michaelbenayoun