diff --git a/expo/evaluation/visualize_mcts.py b/expo/evaluation/visualize_mcts.py index 6a8869670..44f5ec5f5 100644 --- a/expo/evaluation/visualize_mcts.py +++ b/expo/evaluation/visualize_mcts.py @@ -128,7 +128,7 @@ def visualize_tree(graph, show_instructions=False, save_path=""): plt.show() -def build_tree_recursive(graph, parent_id, node, start_task_id=2): +def build_tree_recursive(graph, parent_id, node, node_order, start_task_id=2): """ Recursively builds the entire tree starting from the root node. Adds nodes and edges to the NetworkX graph. @@ -143,9 +143,10 @@ def build_tree_recursive(graph, parent_id, node, start_task_id=2): # Add the current node with attributes to the graph dev_score = node.raw_reward.get("dev_score", 0) * 100 avg_score = node.avg_value() * 100 + order = node_order.index(node.id) if node.id in node_order else "" graph.add_node( parent_id, - label=f"{node.id}\nAvg: {avg_score:.1f}\nScore: {dev_score:.1f}\nVisits: {node.visited}", + label=f"{node.id}\nAvg: {avg_score:.1f}\nScore: {dev_score:.1f}\nVisits: {node.visited}\nOrder: {order}", avg_value=node.avg_value(), dev_score=dev_score, visits=node.visited, @@ -159,4 +160,4 @@ def build_tree_recursive(graph, parent_id, node, start_task_id=2): for i, child in enumerate(node.children): child_id = f"{parent_id}-{i}" graph.add_edge(parent_id, child_id) - build_tree_recursive(graph, child_id, child) + build_tree_recursive(graph, child_id, child, node_order) diff --git a/expo/scripts/visualize_experiment.py b/expo/scripts/visualize_experiment.py index e2443d0fd..6cd84a0de 100644 --- a/expo/scripts/visualize_experiment.py +++ b/expo/scripts/visualize_experiment.py @@ -17,7 +17,9 @@ ) mcts.load_tree() + mcts.load_node_order() root = mcts.root_node + node_order = mcts.node_order G = nx.DiGraph() - build_tree_recursive(G, "0", root) + build_tree_recursive(G, "0", root, node_order) visualize_tree(G, save_path=f"results/{args.task}-tree.png")