diff --git a/autogen/graph_utils.py b/autogen/graph_utils.py index 88c218fde5e..d36b47a12ed 100644 --- a/autogen/graph_utils.py +++ b/autogen/graph_utils.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List +from typing import Dict, List, Optional from autogen.agentchat import Agent @@ -110,7 +110,9 @@ def invert_disallowed_to_allowed(disallowed_speaker_transitions_dict: dict, agen return allowed_speaker_transitions_dict -def visualize_speaker_transitions_dict(speaker_transitions_dict: dict, agents: List[Agent]): +def visualize_speaker_transitions_dict( + speaker_transitions_dict: dict, agents: List[Agent], export_path: Optional[str] = None +): """ Visualize the speaker_transitions_dict using networkx. """ @@ -133,4 +135,8 @@ def visualize_speaker_transitions_dict(speaker_transitions_dict: dict, agents: L # Visualize nx.draw(G, with_labels=True, font_weight="bold") - plt.show() + + if export_path is not None: + plt.savefig(export_path) + else: + plt.show()