Skip to content

Commit

Permalink
undo test_gqa_cpu.py
Browse files Browse the repository at this point in the history
Signed-off-by: liqunfu <liqun.fu@microsoft.com>
  • Loading branch information
liqunfu committed Feb 15, 2025
1 parent 3964acc commit 40a6854
Showing 1 changed file with 1 addition and 23 deletions.
24 changes: 1 addition & 23 deletions onnxruntime/test/python/transformers/test_gqa_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def create_group_query_attention_graph_prompt(
packed=False,
softcap=0.0,
use_smooth_softmax=False,
node_name=None,
):
past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0
present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length
Expand All @@ -176,7 +175,7 @@ def create_group_query_attention_graph_prompt(
"sin_cache" if rotary else "",
],
["output", "present_key", "present_value"],
name=node_name,
"GroupQueryAttention_0",
num_heads=config.num_heads,
kv_num_heads=config.kv_num_heads,
local_window_size=local_window_size,
Expand All @@ -187,7 +186,6 @@ def create_group_query_attention_graph_prompt(
# is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0,
# kv_share_buffer=1 if share_buffer else 0,
domain="com.microsoft",

),
]

Expand Down Expand Up @@ -689,7 +687,6 @@ def gqa_prompt_func(
rotary_interleaved=False,
softcap=0.0,
use_smooth_softmax=False,
node_name=None,
):
onnx_model_str = create_group_query_attention_graph_prompt(
config,
Expand All @@ -701,7 +698,6 @@ def gqa_prompt_func(
packed=new_k is None,
softcap=softcap,
use_smooth_softmax=use_smooth_softmax,
node_name=node_name,
)
q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1))
past_k = k.clone() if share_buffer else None
Expand All @@ -718,7 +714,6 @@ def gqa_prompt_func(
"total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(),
}
sess_options = SessionOptions()
sess_options.enable_profiling = True
ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"])
io_binding = ort_session.io_binding()
if new_k is not None:
Expand Down Expand Up @@ -752,11 +747,6 @@ def gqa_prompt_func(
ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu()
ort_output = numpy.array(ort_output)
output = torch.tensor(ort_output)

if sess_options.enable_profiling:
profile_file = ort_session.end_profiling()
print(f"Profiling data saved to: {profile_file}")

return output, present_k, present_v
else:
ort_inputs = {
Expand Down Expand Up @@ -1076,7 +1066,6 @@ def parity_check_gqa_prompt(
packed=False,
softcap=0.0,
use_smooth_softmax=False,
node_name=None,
rtol=RTOL,
atol=ATOL,
):
Expand Down Expand Up @@ -1196,7 +1185,6 @@ def parity_check_gqa_prompt(

# Flash function
if packed:
node_name = "packed_"
packed_qkv = torch.concatenate([q, new_k, new_v], dim=2)
out, present_k, present_v = gqa_prompt_func(
packed_qkv,
Expand All @@ -1214,7 +1202,6 @@ def parity_check_gqa_prompt(
rotary_interleaved,
softcap,
use_smooth_softmax=use_smooth_softmax,
node_name=node_name,
)
else:
out, present_k, present_v = gqa_prompt_func(
Expand Down Expand Up @@ -1940,14 +1927,6 @@ def test_gqa_no_past(self):
for packed in [False, True]:
for softcap in [0.0, 50.0]:
for use_smooth_softmax in [False, True]:
node_name = (
("packed_" if packed else "") +
("rotary_" if rotary else "") +
("rotary_interleaved_" if rotary_interleaved else "") +
"softcap_" + str(softcap) + "_" +
"smooth_softmax_" + str(use_smooth_softmax) + "_" +
"b_" + str(b) + "_sq_" + str(sq) + "_skv_" + str(skv) + "_n_" + str(n) + "_n2_" + str(n2) + "_h_" + str(h)
)
config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h)
past_kv_format = Formats.BNSH
all_close = parity_check_gqa_prompt(
Expand All @@ -1959,7 +1938,6 @@ def test_gqa_no_past(self):
packed=packed,
softcap=softcap,
use_smooth_softmax=use_smooth_softmax,
node_name=node_name,
)
self.assertTrue(all_close)
all_close = parity_check_gqa_prompt_no_buff(
Expand Down

0 comments on commit 40a6854

Please sign in to comment.