Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle batch input inference for MoD Infini-Former more gracefully #10

Open
dingo-actual opened this issue Apr 23, 2024 · 3 comments
Open
Assignees
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@dingo-actual
Copy link
Owner

Currently, the token sampling for MoD Infini-Former at inference time can result in different length sequences for each observation in the batch. The current workaround is to force the batch size to one and loop through the observations in the batch, which is highly inefficient.

There are two main options for handling this efficiently:

  1. Pad the sampled sequences to the longest sequence length in such a way that the additional tokens contribute nothing to downstream calculations.
  2. Wait for PyTorch to implement a ragged tensor type

I'm likely to pursue the first because there's no telling how long it'll be before the PyTorch devs add ragged tensors.

@dingo-actual dingo-actual added enhancement New feature or request help wanted Extra attention is needed labels Apr 23, 2024
@muditbhargava66
Copy link
Contributor

I worked on this issue #15.

@muditbhargava66
Copy link
Contributor

Should this issue be closed, or do you need any more changes? Please let me know if you have any further questions.

@dingo-actual
Copy link
Owner Author

Unfortunately, the fix you introduced assumes that calling .forward_() produces a valid result when called on the original input. What needs to happen during inference is for .forward() to use sample_mask_seg to pad the samples along the token dimension until they all have the same length. The part I haven't gotten around to is going through the math to determine a choice of padding token that doesn't affect downstream calculations.

For the moment, I'm going to revert the change, just to maintain functionality (slow as it is). I really appreciate your putting in time on this though!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants