From 5f7620684b341cb664c54b0989dc5611bdbb7750 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Thu, 12 Dec 2024 10:39:13 -0800 Subject: [PATCH] #0: Address PR comments, minor test changes --- .../llama2_70b/tests/test_llama_attention.py | 139 +++++++----------- .../tests/test_llama_device_perf.py | 1 + .../llama2_70b/tests/test_llama_model.py | 60 +++++--- .../tt/llama_attention_optimized.py | 1 + 4 files changed, 94 insertions(+), 107 deletions(-) diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_attention.py b/models/demos/t3000/llama2_70b/tests/test_llama_attention.py index 7b2222caf0a8..72bd9b7091f6 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_attention.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_attention.py @@ -327,105 +327,74 @@ def run_test_LlamaAttention_inference( attn_mask, ) - if mode == "prefill": + if is_chunked_prefill: + assert mode == "prefill", "Chunked prefill should only be run in prefill mode" assert start_pos == 0, "Start pos should be 0 for chunked prefill" assert batch == 1, "Batch should be 1 for chunked prefill" - if is_chunked_prefill: - """ - In chunked prefill mode, we need to split the prefill input into chunks. - Each chunk will be processed sequentially. Each chunk must be given the appropriate - sin/cos values. Also, each chunk must be given a partial page table for paged_fill_cache - so that paged_fill_cache fills the current chunk properly. - Be vary careful that we don't pick up cached sin/cos values since they will be incorrect. - """ - for chunk_start in range(0, seq_len, chunk_size): - chunk_end = chunk_start + chunk_size - assert chunk_end <= seq_len, "Chunk end should be less than seq_len" - chunk_page_table = page_table[ - :, - chunk_start - // paged_attention_config.block_size : chunk_end - // paged_attention_config.block_size, - ] - chunk_page_table_tt = ttnn.as_tensor( - chunk_page_table, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - device=t3k_mesh_device, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), - ) - # SDPA requires that the page table batch dim matches the input batch dim, which must be 1 in prefill - prefill_page_table = page_table[0:1, :] - prefill_page_table_tt = ttnn.as_tensor( - prefill_page_table, - dtype=ttnn.int32, - layout=ttnn.ROW_MAJOR_LAYOUT, - device=t3k_mesh_device, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), - ) - - chunk_tt_input = tt_input[:, chunk_start:chunk_end] - # TT hardware execution ------------------------------------------------------------- - attention_input, _, rot_mat, cache_idxs = tt_llama_attention_prepare_inputs( - tt_LlamaAttention_model, - chunk_tt_input, - chunk_start, - mode, - configuration.rope_theta, - rope_setup=None, - use_scaled_rope=configuration.use_scaled_rope, - ) - tt_chunk_out = tt_LlamaAttention_model( - attention_input, - rot_mat, - None, - cache_idxs=None, - page_table=prefill_page_table_tt, - mode=mode, - chunk_page_table=chunk_page_table_tt, - chunk_start_idx=chunk_start, - ) - - tt_chunk_out = ttnn.from_device(tt_chunk_out) - tt_chunk_out = ttnn.to_torch(tt_chunk_out, mesh_composer=ConcatMeshToTensor(t3k_mesh_device, dim=3)) - tt_chunk_out = tt_chunk_out.permute(2, 1, 0, 3).squeeze(1) # [batch, seq_len, hidden_dim] - - # check outputs ---------------------------------------------------------------------- - pytorch_chunk_out = pytorch_out[:, chunk_start:chunk_end] - does_pass, output_pcc = comp_pcc(pytorch_chunk_out, tt_chunk_out, pcc) - logger.info(f"Chunk {chunk_start} output: {output_pcc}") - all_pccs.append(extract_pcc_from_log(output_pcc)) - - else: + + """ + In chunked prefill mode, we need to split the prefill input into chunks. + Each chunk will be processed sequentially. Each chunk must be given the appropriate + sin/cos values. Also, each chunk must be given a partial page table for paged_fill_cache + so that paged_fill_cache fills the current chunk properly. + Be vary careful that we don't pick up cached sin/cos values since they will be incorrect. + """ + for chunk_start in range(0, seq_len, chunk_size): + chunk_end = chunk_start + chunk_size + assert chunk_end <= seq_len, "Chunk end should be less than seq_len" + chunk_page_table = page_table[ + :, + chunk_start // paged_attention_config.block_size : chunk_end // paged_attention_config.block_size, + ] + chunk_page_table_tt = ttnn.as_tensor( + chunk_page_table, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + device=t3k_mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ) + # SDPA requires that the page table batch dim matches the input batch dim, which must be 1 in prefill + prefill_page_table = page_table[0:1, :] + prefill_page_table_tt = ttnn.as_tensor( + prefill_page_table, + dtype=ttnn.int32, + layout=ttnn.ROW_MAJOR_LAYOUT, + device=t3k_mesh_device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ) + + chunk_tt_input = tt_input[:, chunk_start:chunk_end] # TT hardware execution ------------------------------------------------------------- - attention_input, start_pos, rot_mat, cache_idxs = tt_llama_attention_prepare_inputs( + attention_input, _, rot_mat, cache_idxs = tt_llama_attention_prepare_inputs( tt_LlamaAttention_model, - tt_input, - start_pos, + chunk_tt_input, + chunk_start, mode, configuration.rope_theta, rope_setup=None, use_scaled_rope=configuration.use_scaled_rope, ) - - tt_out = tt_LlamaAttention_model( + tt_chunk_out = tt_LlamaAttention_model( attention_input, rot_mat, - start_pos, + None, cache_idxs=None, - page_table=page_table_tt, + page_table=prefill_page_table_tt, mode=mode, + chunk_page_table=chunk_page_table_tt, + chunk_start_idx=chunk_start, ) - tt_out = ttnn.from_device(tt_out) - tt_out = ttnn.to_torch(tt_out, mesh_composer=ConcatMeshToTensor(t3k_mesh_device, dim=3)) - tt_out = tt_out.permute(2, 1, 0, 3).squeeze(1) # [batch, seq_len, hidden_dim] + tt_chunk_out = ttnn.from_device(tt_chunk_out) + tt_chunk_out = ttnn.to_torch(tt_chunk_out, mesh_composer=ConcatMeshToTensor(t3k_mesh_device, dim=3)) + tt_chunk_out = tt_chunk_out.permute(2, 1, 0, 3).squeeze(1) # [batch, seq_len, hidden_dim] # check outputs ---------------------------------------------------------------------- - does_pass, output_pcc = comp_pcc(pytorch_out, tt_out, pcc) - logger.info(f"Output: {output_pcc}") + pytorch_chunk_out = pytorch_out[:, chunk_start:chunk_end] + does_pass, output_pcc = comp_pcc(pytorch_chunk_out, tt_chunk_out, pcc) + logger.info(f"Chunk {chunk_start} output: {output_pcc}") all_pccs.append(extract_pcc_from_log(output_pcc)) else: @@ -436,7 +405,7 @@ def run_test_LlamaAttention_inference( start_pos, mode, configuration.rope_theta, - rope_setup=rope_setup, + rope_setup=rope_setup if mode == "decode" else None, use_scaled_rope=configuration.use_scaled_rope, ) @@ -453,7 +422,8 @@ def run_test_LlamaAttention_inference( tt_out = ttnn.to_torch(tt_out, mesh_composer=ConcatMeshToTensor(t3k_mesh_device, dim=3)) tt_out = tt_out.permute(2, 1, 0, 3).squeeze(1) # [batch, seq_len, hidden_dim] - tt_out = tt_out[:batch] + if mode == "decode": + tt_out = tt_out[:batch] # check outputs ---------------------------------------------------------------------- does_pass, output_pcc = comp_pcc(pytorch_out, tt_out, pcc) @@ -485,7 +455,6 @@ def run_test_LlamaAttention_inference( # concat the pasts by heads tt_layer_present_all = [ttnn.from_device(lp) for lp in tt_LlamaAttention_model.layer_past] if paged_attention: - tt_layer_present_all = [ttnn.from_device(lp) for lp in tt_LlamaAttention_model.layer_past] tt_layer_present_all = [ ( ttnn.to_torch(lp, mesh_composer=ConcatMeshToTensor(t3k_mesh_device, dim=1))[reverse_permutation] diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_device_perf.py b/models/demos/t3000/llama2_70b/tests/test_llama_device_perf.py index e53711014339..a2eee1391703 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_device_perf.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_device_perf.py @@ -76,6 +76,7 @@ def test_run_device_perf_llama( t3k_mesh_device, batch, seq_len, + max_batch_size, max_context_len, N_LAYERS_TO_PCC[n_layers], model_config, diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_model.py b/models/demos/t3000/llama2_70b/tests/test_llama_model.py index a16032d4327d..ef41fbe6d892 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_model.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_model.py @@ -298,34 +298,50 @@ def run_test_LlamaModel_inference( logger.info(f"Average Top-5 over {len(all_top5)} tokens: {sum(all_top5) / len(all_top5)}") # Check kv cache # PyTorch output -------------------------------------------------------------------- - if chunk_size is None: - pytorch_layer_present = [ - pytorch_model.model.layers[0] - .attention.cache_k.clone() - .permute(0, 2, 1, 3)[:batch, ...], # [batch, n_kv_heads, seq, head_dim] - pytorch_model.model.layers[0] - .attention.cache_v.clone() - .permute(0, 2, 1, 3)[:batch, ...], # [batch, n_kv_heads, seq, head_dim] + tt_layer_present_all = [ttnn.from_device(lp) for lp in tt_model.layers[0].attention.layer_past] + if paged_attention: + tt_layer_present_all = [ + ( + ttnn.to_torch(lp, mesh_composer=ConcatMeshToTensor(t3k_mesh_device, dim=1))[reverse_permutation] + .reshape( + max_batch_size, + paged_attention_config.max_num_blocks // max_batch_size, + configuration.n_kv_heads, + paged_attention_config.block_size, + tt_model.head_dim, + ) + .transpose(1, 2) + .reshape(max_batch_size, configuration.n_kv_heads, -1, tt_model.head_dim)[:batch, ...] + ) + for lp in tt_layer_present_all ] - - tt_layer_present_all = [ttnn.from_device(lp) for lp in tt_model.layers[0].attention.layer_past] + else: tt_layer_present_all = [ ttnn.to_torch(lp, mesh_composer=ConcatMeshToTensor(t3k_mesh_device, dim=1))[:batch, ...] for lp in tt_layer_present_all ] - cache_test_pass = check_kv_cache( - pytorch_layer_present, - tt_layer_present_all, - generation_start_pos, - generation_length, - seq_len, - mode == "prefill", - pcc, - ) - all_tests_pass = all_tests_pass and cache_test_pass - if all_tests_pass: - logger.info(f"{llama_version} output Passed!") + pytorch_layer_present = [ + pytorch_model.model.layers[0] + .attention.cache_k.clone() + .permute(0, 2, 1, 3)[:batch, ...], # [batch, n_kv_heads, seq, head_dim] + pytorch_model.model.layers[0] + .attention.cache_v.clone() + .permute(0, 2, 1, 3)[:batch, ...], # [batch, n_kv_heads, seq, head_dim] + ] + + cache_test_pass = check_kv_cache( + pytorch_layer_present, + tt_layer_present_all, + generation_start_pos, + generation_length, + seq_len, + mode == "prefill", + pcc, + ) + all_tests_pass = all_tests_pass and cache_test_pass + if all_tests_pass: + logger.info(f"{llama_version} output Passed!") assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" diff --git a/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py index 108a97661867..547ed7e64abf 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py @@ -525,6 +525,7 @@ def prefill_attn_mqa( is_causal=True, scale=self.scale, program_config=pc_sdpa, + compute_kernel_config=self.model_config["SDPA_COMPUTE_KERNEL_CONFIG"], ) # deallocate keys and values