From d47b62c9c8a3fe58cfa0c0d6a784380dc1e7f0ba Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Thu, 12 Dec 2024 10:06:42 -0800 Subject: [PATCH] #0: Apply fix for paged KV fill cache by unpadding page table --- .../tt/llama_attention_optimized.py | 11 +++- .../t3000/llama2_70b/tt/llama_generation.py | 66 ++++++++++++------- 2 files changed, 53 insertions(+), 24 deletions(-) diff --git a/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py index ffa45245f73f..108a97661867 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py @@ -474,14 +474,21 @@ def prefill_attn_mqa( values = self.layer_past[1] if page_table: + # In the case that the tokens have been padded along the seq len dimension, we need to fill the cache with the unpadded k/v values. + # Assume that the page table does not have padding, so we can use it to get the unpadded page len. + block_size = keys.shape[2] # If chunked prefill, use chunk_page_table if given, otherwise use page_table. fill_page_table = chunk_page_table if chunk_page_table is not None else page_table + page_len = fill_page_table.shape[1] * block_size + + k_fill_sliced = key_layer[:, :, :page_len, :] if page_len < key_layer.shape[2] else key_layer + v_fill_sliced = value_layer[:, :, :page_len, :] if page_len < value_layer.shape[2] else value_layer ttnn.experimental.paged_fill_cache( - keys, ttnn.experimental.typecast(key_layer, self.kv_dtype), fill_page_table, batch_idx=user_id + keys, ttnn.experimental.typecast(k_fill_sliced, self.kv_dtype), fill_page_table, batch_idx=user_id ) ttnn.experimental.paged_fill_cache( - values, ttnn.experimental.typecast(value_layer, self.kv_dtype), fill_page_table, batch_idx=user_id + values, ttnn.experimental.typecast(v_fill_sliced, self.kv_dtype), fill_page_table, batch_idx=user_id ) else: ttnn.fill_cache(keys, ttnn.experimental.typecast(key_layer, self.kv_dtype), user_id) diff --git a/models/demos/t3000/llama2_70b/tt/llama_generation.py b/models/demos/t3000/llama2_70b/tt/llama_generation.py index f9b1321e8e4a..e76ce417dcf2 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_generation.py +++ b/models/demos/t3000/llama2_70b/tt/llama_generation.py @@ -268,23 +268,30 @@ def prefill_forward_single_user( Chunked prefill requires paged attention. There are some strange constraints which we must meet: - page_table, which is used in SDPA, must match batch size of inputs, which is 1. This is because SDPA checks that page table batch dim matches input batch dim. Therefore we must slice the page table for the current user. + - page_table must also have enough entries in each chunk, so it will be padded with zeros if necessary. - chunked_page_table is the slice of the page table for the current chunk. This is used by paged_fill_cache to keep it otherwise unaware that it is operating on a chunk. - due to the above point, we must always set user_id to 0 for chunked prefill. """ + # TODO: Ensure that last_token_idx is within the last chunk! assert page_table is not None, "page_table must be provided for chunked prefill" + # TODO: Uncomment assert assert kv_cache is not None, "kv_cache must be provided for chunked prefill" chunk_size = get_max_prefill_chunk_size(seq_len, self.tt_model.model_config["MAX_PREFILL_SEQ_LEN"]) block_size = get_block_size(kv_cache) - last_token_idx_in_chunk = last_token_idx % chunk_size + last_token_idx_in_chunk = last_token_idx % chunk_size if last_token_idx is not None else None + last_chunk_start = (last_token_idx // chunk_size) * chunk_size if last_token_idx is not None else None page_table_user = page_table[user_id : user_id + 1, :] + # Pad page table to match number of blocks in seq_len + num_padding_blocks = num_blocks_in_seq(seq_len, block_size) - page_table_user.shape[1] + page_table_user_padded = torch.cat( + [page_table_user, torch.zeros(1, num_padding_blocks, dtype=torch.int32)], dim=-1 + ) CHUNK_USER_ID = 0 - logger.info(f"Using chunked prefill with chunk_size={chunk_size}") - logger.info(f"Page table shape: {page_table.shape}") - logger.info(f"Page table user shape: {page_table_user.shape}") - logger.info(f"Block size: {block_size}") - logger.info(f"Last token idx in chunk: {last_token_idx_in_chunk}") - logger.info(f"Sequence length: {seq_len}") + + # Calculate which chunk contains the last_token_idx + + logits_list = [] for chunk_start in range(0, seq_len, chunk_size): chunk_end = chunk_start + chunk_size assert ( @@ -292,10 +299,7 @@ def prefill_forward_single_user( ), f"Chunk end should be less than seq_len, got chunk_end={chunk_end} and seq_len={seq_len}" chunk_tokens = tokens[:, chunk_start:chunk_end] chunk_page_table = page_table_user[:, chunk_start // block_size : chunk_end // block_size] - logger.info(f"Chunk start: {chunk_start}") - logger.info(f"Chunk end: {chunk_end}") - logger.info(f"Chunk tokens shape: {chunk_tokens.shape}") - logger.info(f"Chunk page table shape: {chunk_page_table.shape}") + ( tt_inp_emb, start_pos, @@ -305,14 +309,15 @@ def prefill_forward_single_user( page_table_tt, chunk_page_table_tt, ) = self.tt_model.prepare_inputs( - tokens, + chunk_tokens, start_pos=chunk_start, mode="prefill", - page_table=page_table_user, + page_table=page_table_user_padded, chunk_page_table=chunk_page_table, ) - tt_logits = self.tt_model.prefill_forward_single_user( - chunk_tokens, + tt_logits = self.tt_model( + tt_inp_emb, + rot_mat, start_pos, user_id=CHUNK_USER_ID, last_token_idx=last_token_idx_in_chunk, @@ -323,6 +328,22 @@ def prefill_forward_single_user( chunk_start_idx=chunk_start, ) logger.info(f"TT logits shape: {tt_logits.shape}") + + logits = self._process_logits(tt_logits) + logits = logits.squeeze(1) + ttnn.deallocate(tt_logits) + + if last_token_idx is not None: + # If this was the chunk containing last_token_idx, we're done + if chunk_start == last_chunk_start: + return logits + else: + logits_list.append(logits) + + # Concatenate all logits + logits = torch.cat(logits_list, dim=-2) + return logits + else: ( tt_inp_emb, @@ -375,13 +396,7 @@ def prefill_forward(self, tokens: torch.Tensor, start_pos: int, page_table=None, [tokens[user_id : user_id + 1, :seq_len], torch.zeros(1, prefill_seq_len - seq_len).long()], dim=-1 ) if page_table is not None: - block_size = get_block_size(kv_cache) - num_padding_blocks = num_blocks_in_seq(prefill_seq_len, block_size) - num_blocks_in_seq( - seq_len, block_size - ) - page_table_user = torch.cat( - [page_table, torch.zeros(batch, num_padding_blocks, dtype=torch.int32)], dim=-1 - ) + page_table_user = _get_prefill_user_page_table(page_table, kv_cache, seq_len) logger.info(f"Filling kv cache for user {user_id + 1}") @@ -408,6 +423,13 @@ def _process_logits(self, tt_logits): return logits[..., : self.params.vocab_size].float() +def _get_prefill_user_page_table(page_table, kv_cache, prefill_len): + # Ensure page_table is not padded with extra blocks for paged_fill_cache to work properly + block_size = get_block_size(kv_cache) + num_blocks = num_blocks_in_seq(prefill_len, block_size) + return page_table[:, :num_blocks] + + def get_padded_prefill_len(seq_len): """ If seq_len is less than 32, pad to 32