Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable compilation #6

Merged
merged 5 commits into from
Mar 26, 2024
Merged

Enable compilation #6

merged 5 commits into from
Mar 26, 2024

Conversation

tengomucho
Copy link
Collaborator

What does this PR do?

This model implements a workaround to a problem appearing when compilation is used in torch_xla + TPU and direct assignment is used on inputs.
After fixing that, it is possible to enable compilation by default on models that support static cache, such as gemma, leading to an improved performance in decoding after the first token is decoded.

when direct assignment is used in inputs, apparently the compiled model
can give wrong results, as explained in this issue:

pytorch/xla#6796

The workaround seems to be to use the `index_put_` method.
@tengomucho tengomucho marked this pull request as ready for review March 22, 2024 09:36
@tengomucho tengomucho requested a review from mfuntowicz March 22, 2024 09:36
Copy link
Member

@mfuntowicz mfuntowicz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Niiiiiicezzz

@@ -493,8 +492,8 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
dtype=torch.int64,
device=self.model.device,
)
attention_mask[i, :] = slot.attention_mask
position_ids[i, 0] = slot.cur_position
attention_mask.index_put_([torch.tensor([i])], slot.attention_mask)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we just put 1 here? I.e. the new token in the attention_mask we want to attend to

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arf, nope, it was correct, and I just broke it: I need to put the slot attention mask in the i-th line (corresponding the batch i). I am going to fix it.

It's useless to put another value, and the index was put there by
mistake.
@tengomucho tengomucho merged commit edf1b9e into main Mar 26, 2024
1 check failed
@mfuntowicz mfuntowicz deleted the compilation-on branch March 26, 2024 10:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants