Skip to content

Commit

Permalink
Merge pull request #1526 from garylin2099/sela-lyz
Browse files Browse the repository at this point in the history
add visit order
  • Loading branch information
garylin2099 authored Oct 21, 2024
2 parents d304fc3 + 3a8fdc6 commit 4bed19b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
7 changes: 4 additions & 3 deletions expo/evaluation/visualize_mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def visualize_tree(graph, show_instructions=False, save_path=""):
plt.show()


def build_tree_recursive(graph, parent_id, node, start_task_id=2):
def build_tree_recursive(graph, parent_id, node, node_order, start_task_id=2):
"""
Recursively builds the entire tree starting from the root node.
Adds nodes and edges to the NetworkX graph.
Expand All @@ -143,9 +143,10 @@ def build_tree_recursive(graph, parent_id, node, start_task_id=2):
# Add the current node with attributes to the graph
dev_score = node.raw_reward.get("dev_score", 0) * 100
avg_score = node.avg_value() * 100
order = node_order.index(node.id) if node.id in node_order else ""
graph.add_node(
parent_id,
label=f"{node.id}\nAvg: {avg_score:.1f}\nScore: {dev_score:.1f}\nVisits: {node.visited}",
label=f"{node.id}\nAvg: {avg_score:.1f}\nScore: {dev_score:.1f}\nVisits: {node.visited}\nOrder: {order}",
avg_value=node.avg_value(),
dev_score=dev_score,
visits=node.visited,
Expand All @@ -159,4 +160,4 @@ def build_tree_recursive(graph, parent_id, node, start_task_id=2):
for i, child in enumerate(node.children):
child_id = f"{parent_id}-{i}"
graph.add_edge(parent_id, child_id)
build_tree_recursive(graph, child_id, child)
build_tree_recursive(graph, child_id, child, node_order)
4 changes: 3 additions & 1 deletion expo/scripts/visualize_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
)

mcts.load_tree()
mcts.load_node_order()
root = mcts.root_node
node_order = mcts.node_order
G = nx.DiGraph()
build_tree_recursive(G, "0", root)
build_tree_recursive(G, "0", root, node_order)
visualize_tree(G, save_path=f"results/{args.task}-tree.png")

0 comments on commit 4bed19b

Please sign in to comment.