You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The current implementation of SinkCache might result in confusion position for attention computation.
HF Key Cache: Post-RoPE
After Pre-filling over sequence of length N
i) Post-RoPE Key cache with length N: [ 0, 1, 2, …, N-1]
During first token generation (Current Query & Key position is N, since KV cache size is N.), what does StreamLLM update based (Sink_size = S, Recent_window = R)
i) Initial: Post-RoPE Key cache: [0, 1, …, S-1] + [N - R +1, …, N-1] + [N]; len([N - R +1, …, N-1] = R - 1
ii) Rotate:
HF applies Rotation over R-1 keys with position (N - R +1, …, N-1) to make their position as (S, S +1, …, S + R - 2) and keep this in StreamLLM KV cache.
iii) Updated StreamLLM Key Cache position: [ 0, 1, …, S-1] + [S, …, S + R - 2] + [N], that is: the last (S+R)th element with actual position N, since len([S, …, S + R - 2]) = R - 1
Continue next token prediction.
i) Current Query and Key position are depends on Stream KV cache size = S + R
ii) StreamLLM Key Cache position update:
a) Initial: [ 0, 1, …, S-1] + [S + 1, …, S + R - 2, N] + [S + R] (note: N is the position of the (S+R - 1)th element.
b) Rotate (all keep ones minus 1)
[0, 1, …, S-1] + [S, …, S + R - 3, N-1] + [S + R] (position S+R is not involved Rotation, please refer https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py, SinkCache, row: 1038)
len([S, …, S + R - 3, N-1]) = R - 1
Now: We get S + R Key cache with positions = [0, 1, …, S-1] + [S, …, S + R - 3, N-1] + [S + R]. Since the current query position is S+R. For long context inference, N - 1 >> S + R. This means that Query will interact with a Key with future position.
Note: all the number in [] means the position used for token generation.
An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
My own task or dataset (give details below)
Reproduction
Post RoPE key cache:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py (row: 277 --283)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
Inference stage:
next token position = Key cache length
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L489 (Row: 556 - 563)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
Expected behavior
During inference: Query position should be larger or equal to Key position.
The text was updated successfully, but these errors were encountered:
System Info
The current implementation of SinkCache might result in confusion position for attention computation.
HF Key Cache: Post-RoPE
After Pre-filling over sequence of length N
i) Post-RoPE Key cache with length N: [ 0, 1, 2, …, N-1]
During first token generation (Current Query & Key position is N, since KV cache size is N.), what does StreamLLM update based (Sink_size = S, Recent_window = R)
i) Initial: Post-RoPE Key cache: [0, 1, …, S-1] + [N - R +1, …, N-1] + [N]; len([N - R +1, …, N-1] = R - 1
ii) Rotate:
HF applies Rotation over R-1 keys with position (N - R +1, …, N-1) to make their position as (S, S +1, …, S + R - 2) and keep this in StreamLLM KV cache.
iii) Updated StreamLLM Key Cache position: [ 0, 1, …, S-1] + [S, …, S + R - 2] + [N], that is: the last (S+R)th element with actual position N, since len([S, …, S + R - 2]) = R - 1
Continue next token prediction.
i) Current Query and Key position are depends on Stream KV cache size = S + R
ii) StreamLLM Key Cache position update:
a) Initial: [ 0, 1, …, S-1] + [S + 1, …, S + R - 2, N] + [S + R] (note: N is the position of the (S+R - 1)th element.
b) Rotate (all keep ones minus 1)
[0, 1, …, S-1] + [S, …, S + R - 3, N-1] + [S + R] (position S+R is not involved Rotation, please refer
https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py, SinkCache, row: 1038
)len([S, …, S + R - 3, N-1]) = R - 1
Now: We get S + R Key cache with positions = [0, 1, …, S-1] + [S, …, S + R - 3, N-1] + [S + R]. Since the current query position is S+R. For long context inference, N - 1 >> S + R. This means that Query will interact with a Key with future position.
Note: all the number in [] means the position used for token generation.
Who can help?
@gante @ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Post RoPE key cache:
Inference stage:
next token position = Key cache length
Expected behavior
During inference: Query position should be larger or equal to Key position.
The text was updated successfully, but these errors were encountered: