Skip to content

Commit

Permalink
#0: Address PR comments, minor test changes
Browse files Browse the repository at this point in the history
  • Loading branch information
cglagovichTT committed Dec 12, 2024
1 parent bcda94a commit 5f76206
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 107 deletions.
139 changes: 54 additions & 85 deletions models/demos/t3000/llama2_70b/tests/test_llama_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,105 +327,74 @@ def run_test_LlamaAttention_inference(
attn_mask,
)

if mode == "prefill":
if is_chunked_prefill:
assert mode == "prefill", "Chunked prefill should only be run in prefill mode"
assert start_pos == 0, "Start pos should be 0 for chunked prefill"
assert batch == 1, "Batch should be 1 for chunked prefill"
if is_chunked_prefill:
"""
In chunked prefill mode, we need to split the prefill input into chunks.
Each chunk will be processed sequentially. Each chunk must be given the appropriate
sin/cos values. Also, each chunk must be given a partial page table for paged_fill_cache
so that paged_fill_cache fills the current chunk properly.
Be vary careful that we don't pick up cached sin/cos values since they will be incorrect.
"""
for chunk_start in range(0, seq_len, chunk_size):
chunk_end = chunk_start + chunk_size
assert chunk_end <= seq_len, "Chunk end should be less than seq_len"
chunk_page_table = page_table[
:,
chunk_start
// paged_attention_config.block_size : chunk_end
// paged_attention_config.block_size,
]
chunk_page_table_tt = ttnn.as_tensor(
chunk_page_table,
dtype=ttnn.int32,
layout=ttnn.ROW_MAJOR_LAYOUT,
device=t3k_mesh_device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device),
)
# SDPA requires that the page table batch dim matches the input batch dim, which must be 1 in prefill
prefill_page_table = page_table[0:1, :]
prefill_page_table_tt = ttnn.as_tensor(
prefill_page_table,
dtype=ttnn.int32,
layout=ttnn.ROW_MAJOR_LAYOUT,
device=t3k_mesh_device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device),
)

chunk_tt_input = tt_input[:, chunk_start:chunk_end]
# TT hardware execution -------------------------------------------------------------
attention_input, _, rot_mat, cache_idxs = tt_llama_attention_prepare_inputs(
tt_LlamaAttention_model,
chunk_tt_input,
chunk_start,
mode,
configuration.rope_theta,
rope_setup=None,
use_scaled_rope=configuration.use_scaled_rope,
)
tt_chunk_out = tt_LlamaAttention_model(
attention_input,
rot_mat,
None,
cache_idxs=None,
page_table=prefill_page_table_tt,
mode=mode,
chunk_page_table=chunk_page_table_tt,
chunk_start_idx=chunk_start,
)

tt_chunk_out = ttnn.from_device(tt_chunk_out)
tt_chunk_out = ttnn.to_torch(tt_chunk_out, mesh_composer=ConcatMeshToTensor(t3k_mesh_device, dim=3))
tt_chunk_out = tt_chunk_out.permute(2, 1, 0, 3).squeeze(1) # [batch, seq_len, hidden_dim]

# check outputs ----------------------------------------------------------------------
pytorch_chunk_out = pytorch_out[:, chunk_start:chunk_end]
does_pass, output_pcc = comp_pcc(pytorch_chunk_out, tt_chunk_out, pcc)
logger.info(f"Chunk {chunk_start} output: {output_pcc}")
all_pccs.append(extract_pcc_from_log(output_pcc))

else:

