diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index f42be8ead29ee..77b4b326bf645 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -157,7 +157,6 @@ def create_group_query_attention_graph_prompt( packed=False, softcap=0.0, use_smooth_softmax=False, - node_name=None, ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length @@ -176,7 +175,7 @@ def create_group_query_attention_graph_prompt( "sin_cache" if rotary else "", ], ["output", "present_key", "present_value"], - name=node_name, + "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, local_window_size=local_window_size, @@ -187,7 +186,6 @@ def create_group_query_attention_graph_prompt( # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", - ), ] @@ -689,7 +687,6 @@ def gqa_prompt_func( rotary_interleaved=False, softcap=0.0, use_smooth_softmax=False, - node_name=None, ): onnx_model_str = create_group_query_attention_graph_prompt( config, @@ -701,7 +698,6 @@ def gqa_prompt_func( packed=new_k is None, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - node_name=node_name, ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) past_k = k.clone() if share_buffer else None @@ -718,7 +714,6 @@ def gqa_prompt_func( "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } sess_options = SessionOptions() - sess_options.enable_profiling = True ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) io_binding = ort_session.io_binding() if new_k is not None: @@ -752,11 +747,6 @@ def gqa_prompt_func( ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() ort_output = numpy.array(ort_output) output = torch.tensor(ort_output) - - if sess_options.enable_profiling: - profile_file = ort_session.end_profiling() - print(f"Profiling data saved to: {profile_file}") - return output, present_k, present_v else: ort_inputs = { @@ -1076,7 +1066,6 @@ def parity_check_gqa_prompt( packed=False, softcap=0.0, use_smooth_softmax=False, - node_name=None, rtol=RTOL, atol=ATOL, ): @@ -1196,7 +1185,6 @@ def parity_check_gqa_prompt( # Flash function if packed: - node_name = "packed_" packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) out, present_k, present_v = gqa_prompt_func( packed_qkv, @@ -1214,7 +1202,6 @@ def parity_check_gqa_prompt( rotary_interleaved, softcap, use_smooth_softmax=use_smooth_softmax, - node_name=node_name, ) else: out, present_k, present_v = gqa_prompt_func( @@ -1940,14 +1927,6 @@ def test_gqa_no_past(self): for packed in [False, True]: for softcap in [0.0, 50.0]: for use_smooth_softmax in [False, True]: - node_name = ( - ("packed_" if packed else "") + - ("rotary_" if rotary else "") + - ("rotary_interleaved_" if rotary_interleaved else "") + - "softcap_" + str(softcap) + "_" + - "smooth_softmax_" + str(use_smooth_softmax) + "_" + - "b_" + str(b) + "_sq_" + str(sq) + "_skv_" + str(skv) + "_n_" + str(n) + "_n2_" + str(n2) + "_h_" + str(h) - ) config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) past_kv_format = Formats.BNSH all_close = parity_check_gqa_prompt( @@ -1959,7 +1938,6 @@ def test_gqa_no_past(self): packed=packed, softcap=softcap, use_smooth_softmax=use_smooth_softmax, - node_name=node_name, ) self.assertTrue(all_close) all_close = parity_check_gqa_prompt_no_buff(