-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable compilation #6
Conversation
when direct assignment is used in inputs, apparently the compiled model can give wrong results, as explained in this issue: pytorch/xla#6796 The workaround seems to be to use the `index_put_` method.
This improves inference time a lot.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Niiiiiicezzz
@@ -493,8 +492,8 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa | |||
dtype=torch.int64, | |||
device=self.model.device, | |||
) | |||
attention_mask[i, :] = slot.attention_mask | |||
position_ids[i, 0] = slot.cur_position | |||
attention_mask.index_put_([torch.tensor([i])], slot.attention_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we just put 1
here? I.e. the new token in the attention_mask
we want to attend to
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
arf, nope, it was correct, and I just broke it: I need to put the slot attention mask in the i-th line (corresponding the batch i). I am going to fix it.
It's useless to put another value, and the index was put there by mistake.
What does this PR do?
This model implements a workaround to a problem appearing when compilation is used in torch_xla + TPU and direct assignment is used on inputs.
After fixing that, it is possible to enable compilation by default on models that support static cache, such as gemma, leading to an improved performance in decoding after the first token is decoded.