Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/swebench_di' into swebench_di
Browse files Browse the repository at this point in the history
merge upstream/swebench_di
  • Loading branch information
HuZixia committed Mar 19, 2024
2 parents e40fc66 + c52dcc7 commit 8887987
Show file tree
Hide file tree
Showing 9 changed files with 344 additions and 1 deletion.
53 changes: 53 additions & 0 deletions data/inference/const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import pandas as pd

from metagpt.const import METAGPT_ROOT

SUBSET_DATASET = METAGPT_ROOT / "sub_swebench_dataset" / "sub_swebench.csv"
SUBSET_DATASET_SKLERARN = METAGPT_ROOT / "sub_swebench_dataset" / "scikit-learn-68.csv"

# SCIKIT_LEARN_IDS: A list of instance identifiers from 'sub_swebench.csv' within SUBSET_DATASET.
# This collection represents a subset specifically related to scikit-learn content.
SCIKIT_LEARN_IDS = [
"scikit-learn__scikit-learn-11578",
"scikit-learn__scikit-learn-10297",
"scikit-learn__scikit-learn-25747",
"scikit-learn__scikit-learn-15512",
"scikit-learn__scikit-learn-15119",
"scikit-learn__scikit-learn-10870",
"scikit-learn__scikit-learn-15100",
"scikit-learn__scikit-learn-14496",
"scikit-learn__scikit-learn-14890",
"scikit-learn__scikit-learn-10428",
"scikit-learn__scikit-learn-25744",
"scikit-learn__scikit-learn-11542",
"scikit-learn__scikit-learn-10198",
"scikit-learn__scikit-learn-10459",
]


def read_sub_set_instance(path=SUBSET_DATASET, tag="scikit-learn"):
try:
df = pd.read_excel(path)
# Filter for instances containing the tag in either column
pass_filter = df["instance_id_pass"].str.contains(tag, na=False)
fail_filter = df["instance_id_fail"].str.contains(tag, na=False)

# Combine the filters using | (OR operator) for efficiency
combined_filter = pass_filter | fail_filter

# Apply combined filter and select the specific columns
filtered_df = df[combined_filter][["instance_id_pass", "instance_id_fail"]]

# Flatten the DataFrame into a list and remove NaN values
subset_instance = filtered_df.stack().dropna().tolist()

return subset_instance
except FileNotFoundError:
print(f"File not found: {path}")
return []
except Exception as e:
print(f"An error occurred: {e}")
return []
28 changes: 28 additions & 0 deletions data/inference/make_datasets/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import re


def extract_diff(response):
"""
Extracts the diff from a response formatted in different ways
"""
if response is None:
return None
diff_matches = []
other_matches = []
pattern = re.compile(r"\<([\w-]+)\>(.*?)\<\/\1\>", re.DOTALL)
for code, match in pattern.findall(response):
if code in {"diff", "patch"}:
diff_matches.append(match)
else:
other_matches.append(match)
pattern = re.compile(r"```(\w+)?\n(.*?)```", re.DOTALL)
for code, match in pattern.findall(response):
if code in {"diff", "patch"}:
diff_matches.append(match)
else:
other_matches.append(match)
if diff_matches:
return diff_matches[0]
if other_matches:
return other_matches[0]
return response.split("</s>")[0]
16 changes: 16 additions & 0 deletions data/inference/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import runpy
import sys

original_argv = sys.argv.copy()

try:
# 设置你想要传递给脚本的命令行参数
sys.argv = ["run_api.py", "--dataset_name_or_path", "princeton-nlp/SWE-bench_oracle", "--output_dir", "./outputs"]
# 执行脚本
runpy.run_path(path_name="run_api.py", run_name="__main__")
finally:
# 恢复原始的sys.argv以避免对后续代码的潜在影响
sys.argv = original_argv
157 changes: 157 additions & 0 deletions data/inference/run_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import json
import os
import traceback
from pathlib import Path

import fire
import numpy as np
from datasets import load_dataset, load_from_disk
from make_datasets.utils import extract_diff
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tqdm.auto import tqdm

from data.inference.const import SCIKIT_LEARN_IDS
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.roles.di.data_interpreter import DataInterpreter
from metagpt.utils import count_string_tokens
from metagpt.utils.recovery_util import save_history

# Replace with your own
MAX_TOKEN = 128000


@retry(wait=wait_random_exponential(min=30, max=600), stop=stop_after_attempt(5))
async def call_chat(inputs, interpreter):
"""
Calls the openai API to generate completions for the given inputs.
Args:
inputs (str): The inputs to generate completions for.
interpreter (DataInterpreter): The data interpreter to use for execution.
"""
requirement = "Please rewrite the code and generate test case to address the issues existing in the repository. If the test code passes, it is considered that the execution code has passed and use the `git diff` command to output the patch based on the correct code."
system_messages = inputs.split("\n", 1)[0]
user_message = inputs.split("\n", 1)[1]
try:
await interpreter.run([requirement, system_messages, user_message])
return interpreter.get_last_cell_source()
except Exception as e:
logger.error(f"Error: {e}\nInputs: {inputs}")
traceback.print_exc()
raise e


async def openai_inference(
test_dataset,
model_name_or_path,
output_file,
existing_ids,
use_reflection,
):
"""
Runs inference on a dataset using the openai API.
Args:
test_dataset (datasets.Dataset): The dataset to run inference on.
model_name_or_path (str): The name or path of the model to use.
output_file (str): The path to the output file.
existing_ids (set): A set of ids that have already been processed.
"""
test_dataset = test_dataset.filter(
lambda x: count_string_tokens(x["text"], model_name_or_path) <= MAX_TOKEN,
desc="Filtering",
load_from_cache_file=False,
)
basic_args = {
"model_name_or_path": model_name_or_path,
}
print(f"Filtered to {len(test_dataset)} instances")
with open(output_file, "a+") as f:
for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"):
di = DataInterpreter(use_reflection=use_reflection)
instance_id = datum["instance_id"]

if instance_id in existing_ids:
continue
output_dict = {"instance_id": instance_id}
output_dict.update(basic_args)
output_dict["text"] = f"{datum['text']}\n\n"
response = await call_chat(
output_dict["text"],
di,
)
logger.info(f"Final response: {response}")
save_history(di)
output_dict["full_output"] = response
output_dict["model_patch"] = extract_diff(response)
print(json.dumps(output_dict), file=f, flush=True)


async def main(
dataset_name_or_path,
split="test",
model_name_or_path=config.llm.model,
output_dir="outputs",
use_reflection=True,
):
"""
Performs inference on SWE-bench dataset using the Data Interpreter.
Args:
dataset_name_or_path: HuggingFace dataset name or local path
split: Dataset split to use (default: test)
model_name_or_path: Name of the model to use (default: config.llm.model)
param output_dir: Path to the output directory (default: outputs)
"""
model_nickname = Path(model_name_or_path).name if isinstance(model_name_or_path, Path) else model_name_or_path
output_file = f"{model_nickname}__{dataset_name_or_path.split('/')[-1]}__{split}"
output_file = Path(output_dir, output_file + ".jsonl")
output_file.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"Will write to {output_file}")
existing_ids = set()
if os.path.exists(output_file):
with open(output_file, "r") as f:
for line in f:
data = json.loads(line)
instance_id = data["instance_id"]
existing_ids.add(instance_id)
logger.info(f"Read {len(existing_ids)} already completed ids from {output_file}")
if Path(dataset_name_or_path).exists():
dataset = load_from_disk(dataset_name_or_path)
else:
dataset = load_dataset(dataset_name_or_path)
if split not in dataset:
raise ValueError(f"Invalid split {split} for dataset {dataset_name_or_path}")
dataset = dataset[split]
lens = np.array(list(map(len, dataset["text"])))
dataset = dataset.select(np.argsort(lens))

if len(existing_ids) > 0:
dataset = dataset.filter(
lambda x: x["instance_id"] not in existing_ids,
desc="Filtering out existing ids",
load_from_cache_file=False,
)
if len(SCIKIT_LEARN_IDS) > 0:
dataset = dataset.filter(
lambda x: x["instance_id"] in SCIKIT_LEARN_IDS,
desc="Filtering out subset_instance_ids",
load_from_cache_file=False,
)
inference_args = {
"test_dataset": dataset,
"model_name_or_path": model_name_or_path,
"output_file": output_file,
"existing_ids": existing_ids,
"use_reflection": use_reflection,
}
if model_name_or_path.startswith("gpt"):
await openai_inference(**inference_args)
else:
raise ValueError(f"Invalid model name or path {model_name_or_path}")
logger.info("Done!")


if __name__ == "__main__":
fire.Fire(main)
3 changes: 3 additions & 0 deletions metagpt/roles/di/data_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,6 @@ async def _check_data(self):
print(result)
data_info = DATA_INFO.format(info=result)
self.working_memory.add(Message(content=data_info, role="user", cause_by=CheckData))

def get_last_cell_source(self):
return self.execute_code.nb.cells[-1].source
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,5 @@ imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py
qianfan==0.3.2
dashscope==1.14.1
rank-bm25==0.2.2 # for tool recommendation
jieba==0.42.1 # for tool recommendation
jieba==0.42.1 # for tool recommendation
datasets==2.18.0
85 changes: 85 additions & 0 deletions sub_swebench_dataset/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Dataset Description

The index of sub_swebench is a subset of swebench, with two columns in total, each column containing 50 id_instance.

The id_instance is a balanced subset of pass and fail samples for CognitionAI on swebench.

The index of scikit-learn-68 is another subset of CognitionAI in swebench (all tasks of the scikit-learn type), with a total of two columns:

- pass:12
- fail:56

Sampling list:https://github.com/CognitionAI/devin-swebench-results/tree/main/
Original dataset:https://huggingface.co/datasets/princeton-nlp/SWE-bench/

## fail dataset Description:

There are a total of 491 txt files listed.
In the original dataset, the distribution of pass case categories is:

- astropy: 24
- django: 160
- matplotlib: 42
- mwaskom: 4
- pallets: 3
- psf: 9
- pydata: 29
- pylint-dev: 13
- pytest-dev: 20
- scikit-learn: 56
- sphinx-doc: 46
- sympy: 85

### After balanced sampling:

There are a total of 50 txt files listed.

- Django: 16
- Scikit-Learn: 6
- Sympy: 10
- sphinx-doc:5
- matplotlib: 4
- pydata: 3
- astropy: 2
- pytest-dev: 2
- psf: 1
- pylint-dev: 1



## pass dataset Description:



There are a total of 79 txt files listed.
In the original dataset, the distribution of pass case categories is:

- astropy: 4
- django: 38
- matplotlib: 3
- pydata: 3
- pytest-dev: 6
- scikit-learn: 12
- sphinx-doc: 2
- sympy: 11

### After balanced sampling:

There are a total of 50 txt files listed.

- Django: 23
- Scikit-Learn: 8
- Sympy: 7
- Pytest: 4
- Astropy: 3
- Xarray (pydata): 2
- Matplotlib: 2
- Sphinx: 1



## scikit-learn-68 dataset Description:

instance_id_pass:12

instance_id_fail:56
Binary file added sub_swebench_dataset/scikit-learn-68.csv
Binary file not shown.
Binary file added sub_swebench_dataset/sub_swebench.csv
Binary file not shown.

0 comments on commit 8887987

Please sign in to comment.