From d0f1c15d771d19ae286218f77d4eae9d3760d1f8 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Thu, 12 Dec 2024 10:46:04 -0800 Subject: [PATCH] #0: Remove TODOs --- models/demos/t3000/llama2_70b/tt/llama_generation.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/models/demos/t3000/llama2_70b/tt/llama_generation.py b/models/demos/t3000/llama2_70b/tt/llama_generation.py index e76ce417dcf2..86beb3027df7 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_generation.py +++ b/models/demos/t3000/llama2_70b/tt/llama_generation.py @@ -273,13 +273,12 @@ def prefill_forward_single_user( to keep it otherwise unaware that it is operating on a chunk. - due to the above point, we must always set user_id to 0 for chunked prefill. """ - # TODO: Ensure that last_token_idx is within the last chunk! assert page_table is not None, "page_table must be provided for chunked prefill" - # TODO: Uncomment assert assert kv_cache is not None, "kv_cache must be provided for chunked prefill" chunk_size = get_max_prefill_chunk_size(seq_len, self.tt_model.model_config["MAX_PREFILL_SEQ_LEN"]) block_size = get_block_size(kv_cache) last_token_idx_in_chunk = last_token_idx % chunk_size if last_token_idx is not None else None + # Calculate which chunk contains the last_token_idx last_chunk_start = (last_token_idx // chunk_size) * chunk_size if last_token_idx is not None else None page_table_user = page_table[user_id : user_id + 1, :] # Pad page table to match number of blocks in seq_len @@ -289,8 +288,6 @@ def prefill_forward_single_user( ) CHUNK_USER_ID = 0 - # Calculate which chunk contains the last_token_idx - logits_list = [] for chunk_start in range(0, seq_len, chunk_size): chunk_end = chunk_start + chunk_size