diff --git a/tests/trace_link/test_trace_linker.py b/tests/trace_link/test_trace_linker.py index 77bcc732..4d320acd 100644 --- a/tests/trace_link/test_trace_linker.py +++ b/tests/trace_link/test_trace_linker.py @@ -469,6 +469,7 @@ def test_process_dependent_gpu_ops(trace_linker, orig_op_id, cpu_op, kineto_gpu_ gpu_op.inclusive_dur = gpu_op_data["inclusive_dur"] gpu_op.exclusive_dur = gpu_op_data["exclusive_dur"] gpu_op.stream = gpu_op_data["stream"] + gpu_op.pg_name = gpu_op_data.get("pg_name", None) kineto_gpu_op_objects.append(gpu_op) host_op_id_to_kineto_ops_map = {orig_op_id: kineto_gpu_op_objects} @@ -497,6 +498,8 @@ def test_process_dependent_gpu_ops(trace_linker, orig_op_id, cpu_op, kineto_gpu_ assert updated_gpu_op["exclusive_dur"] == kineto_gpu_op_objects[i].exclusive_dur assert updated_gpu_op["ts"] == kineto_gpu_op_objects[i].timestamp assert updated_gpu_op["stream"] == kineto_gpu_op_objects[i].stream + if kineto_gpu_op_objects[i].is_inter_gpu_comms_op() and kineto_gpu_op_objects[i].pg_name is not None: + assert updated_gpu_op["pg_name"] == kineto_gpu_op_objects[i].pg_name @patch("builtins.open", new_callable=MagicMock)