Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to export ColPali Model to ONNX #2074

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

akshayballal95
Copy link

What does this PR do?

This PR adds support for exporting the ColPali merged model to ONNX format. The model is based on the "pali gemma" model type, and thus, I have added it under the "feature-extraction" task. Do suggest if there is a better way to integrate this. If this looks fine with a few modifications, I can add support for the Paligemma text-generation task as well.

Before submitting

Who can review?

@fxmarty, @echarlaix, @JingyaHuang, @michaelbenayoun

@akshayballal95
Copy link
Author

@fxmarty, @echarlaix, @JingyaHuang, @michaelbenayoun

Are you open to merging this?

@echarlaix
Copy link
Collaborator

echarlaix commented Dec 6, 2024

Comment on lines +512 to +527
class ColPaliModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)

def patched_forward(input_ids=None, pixel_values=None, attention_mask=None, **kwargs):
outputs = self.orig_forward(
input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **kwargs
)
return outputs

self.patched_forward = patched_forward
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it needed ?

Comment on lines +1175 to +1189

class ColPaliModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)
def patched_forward(input_ids=None, pixel_values=None, attention_mask=None, **kwargs):
outputs = self.orig_forward(
input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **kwargs
)
return outputs
self.patched_forward = patched_forward
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already added above

Suggested change
class ColPaliModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)
def patched_forward(input_ids=None, pixel_values=None, attention_mask=None, **kwargs):
outputs = self.orig_forward(
input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **kwargs
)
return outputs
self.patched_forward = patched_forward

)
dummy_inputs["input_ids"] = generator.concat_inputs([prefix_tensor, dummy_inputs["input_ids"]], dim=1)
dummy_inputs["attention_mask"] = generator.random_mask_tensor(
shape=[generator.batch_size, generator.sequence_length + 1024],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does the value 1024 comes from? shouldn't it depend from the models config ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants