Based on Accessing GPT-4 level Mathematical Olympiad Solutions via Monte Carlo Tree Self-refine with LLaMa-3 8B by Zhang, et al.
At a high level, MCTSr iteratively generates solutions to a specified (math) problem.
In a MCTSr tree, nodes correspond to attempted answers, and edges correspond to attempts to improve the answer.
Generate an solution to the problem. This paper uses a "dummy" solution (e.g. "I don't know"
).
We gather a set of candidate nodes which haven't been fully expanded.
A node is fully expanded if either:
- it has
max_children
- any of its children have a Q value which is greater than its own
Once we've gathered the candidates, we compute UCT scores for each candidate node. There are a few ways we can make our selection:
- Greedily (choose the node with the highest UCT)
- Importance sampling (sample from the set of candidates, weighted by their UCT score)
- Pairwise importance sampling (sample the max from a pair of nodes from the set of candidates, weighted by the difference between the pair's UCT scores)
The authors mention that they perform greedy selection in the paper. In their repo, they also perform pairwise sampling and save the (question, answer1, answer2) tuples for use in DPO.
Expansion involves several steps:
- Generate a critique of the current solution.
- Refine the solution based on the critique.
- Add a new child, corresponding to the refined solution.
- Self-evaluate the
reward
of the new child. - Backpropagate the reward from the new child through its parents, through to the root.
Model | Arena Hard RU | Arena Hard EN |
---|---|---|
LLaMA 8B | 35.07 | 20.6 |
LLaMA 8B MCTS | 45.71 | 32.1 |
- Clone the repository (assuming the repository URL is provided):
git clone https://github.com/VikhrModels/mctslib.git
cd mctslib
- Install dependencies (from a
requirements.txt
file):
pip install -r requirements.txt
import os
from src.mcts_llm.mctsr import MCTSrLlama38B
from src.dataset_utils import load_aime
from tqdm import tqdm
import pandas as pd
from src.mcts_llm.mctsr import print_tree
# Set your API key
os.environ['TOGETHER_API_KEY'] = '<your key>'
# Define the problem/question
q = """Your Task"""
# Initialize the MCTS model with the given parameters
mctsr = MCTSrLlama38B(
problem=q,
max_rollouts=4,
max_children=8,
selection_policy=2,
initialize_strategy=2
)
# Run the MCTS model
best_answer = mctsr.run()
# Print the best answer found
print(best_answer)
max_rollouts=8
max_children=2