"""
In chunked prefill mode, we need to split the prefill input into chunks.
Each chunk will be processed sequentially. Each chunk must be given the appropriate
sin/cos values. Also, each chunk must be given a partial page table for paged_fill_cache
so that paged_fill_cache fills the current chunk properly.
Be vary careful that we don't pick up cached sin/cos values since they will be incorrect.
"""
for chunk_start in range(0, seq_len, chunk_size):
chunk_end = chunk_start + chunk_size
assert chunk_end <= seq_len, "Chunk end should be less than seq_len"
chunk_page_table = page_table[
:,
chunk_start // paged_attention_config.block_size : chunk_end // paged_attention_config.block_size,
]
chunk_page_table_tt = ttnn.as_tensor(
chunk_page_table,
dtype=ttnn.int32,
layout=ttnn.ROW_MAJOR_LAYOUT,
device=t3k_mesh_device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device),
)
# SDPA requires that the page table batch dim matches the input batch dim, which must be 1 in prefill
prefill_page_table = page_table[0:1, :]
prefill_page_table_tt = ttnn.as_tensor(
prefill_page_table,
dtype=ttnn.int32,
layout=ttnn.ROW_MAJOR_LAYOUT,
device=t3k_mesh_device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device),
)

chunk_tt_input = tt_input[:, chunk_start:chunk_end]
# TT hardware execution -------------------------------------------------------------
attention_input, start_pos, rot_mat, cache_idxs = tt_llama_attention_prepare_inputs(
attention_input, _, rot_mat, cache_idxs = tt_llama_attention_prepare_inputs(
tt_LlamaAttention_model,
tt_input,
start_pos,
chunk_tt_input,
chunk_start,
mode,
configuration.rope_theta,
rope_setup=None,
use_scaled_rope=configuration.use_scaled_rope,
)

tt_out = tt_LlamaAttention_model(
tt_chunk_out = tt_LlamaAttention_model(
attention_input,
rot_mat,
start_pos,
None,
cache_idxs=None,
page_table=page_table_tt,
page_table=prefill_page_table_tt,
mode=mode,
chunk_page_table=chunk_page_table_tt,
chunk_start_idx=chunk_start,
)

tt_out = ttnn.from_device(tt_out)
tt_out = ttnn.to_torch(tt_out, mesh_composer=ConcatMeshToTensor(t3k_mesh_device, dim=3))
tt_out = tt_out.permute(2, 1, 0, 3).squeeze(1) # [batch, seq_len, hidden_dim]
tt_chunk_out = ttnn.from_device(tt_chunk_out)
tt_chunk_out = ttnn.to_torch(tt_chunk_out, mesh_composer=ConcatMeshToTensor(t3k_mesh_device, dim=3))
tt_chunk_out = tt_chunk_out.permute(2, 1, 0, 3).squeeze(1) # [batch, seq_len, hidden_dim]

# check outputs ----------------------------------------------------------------------
does_pass, output_pcc = comp_pcc(pytorch_out, tt_out, pcc)
logger.info(f"Output: {output_pcc}")
pytorch_chunk_out = pytorch_out[:, chunk_start:chunk_end]
does_pass, output_pcc = comp_pcc(pytorch_chunk_out, tt_chunk_out, pcc)
logger.info(f"Chunk {chunk_start} output: {output_pcc}")
all_pccs.append(extract_pcc_from_log(output_pcc))

else:
Expand All @@ -436,7 +405,7 @@ def run_test_LlamaAttention_inference(
start_pos,
mode,
configuration.rope_theta,
rope_setup=rope_setup,
rope_setup=rope_setup if mode == "decode" else None,
use_scaled_rope=configuration.use_scaled_rope,
)

Expand All @@ -453,7 +422,8 @@ def run_test_LlamaAttention_inference(
tt_out = ttnn.to_torch(tt_out, mesh_composer=ConcatMeshToTensor(t3k_mesh_device, dim=3))
tt_out = tt_out.permute(2, 1, 0, 3).squeeze(1) # [batch, seq_len, hidden_dim]

tt_out = tt_out[:batch]
if mode == "decode":
tt_out = tt_out[:batch]

# check outputs ----------------------------------------------------------------------
does_pass, output_pcc = comp_pcc(pytorch_out, tt_out, pcc)
Expand Down Expand Up @@ -485,7 +455,6 @@ def run_test_LlamaAttention_inference(
# concat the pasts by heads
tt_layer_present_all = [ttnn.from_device(lp) for lp in tt_LlamaAttention_model.layer_past]
if paged_attention:
tt_layer_present_all = [ttnn.from_device(lp) for lp in tt_LlamaAttention_model.layer_past]
tt_layer_present_all = [
(
ttnn.to_torch(lp, mesh_composer=ConcatMeshToTensor(t3k_mesh_device, dim=1))[reverse_permutation]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def test_run_device_perf_llama(
t3k_mesh_device,
batch,
seq_len,
max_batch_size,
max_context_len,
N_LAYERS_TO_PCC[n_layers],
model_config,
Expand Down
60 changes: 38 additions & 22 deletions models/demos/t3000/llama2_70b/tests/test_llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,34 +298,50 @@ def run_test_LlamaModel_inference(
logger.info(f"Average Top-5 over {len(all_top5)} tokens: {sum(all_top5) / len(all_top5)}")
# Check kv cache
# PyTorch output --------------------------------------------------------------------
if chunk_size is None:
pytorch_layer_present = [
pytorch_model.model.layers[0]
.attention.cache_k.clone()
.permute(0, 2, 1, 3)[:batch, ...], # [batch, n_kv_heads, seq, head_dim]
pytorch_model.model.layers[0]
.attention.cache_v.clone()
.permute(0, 2, 1, 3)[:batch, ...], # [batch, n_kv_heads, seq, head_dim]
tt_layer_present_all = [ttnn.from_device(lp) for lp in tt_model.layers[0].attention.layer_past]
if paged_attention:
tt_layer_present_all = [
(
ttnn.to_torch(lp, mesh_composer=ConcatMeshToTensor(t3k_mesh_device, dim=1))[reverse_permutation]
.reshape(
max_batch_size,
paged_attention_config.max_num_blocks // max_batch_size,
configuration.n_kv_heads,
paged_attention_config.block_size,
tt_model.head_dim,
)
.transpose(1, 2)
.reshape(max_batch_size, configuration.n_kv_heads, -1, tt_model.head_dim)[:batch, ...]
)
for lp in tt_layer_present_all
]

tt_layer_present_all = [ttnn.from_device(lp) for lp in tt_model.layers[0].attention.layer_past]
else:
tt_layer_present_all = [
ttnn.to_torch(lp, mesh_composer=ConcatMeshToTensor(t3k_mesh_device, dim=1))[:batch, ...]
for lp in tt_layer_present_all
]

cache_test_pass = check_kv_cache(
pytorch_layer_present,
tt_layer_present_all,
generation_start_pos,
generation_length,
seq_len,
mode == "prefill",
pcc,
)
all_tests_pass = all_tests_pass and cache_test_pass
if all_tests_pass:
logger.info(f"{llama_version} output Passed!")
pytorch_layer_present = [
pytorch_model.model.layers[0]
.attention.cache_k.clone()
.permute(0, 2, 1, 3)[:batch, ...], # [batch, n_kv_heads, seq, head_dim]
pytorch_model.model.layers[0]
.attention.cache_v.clone()
.permute(0, 2, 1, 3)[:batch, ...], # [batch, n_kv_heads, seq, head_dim]
]

cache_test_pass = check_kv_cache(
pytorch_layer_present,
tt_layer_present_all,
generation_start_pos,
generation_length,
seq_len,
mode == "prefill",
pcc,
)
all_tests_pass = all_tests_pass and cache_test_pass
if all_tests_pass:
logger.info(f"{llama_version} output Passed!")

assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ def prefill_attn_mqa(
is_causal=True,
scale=self.scale,
program_config=pc_sdpa,
compute_kernel_config=self.model_config["SDPA_COMPUTE_KERNEL_CONFIG"],
)

# deallocate keys and values
Expand Down

0 comments on commit 5f76206

Please sign in to comment.