-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1514 from geekan/sela
SELA
- Loading branch information
Showing
36 changed files
with
3,383 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -189,3 +189,4 @@ cov.xml | |
*.dot | ||
.python-version | ||
*.csv | ||
metagpt/ext/sela/results/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# SELA: Tree-Search Enhanced LLM Agents for Automated Machine Learning | ||
|
||
|
||
Official implementation for paper [SELA: Tree-Search Enhanced LLM Agents for Automated Machine Learning](https://arxiv.org/abs/2410.17238). | ||
|
||
|
||
SELA is an innovative system that enhances Automated Machine Learning (AutoML) by integrating Monte Carlo Tree Search (MCTS) with LLM-based agents. Traditional AutoML methods often generate low-diversity and suboptimal code, limiting their effectiveness in model selection and ensembling. SELA addresses these challenges by representing pipeline configurations as trees, enabling agents to intelligently explore the solution space and iteratively refine their strategies based on experimental feedback. | ||
|
||
## 1. Data Preparation | ||
|
||
You can either download the datasets from the link or prepare the datasets from scratch. | ||
- **Download Datasets:** [Dataset Link](https://drive.google.com/drive/folders/151FIZoLygkRfeJgSI9fNMiLsixh1mK0r?usp=sharing) | ||
- **Download and prepare datasets from scratch:** | ||
```bash | ||
cd data | ||
python dataset.py --save_analysis_pool | ||
python hf_data.py --save_analysis_pool | ||
``` | ||
|
||
## 2. Configurations | ||
|
||
### Data Config | ||
|
||
- **`datasets.yaml`:** Provide base prompts, metrics, and target columns for respective datasets. | ||
- **`data.yaml`:** Modify `datasets_dir` to the base directory of all prepared datasets. | ||
|
||
### LLM Config | ||
|
||
```yaml | ||
llm: | ||
api_type: 'openai' | ||
model: deepseek-coder | ||
base_url: "https://your_base_url" | ||
api_key: sk-xxx | ||
temperature: 0.5 | ||
``` | ||
|
||
|
||
## 3. SELA | ||
|
||
### Run SELA | ||
|
||
#### Setup | ||
|
||
```bash | ||
pip install -e . | ||
cd metagpt/ext/sela | ||
pip install -r requirements.txt | ||
``` | ||
|
||
#### Running Experiments | ||
|
||
- **Examples:** | ||
```bash | ||
python run_experiment.py --exp_mode mcts --task titanic --rollouts 10 | ||
python run_experiment.py --exp_mode mcts --task house-prices --rollouts 10 --low_is_better | ||
``` | ||
|
||
#### Parameters | ||
|
||
- **`--rollouts`:** The number of rollouts. | ||
- **`--use_fixed_insights`:** Include fixed insights saved in `expo/insights/fixed_insights.json`. | ||
- **`--low_is_better`:** Use this if the dataset has a regression metric. | ||
- **`--from_scratch`:** Generate a new insight pool based on the dataset before running MCTS. | ||
- **`--role_timeout`:** Limits the duration of a single simulation (e.g., `10 rollouts with timeout 1,000` = max 10,000s). | ||
- **`--max_depth`:** Set the maximum depth of MCTS (default is 4). | ||
- **`--load_tree`:** Load an existing MCTS tree if the previous experiment was interrupted. | ||
- Example: | ||
```bash | ||
python run_experiment.py --exp_mode mcts --task titanic --rollouts 10 | ||
``` | ||
- To resume: | ||
```bash | ||
python run_experiment.py --exp_mode mcts --task titanic --rollouts 7 --load_tree | ||
``` | ||
|
||
### Ablation Study | ||
|
||
**RandomSearch** | ||
|
||
- **Use a single insight:** | ||
```bash | ||
python run_experiment.py --exp_mode rs --task titanic --rs_mode single | ||
``` | ||
|
||
- **Use a set of insights:** | ||
```bash | ||
python run_experiment.py --exp_mode rs --task titanic --rs_mode set | ||
``` | ||
|
||
## 4. Citation | ||
Please cite our paper if you use SELA or find it cool or useful! | ||
|
||
```bibtex | ||
@misc{chi2024selatreesearchenhancedllm, | ||
title={SELA: Tree-Search Enhanced LLM Agents for Automated Machine Learning}, | ||
author={Yizhou Chi and Yizhang Lin and Sirui Hong and Duyi Pan and Yaying Fei and Guanghao Mei and Bangbang Liu and Tianqi Pang and Jacky Kwok and Ceyao Zhang and Bang Liu and Chenglin Wu}, | ||
year={2024}, | ||
eprint={2410.17238}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.AI}, | ||
url={https://arxiv.org/abs/2410.17238}, | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
datasets_dir: "path/to/datasets" # path to the datasets directory | ||
work_dir: ../../../workspace # path to the workspace directory | ||
role_dir: storage/SELA # path to the role directory |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import os | ||
|
||
from metagpt.ext.sela.data.dataset import SPECIAL_INSTRUCTIONS | ||
from metagpt.ext.sela.runner.mle_bench.instructions import ( | ||
ADDITIONAL_NOTES, | ||
INSTRUCTIONS, | ||
INSTRUCTIONS_OBFUSCATED, | ||
) | ||
|
||
MLE_BENCH_FILES = ["description.md", "description_obfuscated.md"] | ||
|
||
|
||
MLE_REQUIREMENTS = """ | ||
{instructions} | ||
{additonal_notes} | ||
COMPETITION INSTRUCTIONS | ||
------ | ||
{task_description} | ||
## More Instructions | ||
- You should split the training data into train and dev set with a seed of 42. | ||
- You should use the dev set to improve your model. Print the final dev set score after training. | ||
- output_dir: {output_dir} | ||
- Besides `submission.csv`, you should also save your `test_predictions.csv` and `dev_predictions.csv` in the output directory. | ||
- Note that `test_predictions.csv` should be identical to `submission.csv`. | ||
- Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. {special_instruction} | ||
**Do not make any plots or visualizations.** | ||
""" | ||
|
||
|
||
def get_mle_task_id(dataset_dir): | ||
return dataset_dir.split("/")[-3] | ||
|
||
|
||
def get_mle_is_lower_better(task): | ||
from mlebench.data import get_leaderboard | ||
from mlebench.registry import registry | ||
|
||
competition = registry.get_competition(task) | ||
competition_leaderboard = get_leaderboard(competition) | ||
return competition.grader.is_lower_better(competition_leaderboard) | ||
|
||
|
||
def get_mle_bench_requirements(dataset_dir, data_config, special_instruction, obfuscated=False): | ||
work_dir = data_config["work_dir"] | ||
task = get_mle_task_id(dataset_dir) | ||
output_dir = f"{work_dir}/{task}" | ||
final_output_dir = f"{work_dir}/submission" | ||
os.makedirs(output_dir, exist_ok=True) | ||
if special_instruction: | ||
special_instruction = SPECIAL_INSTRUCTIONS[special_instruction] | ||
else: | ||
special_instruction = "" | ||
if obfuscated: | ||
instructions = INSTRUCTIONS_OBFUSCATED.format(dataset_dir=dataset_dir, output_dir=final_output_dir) | ||
task_file = "description_obfuscated.md" | ||
else: | ||
instructions = INSTRUCTIONS.format(dataset_dir=dataset_dir, output_dir=output_dir) | ||
task_file = "description.md" | ||
|
||
with open(os.path.join(dataset_dir, task_file), encoding="utf-8") as f: | ||
task_description = f.read() | ||
mle_requirement = MLE_REQUIREMENTS.format( | ||
instructions=instructions, | ||
additonal_notes=ADDITIONAL_NOTES, | ||
task_description=task_description, | ||
output_dir=output_dir, | ||
special_instruction=special_instruction, | ||
) | ||
print(mle_requirement) | ||
return mle_requirement |
Oops, something went wrong.