Skip to content

Commit

Permalink
#0: Apply fix for paged KV fill cache by unpadding page table
Browse files Browse the repository at this point in the history
  • Loading branch information
cglagovichTT committed Dec 12, 2024
1 parent 718fa3d commit d47b62c
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 24 deletions.
11 changes: 9 additions & 2 deletions models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,14 +474,21 @@ def prefill_attn_mqa(
values = self.layer_past[1]

if page_table:
# In the case that the tokens have been padded along the seq len dimension, we need to fill the cache with the unpadded k/v values.
# Assume that the page table does not have padding, so we can use it to get the unpadded page len.
block_size = keys.shape[2]
# If chunked prefill, use chunk_page_table if given, otherwise use page_table.
fill_page_table = chunk_page_table if chunk_page_table is not None else page_table
page_len = fill_page_table.shape[1] * block_size

k_fill_sliced = key_layer[:, :, :page_len, :] if page_len < key_layer.shape[2] else key_layer
v_fill_sliced = value_layer[:, :, :page_len, :] if page_len < value_layer.shape[2] else value_layer

ttnn.experimental.paged_fill_cache(
keys, ttnn.experimental.typecast(key_layer, self.kv_dtype), fill_page_table, batch_idx=user_id
keys, ttnn.experimental.typecast(k_fill_sliced, self.kv_dtype), fill_page_table, batch_idx=user_id
)
ttnn.experimental.paged_fill_cache(
values, ttnn.experimental.typecast(value_layer, self.kv_dtype), fill_page_table, batch_idx=user_id
values, ttnn.experimental.typecast(v_fill_sliced, self.kv_dtype), fill_page_table, batch_idx=user_id
)
else:
ttnn.fill_cache(keys, ttnn.experimental.typecast(key_layer, self.kv_dtype), user_id)
Expand Down
66 changes: 44 additions & 22 deletions models/demos/t3000/llama2_70b/tt/llama_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,34 +268,38 @@ def prefill_forward_single_user(
Chunked prefill requires paged attention. There are some strange constraints which we must meet:
- page_table, which is used in SDPA, must match batch size of inputs, which is 1. This is because SDPA
checks that page table batch dim matches input batch dim. Therefore we must slice the page table for the current user.
- page_table must also have enough entries in each chunk, so it will be padded with zeros if necessary.
- chunked_page_table is the slice of the page table for the current chunk. This is used by paged_fill_cache
to keep it otherwise unaware that it is operating on a chunk.
- due to the above point, we must always set user_id to 0 for chunked prefill.
"""
# TODO: Ensure that last_token_idx is within the last chunk!
assert page_table is not None, "page_table must be provided for chunked prefill"
# TODO: Uncomment assert
assert kv_cache is not None, "kv_cache must be provided for chunked prefill"
chunk_size = get_max_prefill_chunk_size(seq_len, self.tt_model.model_config["MAX_PREFILL_SEQ_LEN"])
block_size = get_block_size(kv_cache)
last_token_idx_in_chunk = last_token_idx % chunk_size
last_token_idx_in_chunk = last_token_idx % chunk_size if last_token_idx is not None else None
last_chunk_start = (last_token_idx // chunk_size) * chunk_size if last_token_idx is not None else None
page_table_user = page_table[user_id : user_id + 1, :]
# Pad page table to match number of blocks in seq_len
num_padding_blocks = num_blocks_in_seq(seq_len, block_size) - page_table_user.shape[1]
page_table_user_padded = torch.cat(
[page_table_user, torch.zeros(1, num_padding_blocks, dtype=torch.int32)], dim=-1
)
CHUNK_USER_ID = 0
logger.info(f"Using chunked prefill with chunk_size={chunk_size}")
logger.info(f"Page table shape: {page_table.shape}")
logger.info(f"Page table user shape: {page_table_user.shape}")
logger.info(f"Block size: {block_size}")
logger.info(f"Last token idx in chunk: {last_token_idx_in_chunk}")
logger.info(f"Sequence length: {seq_len}")

# Calculate which chunk contains the last_token_idx

logits_list = []
for chunk_start in range(0, seq_len, chunk_size):
chunk_end = chunk_start + chunk_size
assert (
chunk_end <= seq_len
), f"Chunk end should be less than seq_len, got chunk_end={chunk_end} and seq_len={seq_len}"
chunk_tokens = tokens[:, chunk_start:chunk_end]
chunk_page_table = page_table_user[:, chunk_start // block_size : chunk_end // block_size]
logger.info(f"Chunk start: {chunk_start}")
logger.info(f"Chunk end: {chunk_end}")
logger.info(f"Chunk tokens shape: {chunk_tokens.shape}")
logger.info(f"Chunk page table shape: {chunk_page_table.shape}")

(
tt_inp_emb,
start_pos,
Expand All @@ -305,14 +309,15 @@ def prefill_forward_single_user(
page_table_tt,
chunk_page_table_tt,
) = self.tt_model.prepare_inputs(
tokens,
chunk_tokens,
start_pos=chunk_start,
mode="prefill",
page_table=page_table_user,
page_table=page_table_user_padded,
chunk_page_table=chunk_page_table,
)
tt_logits = self.tt_model.prefill_forward_single_user(
chunk_tokens,
tt_logits = self.tt_model(
tt_inp_emb,
rot_mat,
start_pos,
user_id=CHUNK_USER_ID,
last_token_idx=last_token_idx_in_chunk,
Expand All @@ -323,6 +328,22 @@ def prefill_forward_single_user(
chunk_start_idx=chunk_start,
)
logger.info(f"TT logits shape: {tt_logits.shape}")

logits = self._process_logits(tt_logits)
logits = logits.squeeze(1)
ttnn.deallocate(tt_logits)

if last_token_idx is not None:
# If this was the chunk containing last_token_idx, we're done
if chunk_start == last_chunk_start:
return logits
else:
logits_list.append(logits)

# Concatenate all logits
logits = torch.cat(logits_list, dim=-2)
return logits

else:
(
tt_inp_emb,
Expand Down Expand Up @@ -375,13 +396,7 @@ def prefill_forward(self, tokens: torch.Tensor, start_pos: int, page_table=None,
[tokens[user_id : user_id + 1, :seq_len], torch.zeros(1, prefill_seq_len - seq_len).long()], dim=-1
)
if page_table is not None:
block_size = get_block_size(kv_cache)
num_padding_blocks = num_blocks_in_seq(prefill_seq_len, block_size) - num_blocks_in_seq(
seq_len, block_size
)
page_table_user = torch.cat(
[page_table, torch.zeros(batch, num_padding_blocks, dtype=torch.int32)], dim=-1
)
page_table_user = _get_prefill_user_page_table(page_table, kv_cache, seq_len)

logger.info(f"Filling kv cache for user {user_id + 1}")

Expand All @@ -408,6 +423,13 @@ def _process_logits(self, tt_logits):
return logits[..., : self.params.vocab_size].float()


def _get_prefill_user_page_table(page_table, kv_cache, prefill_len):
# Ensure page_table is not padded with extra blocks for paged_fill_cache to work properly
block_size = get_block_size(kv_cache)
num_blocks = num_blocks_in_seq(prefill_len, block_size)
return page_table[:, :num_blocks]


def get_padded_prefill_len(seq_len):
"""
If seq_len is less than 32, pad to 32
Expand Down

0 comments on commit d47b62c

Please sign in to comment.