Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SinkCache (StreamLLM) implemented over Post-RoPE Key cache might result in confused position for inference #35350

Open
4 tasks
wangguangtao0722 opened this issue Dec 19, 2024 · 0 comments
Labels

Comments

@wangguangtao0722
Copy link

System Info

The current implementation of SinkCache might result in confusion position for attention computation.

  1. HF Key Cache: Post-RoPE

  2. After Pre-filling over sequence of length N
    i) Post-RoPE Key cache with length N: [ 0, 1, 2, …, N-1]

  3. During first token generation (Current Query & Key position is N, since KV cache size is N.), what does StreamLLM update based (Sink_size = S, Recent_window = R)

    i) Initial: Post-RoPE Key cache: [0, 1, …, S-1] + [N - R +1, …, N-1] + [N]; len([N - R +1, …, N-1] = R - 1
    ii) Rotate:
    HF applies Rotation over R-1 keys with position (N - R +1, …, N-1) to make their position as (S, S +1, …, S + R - 2) and keep this in StreamLLM KV cache.
    iii) Updated StreamLLM Key Cache position: [ 0, 1, …, S-1] + [S, …, S + R - 2] + [N], that is: the last (S+R)th element with actual position N, since len([S, …, S + R - 2]) = R - 1

  4. Continue next token prediction.
    i) Current Query and Key position are depends on Stream KV cache size = S + R
    ii) StreamLLM Key Cache position update:
    a) Initial: [ 0, 1, …, S-1] + [S + 1, …, S + R - 2, N] + [S + R] (note: N is the position of the (S+R - 1)th element.
    b) Rotate (all keep ones minus 1)
    [0, 1, …, S-1] + [S, …, S + R - 3, N-1] + [S + R] (position S+R is not involved Rotation, please refer https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py, SinkCache, row: 1038)
    len([S, …, S + R - 3, N-1]) = R - 1

Now: We get S + R Key cache with positions = [0, 1, …, S-1] + [S, …, S + R - 3, N-1] + [S + R]. Since the current query position is S+R. For long context inference, N - 1 >> S + R. This means that Query will interact with a Key with future position.

Note: all the number in [] means the position used for token generation.

Who can help?

@gante @ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Post RoPE key cache:

https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py (row: 277 --283)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

Inference stage:
next token position = Key cache length

https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L489 (Row: 556 - 563)
        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

Expected behavior

During inference: Query position should be larger or equal to Key position.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant