Skip to content

Commit

Permalink
update testcase, add WMAPE and MSE metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
tsu-bin committed Nov 1, 2024
1 parent d8f10ee commit 89edcda
Showing 1 changed file with 58 additions and 23 deletions.
81 changes: 58 additions & 23 deletions tests/test_mla_decode_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
import torch.nn.functional as F
import numpy as np
import flashinfer

def wmape(target: torch.Tensor, preds: torch.Tensor):
sum_abs_error = (preds - target).abs().sum().detach().item()
sum_scale = target.abs().sum().detach().item()
return sum_abs_error/sum_scale

from rope_reference import *

class DeepseekV2RMSNorm(nn.Module):
Expand Down Expand Up @@ -189,7 +195,6 @@ def __init__(self, mla_vanilla: DeepseekV2AttentionVanilla):
self.softmax_scale = mla_vanilla.softmax_scale

self.rope_theta = mla_vanilla.rope_theta
# self.rotary_emb = mla_vanilla.rotary_emb

# W^DQ ~ [5120, 1536]
self.W_DQ = mla_vanilla.q_a_proj.weight.transpose(0,1)
Expand Down Expand Up @@ -222,7 +227,8 @@ def run_proof_of_concept(
hidden_states: torch.Tensor,
compressed_kv_normed_cache: torch.Tensor,
k_pe_cache: torch.Tensor,
use_flashinfer_kernel: bool
use_flashinfer_kernel: bool,
convert_float16: bool
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

c_Q = torch.matmul(hidden_states, self.W_DQ)
Expand All @@ -239,6 +245,13 @@ def run_proof_of_concept(
# q_nope ~ [bsz, 128, 512]
q_nope = q_nope.reshape(bsz, self.num_heads, self.kv_lora_rank)

q_kv_dtype = torch.float16
if convert_float16:
q_nope = q_nope.to(q_kv_dtype)
q_pe = q_pe.to(q_kv_dtype)
compressed_kv_normed_cache = compressed_kv_normed_cache.to(q_kv_dtype)
k_pe_cache = k_pe_cache.to(q_kv_dtype)

if not use_flashinfer_kernel:
freqs_cis = precompute_freqs_cis(
self.qk_rope_head_dim, kv_len, self.rope_theta, use_scaled=False
Expand Down Expand Up @@ -275,11 +288,6 @@ def run_proof_of_concept(

else:
print("Now use MLA decode kernel!\n")
q_kv_dtype = torch.float16

q_nope = q_nope.to(q_kv_dtype)
q_pe = q_pe.to(q_kv_dtype)

if kv_len % page_size != 0:
raise ValueError("For simplicity, kv_len should be multiple of page_size.")
num_pages_per_seq = kv_len // page_size
Expand All @@ -291,8 +299,8 @@ def run_proof_of_concept(
(bsz,), page_size, dtype=torch.int32
).to(dev_id)

paged_ckv_cache = compressed_kv_normed_cache.to(q_kv_dtype).reshape(total_num_pages, page_size, self.kv_lora_rank)
paged_kpe_cache = k_pe_cache.to(q_kv_dtype).reshape(total_num_pages, page_size, self.qk_rope_head_dim)
paged_ckv_cache = compressed_kv_normed_cache.reshape(total_num_pages, page_size, self.kv_lora_rank)
paged_kpe_cache = k_pe_cache.reshape(total_num_pages, page_size, self.qk_rope_head_dim)

workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(dev_id)
wrapper = flashinfer.BatchDecodeMlaWithPagedKVCacheWrapper(workspace_buffer)
Expand All @@ -308,12 +316,10 @@ def run_proof_of_concept(
data_type=q_kv_dtype,
q_data_type=q_kv_dtype)
attn_output = wrapper.run(q_nope, q_pe, paged_ckv_cache, paged_kpe_cache)

attn_output = attn_output.to(self.W_UV_O.dtype)

# output ~ [bsz, 5120]
output = torch.matmul(
attn_output.reshape(bsz, self.num_heads*self.kv_lora_rank),
attn_output.to(self.W_UV_O.dtype).reshape(bsz, self.num_heads*self.kv_lora_rank),
self.W_UV_O) # W_UV_O ~ [65536, 5120]

return output
Expand All @@ -323,7 +329,7 @@ def run_proof_of_concept(

dev_id = 1

torch.manual_seed(666)
# torch.manual_seed(666)
torch.set_grad_enabled(False)

mla_vanilla = DeepseekV2AttentionVanilla().cuda(device=dev_id)
Expand All @@ -339,15 +345,44 @@ def run_proof_of_concept(
output_vanilla = mla_vanilla.run_decode(hidden_states, compressed_kv_normed_cache, k_pe_cache)

mla_mat_absorb = DeepseekV2AttentionMatAbsorbDecode(mla_vanilla).cuda(device=dev_id)
output_mat_absorbed_use_torch = mla_mat_absorb.run_proof_of_concept(hidden_states.squeeze(1), compressed_kv_normed_cache, k_pe_cache, use_flashinfer_kernel=False)
output_mat_absorbed_use_flashinfer = mla_mat_absorb.run_proof_of_concept(hidden_states.squeeze(1), compressed_kv_normed_cache, k_pe_cache, use_flashinfer_kernel=True)
output_mat_absorbed_use_torch_f32 = mla_mat_absorb.run_proof_of_concept(hidden_states.squeeze(1), compressed_kv_normed_cache, k_pe_cache,
use_flashinfer_kernel=False, convert_float16=False)
output_mat_absorbed_use_torch_f16 = mla_mat_absorb.run_proof_of_concept(hidden_states.squeeze(1), compressed_kv_normed_cache, k_pe_cache,
use_flashinfer_kernel=False, convert_float16=True)
output_mat_absorbed_use_flashinfer = mla_mat_absorb.run_proof_of_concept(hidden_states.squeeze(1), compressed_kv_normed_cache, k_pe_cache,
use_flashinfer_kernel=True, convert_float16=True)

cos_sim_use_torch = F.cosine_similarity(output_vanilla.reshape(-1), output_mat_absorbed_use_torch.reshape(-1), dim=0)
print(f"cos_sim_use_torch={cos_sim_use_torch}")
assert cos_sim_use_torch > 0.99
cos_use_torch_f32 = F.cosine_similarity(output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f32.reshape(-1), dim=0)
print(f"cos_use_torch_f32 = {cos_use_torch_f32}")
assert cos_use_torch_f32 > 0.99

cos_sim_use_flashinfer = F.cosine_similarity(output_vanilla.reshape(-1), output_mat_absorbed_use_flashinfer.reshape(-1), dim=0)
print(f"cos_sim_use_flashinfer={cos_sim_use_flashinfer}\n")
assert cos_sim_use_flashinfer > 0.99


wmape_use_torch_f32 = wmape(output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f32.reshape(-1))
print(f"wmape_use_torch_f32 = {wmape_use_torch_f32}")
assert wmape_use_torch_f32 < 0.02

mse_use_torch_f32 = F.mse_loss(output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f32.reshape(-1))
print(f"mse_use_torch_f32={mse_use_torch_f32}\n")


cos_use_torch_f16 = F.cosine_similarity(output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f16.reshape(-1), dim=0)
print(f"cos_use_torch_f16 = {cos_use_torch_f16}")
assert cos_use_torch_f16 > 0.99

wmape_use_torch_f16 = wmape(output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f16.reshape(-1))
print(f"wmape_use_torch_f16 = {wmape_use_torch_f16}")
assert wmape_use_torch_f16 < 0.02

mse_use_torch_f16 = F.mse_loss(output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f16.reshape(-1))
print(f"mse_use_torch_f16 = {mse_use_torch_f16}\n")


cos_use_flashinfer = F.cosine_similarity(output_vanilla.reshape(-1), output_mat_absorbed_use_flashinfer.reshape(-1), dim=0)
print(f"cos_use_flashinfer = {cos_use_flashinfer}")
assert cos_use_flashinfer > 0.99

wmape_use_flashinfer = wmape(output_vanilla.reshape(-1), output_mat_absorbed_use_flashinfer.reshape(-1))
print(f"wmape_use_flashinfer = {wmape_use_flashinfer}")
assert wmape_use_flashinfer < 0.02

mse_use_flashinfer = F.mse_loss(output_vanilla.reshape(-1), output_mat_absorbed_use_flashinfer.reshape(-1))
print(f"mse_use_flashinfer = {mse_use_flashinfer}")

0 comments on commit 89edcda

Please sign in to comment.