-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Bug: Gemma2 the past_key_value.update() function has added a new parameter "sliding_window" to support the _sliding_update function. #31786
Conversation
…eter "sliding_window" to support the `_sliding_update` function.
Could you take a look when you have a minute @sanchit-gandhi ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @kkk935208447! Thanks for opening this PR 🤗 I believe the bug you're experiencing is actually un-related to the sliding window mechanism, which I've explained below is correctly updated using the current code on main
. Looking at your codesnippet, it looks like there are CUDA device errors which are being thrown, which might be related to how you're moving the tensors across devices. Could you please paste the full traceback of your error? That would help massively in pin-pointing where you're hitting the error. Thanks!
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding | ||
attention_mask = attention_mask * torch.tril( | ||
torch.ones_like(attention_mask), diagonal=(self.sliding_window - cache_position[-1]) | ||
if ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for propagating the update here! It's not essential to keep the diff file sync'd with modeling_gemma2.py
, since it's just used in the process of the integration and we don't actually make any imports from this file, but helpful to motivate the other changes in this PR!
@@ -338,7 +339,8 @@ def forward( | |||
"sliding_window": self.sliding_window, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that we set sliding_window
to the cache_kwargs
here, and then pass the dict of cache_kwargs
to the .update
function for the cache 2 lines later. Hence, the sliding window should already be handled with the current code.
@@ -338,7 +339,8 @@ def forward( | |||
"sliding_window": self.sliding_window, | |||
"cache_position": cache_position, | |||
} | |||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See here how we pass the cache_kwargs
to the update function of the hybrid cache, where cache_kwargs
are defined as {..., "sliding_window": self.sliding_window"}
a few lines above. Hence, there should be no need to pass sliding_window
explicitly!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#31775
There is a similar problem to mine here. I don't think it's an issue with my personal device, as I've debugged it multiple times and this phenomenon still occurs.
"The cache was created with alternating max seq lenghts of 4k and 8k, but all layers were being updated as if they were 8k, causing out-of-bounds errors and CUDA exceptions."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Problem is with the function definition:
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
sliding_window: Optional[int] = None,
)
cache_kwargs
is not actually a kwargs argument (missing the **
), it's just a dict. Adding a sliding_window
key to it doesn't affect the separate sliding_window
argument to the update function, which just keeps its default value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah that's a good catch, none of the cache kwargs have been kwargs up until this point. and we never did **cache_kwargs
. I'll update this in another PR if you don't have time to fix. TBH we should remove sliding window from the update kwargs !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, is there a problem with the order of alternating implementation of the SWA layer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it's flipped. SWA should be on the even layers according to gemma_pytorch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep the PR is here #31775 we'll do a patch for it
What does this PR do?
System Info
transformers 4.42.3
Now gemma2 model generates long text that exceeds the window size (>4096), it will report a CUDA error, which seems to be a problem with the failure of the _sliding_update function in HybridCache. The error is as follows:
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker