Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch inference? #6

Open
pieris98 opened this issue Mar 23, 2024 · 1 comment
Open

Batch inference? #6

pieris98 opened this issue Mar 23, 2024 · 1 comment

Comments

@pieris98
Copy link

Hey Nazir, thank you for your exciting research and work on RbA!
I'm trying to potentially run the model for multiple images/frames in a directory in batches.
I noticed that in here you mention only batch_size of 1 is supported.
Is that a limitation related to Mask2Former?
What are the difficulties in adapting the model/code to accept more images in parallel?

Thanks again for your contributions to the community!

@NazirNayal8
Copy link
Owner

Hi @pieris98 , thank you for your interest in our work!

In the script that you have linked, we only use a single batch for simplicity. It can be of course extended to any arbitrary batch size you need. The only important thing to be careful about is satisfying the format expected by mask2former, which requires that a batch is passed as a list, where each image is a dictionary inside that list. For example, the modified version of the function that runs a batch of size 2 would look like this:

def get_logits(model, img1, img2, **kwargs):
  
    with torch.no_grad():
        out = model([{"image": img1.to(DEVICE)}, {"image": img2.to(DEVICE)}])

    return out

note that img1 and img2 are each of shape (3,H,W) (that is the batch dimension is collapsed. And you can of course extend this the same way for any batch size you need.

I hope this was helpful, if you have any further question please do not hesitate to ask.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants