diff --git a/tests/test_mla_decode_kernel.py b/tests/test_mla_decode_kernel.py index c83185fc..3f5ba57c 100644 --- a/tests/test_mla_decode_kernel.py +++ b/tests/test_mla_decode_kernel.py @@ -5,6 +5,12 @@ import torch.nn.functional as F import numpy as np import flashinfer + +def wmape(target: torch.Tensor, preds: torch.Tensor): + sum_abs_error = (preds - target).abs().sum().detach().item() + sum_scale = target.abs().sum().detach().item() + return sum_abs_error/sum_scale + from rope_reference import * class DeepseekV2RMSNorm(nn.Module): @@ -189,7 +195,6 @@ def __init__(self, mla_vanilla: DeepseekV2AttentionVanilla): self.softmax_scale = mla_vanilla.softmax_scale self.rope_theta = mla_vanilla.rope_theta - # self.rotary_emb = mla_vanilla.rotary_emb # W^DQ ~ [5120, 1536] self.W_DQ = mla_vanilla.q_a_proj.weight.transpose(0,1) @@ -222,7 +227,8 @@ def run_proof_of_concept( hidden_states: torch.Tensor, compressed_kv_normed_cache: torch.Tensor, k_pe_cache: torch.Tensor, - use_flashinfer_kernel: bool + use_flashinfer_kernel: bool, + convert_float16: bool ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: c_Q = torch.matmul(hidden_states, self.W_DQ) @@ -239,6 +245,13 @@ def run_proof_of_concept( # q_nope ~ [bsz, 128, 512] q_nope = q_nope.reshape(bsz, self.num_heads, self.kv_lora_rank) + q_kv_dtype = torch.float16 + if convert_float16: + q_nope = q_nope.to(q_kv_dtype) + q_pe = q_pe.to(q_kv_dtype) + compressed_kv_normed_cache = compressed_kv_normed_cache.to(q_kv_dtype) + k_pe_cache = k_pe_cache.to(q_kv_dtype) + if not use_flashinfer_kernel: freqs_cis = precompute_freqs_cis( self.qk_rope_head_dim, kv_len, self.rope_theta, use_scaled=False @@ -275,11 +288,6 @@ def run_proof_of_concept( else: print("Now use MLA decode kernel!\n") - q_kv_dtype = torch.float16 - - q_nope = q_nope.to(q_kv_dtype) - q_pe = q_pe.to(q_kv_dtype) - if kv_len % page_size != 0: raise ValueError("For simplicity, kv_len should be multiple of page_size.") num_pages_per_seq = kv_len // page_size @@ -291,8 +299,8 @@ def run_proof_of_concept( (bsz,), page_size, dtype=torch.int32 ).to(dev_id) - paged_ckv_cache = compressed_kv_normed_cache.to(q_kv_dtype).reshape(total_num_pages, page_size, self.kv_lora_rank) - paged_kpe_cache = k_pe_cache.to(q_kv_dtype).reshape(total_num_pages, page_size, self.qk_rope_head_dim) + paged_ckv_cache = compressed_kv_normed_cache.reshape(total_num_pages, page_size, self.kv_lora_rank) + paged_kpe_cache = k_pe_cache.reshape(total_num_pages, page_size, self.qk_rope_head_dim) workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(dev_id) wrapper = flashinfer.BatchDecodeMlaWithPagedKVCacheWrapper(workspace_buffer) @@ -308,12 +316,10 @@ def run_proof_of_concept( data_type=q_kv_dtype, q_data_type=q_kv_dtype) attn_output = wrapper.run(q_nope, q_pe, paged_ckv_cache, paged_kpe_cache) - - attn_output = attn_output.to(self.W_UV_O.dtype) # output ~ [bsz, 5120] output = torch.matmul( - attn_output.reshape(bsz, self.num_heads*self.kv_lora_rank), + attn_output.to(self.W_UV_O.dtype).reshape(bsz, self.num_heads*self.kv_lora_rank), self.W_UV_O) # W_UV_O ~ [65536, 5120] return output @@ -323,7 +329,7 @@ def run_proof_of_concept( dev_id = 1 - torch.manual_seed(666) + # torch.manual_seed(666) torch.set_grad_enabled(False) mla_vanilla = DeepseekV2AttentionVanilla().cuda(device=dev_id) @@ -339,15 +345,44 @@ def run_proof_of_concept( output_vanilla = mla_vanilla.run_decode(hidden_states, compressed_kv_normed_cache, k_pe_cache) mla_mat_absorb = DeepseekV2AttentionMatAbsorbDecode(mla_vanilla).cuda(device=dev_id) - output_mat_absorbed_use_torch = mla_mat_absorb.run_proof_of_concept(hidden_states.squeeze(1), compressed_kv_normed_cache, k_pe_cache, use_flashinfer_kernel=False) - output_mat_absorbed_use_flashinfer = mla_mat_absorb.run_proof_of_concept(hidden_states.squeeze(1), compressed_kv_normed_cache, k_pe_cache, use_flashinfer_kernel=True) + output_mat_absorbed_use_torch_f32 = mla_mat_absorb.run_proof_of_concept(hidden_states.squeeze(1), compressed_kv_normed_cache, k_pe_cache, + use_flashinfer_kernel=False, convert_float16=False) + output_mat_absorbed_use_torch_f16 = mla_mat_absorb.run_proof_of_concept(hidden_states.squeeze(1), compressed_kv_normed_cache, k_pe_cache, + use_flashinfer_kernel=False, convert_float16=True) + output_mat_absorbed_use_flashinfer = mla_mat_absorb.run_proof_of_concept(hidden_states.squeeze(1), compressed_kv_normed_cache, k_pe_cache, + use_flashinfer_kernel=True, convert_float16=True) - cos_sim_use_torch = F.cosine_similarity(output_vanilla.reshape(-1), output_mat_absorbed_use_torch.reshape(-1), dim=0) - print(f"cos_sim_use_torch={cos_sim_use_torch}") - assert cos_sim_use_torch > 0.99 + cos_use_torch_f32 = F.cosine_similarity(output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f32.reshape(-1), dim=0) + print(f"cos_use_torch_f32 = {cos_use_torch_f32}") + assert cos_use_torch_f32 > 0.99 - cos_sim_use_flashinfer = F.cosine_similarity(output_vanilla.reshape(-1), output_mat_absorbed_use_flashinfer.reshape(-1), dim=0) - print(f"cos_sim_use_flashinfer={cos_sim_use_flashinfer}\n") - assert cos_sim_use_flashinfer > 0.99 - - + wmape_use_torch_f32 = wmape(output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f32.reshape(-1)) + print(f"wmape_use_torch_f32 = {wmape_use_torch_f32}") + assert wmape_use_torch_f32 < 0.02 + + mse_use_torch_f32 = F.mse_loss(output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f32.reshape(-1)) + print(f"mse_use_torch_f32={mse_use_torch_f32}\n") + + + cos_use_torch_f16 = F.cosine_similarity(output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f16.reshape(-1), dim=0) + print(f"cos_use_torch_f16 = {cos_use_torch_f16}") + assert cos_use_torch_f16 > 0.99 + + wmape_use_torch_f16 = wmape(output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f16.reshape(-1)) + print(f"wmape_use_torch_f16 = {wmape_use_torch_f16}") + assert wmape_use_torch_f16 < 0.02 + + mse_use_torch_f16 = F.mse_loss(output_vanilla.reshape(-1), output_mat_absorbed_use_torch_f16.reshape(-1)) + print(f"mse_use_torch_f16 = {mse_use_torch_f16}\n") + + + cos_use_flashinfer = F.cosine_similarity(output_vanilla.reshape(-1), output_mat_absorbed_use_flashinfer.reshape(-1), dim=0) + print(f"cos_use_flashinfer = {cos_use_flashinfer}") + assert cos_use_flashinfer > 0.99 + + wmape_use_flashinfer = wmape(output_vanilla.reshape(-1), output_mat_absorbed_use_flashinfer.reshape(-1)) + print(f"wmape_use_flashinfer = {wmape_use_flashinfer}") + assert wmape_use_flashinfer < 0.02 + + mse_use_flashinfer = F.mse_loss(output_vanilla.reshape(-1), output_mat_absorbed_use_flashinfer.reshape(-1)) + print(f"mse_use_flashinfer = {mse_use_flashinfer}")