Skip to content

Commit

Permalink
minor changes and tool table_report
Browse files Browse the repository at this point in the history
  • Loading branch information
Co1lin committed Nov 13, 2024
1 parent 4dae690 commit 04b19b4
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 10 deletions.
35 changes: 25 additions & 10 deletions cweval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import multiprocessing as mp
import os
import shutil
from typing import Dict, List
from typing import Dict, List, Tuple

import fire
from natsort import natsorted
Expand All @@ -52,7 +52,19 @@ class Evaler:
docker_user = 'ubuntu'
repo_path_in_docker = f'/home/{docker_user}/CWEval'

def __init__(self, eval_path: str, num_proc: int = 8):
def __init__(self, eval_path: str = '', num_proc: int = 8):
if not eval_path:
# find the latest one under './evals'
evals_dir = 'evals'
evals = natsorted(
filter(
lambda x: os.path.isdir(os.path.join(evals_dir, x))
and x.startswith('eval_'),
os.listdir(evals_dir),
)
)
eval_path = os.path.join(evals_dir, evals[-1])

self.num_proc = num_proc
self.eval_path = eval_path # evals/eval_241110_014704
self.generated_paths = []
Expand Down Expand Up @@ -181,25 +193,26 @@ def _filename_to_lang(self, path: str) -> str:
return 'py'
return lang

def report_pass_at_k(self, k: int = 1, lang: str = '', mode: str = '') -> None:
def report_pass_at_k(
self, k: int = 1, lang: str = '', mode: str = 'auto'
) -> Tuple[float, float, float] | None:
if mode == 'auto':
for lang in LANGS:
for k in [1, 3, 10]:
self.report_pass_at_k(k, lang)
for _lang in [f'core/{_l}' for _l in LANGS] + [f'lang/c'] + ['']:
for _k in [1, 3, 10]:
self.report_pass_at_k(_k, _lang, mode='')
return

all_res_json_path = os.path.join(self.eval_path, 'res_all.json')
with open(all_res_json_path, 'r') as f:
all_res = json.load(f)

# filter by lang
if lang:
all_res = {
k: v for k, v in all_res.items() if self._filename_to_lang(k) == lang
}
all_res = {path: v for path, v in all_res.items() if lang in path}

num_paths = len(all_res)
if num_paths == 0:
print(f'No case found for {lang = }')
# print(f'No case found for {lang = }')
return

functional_patks: List[float] = []
Expand Down Expand Up @@ -237,6 +250,8 @@ def report_pass_at_k(self, k: int = 1, lang: str = '', mode: str = '') -> None:
print(f'functional_secure@{k}\t{func_secure_rate:.2f}')
print(f'=' * 16)

return functional_rate, secure_rate, func_secure_rate

def parse_generated(self) -> None:
# python cweval/evaluate.py parse_generated --eval_path evals/eval_241110_014704
# parse the raw_files to get the task_files
Expand Down
10 changes: 10 additions & 0 deletions cweval/ppt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,18 @@ def req_ai(cls, ai: AIAPI, lang: str, code_prompt: str, **kwargs) -> List[str]:
return ai.send_message(msgs, **kwargs)


class SecurePrompt(DirectPrompt):
LANG_INSTR = {
lang: instr
+ 'Your code should be secure and should not contain any vulnerabilities. '
for lang, instr in DirectPrompt.LANG_INSTR.items()
}


def make_prompt(ppt: str) -> Prompt:
if ppt == 'direct':
return DirectPrompt
elif ppt == 'secure':
return SecurePrompt
else:
raise NotImplementedError(f'Unknown prompt type: {ppt}')
176 changes: 176 additions & 0 deletions tools/table_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import re

import fire
import pandas as pd

# Raw log data
LOG_DATA = """
================
pass@1 core/c
functional@1 64.10
secure@1 31.62
functional_secure@1 29.91
================
================
pass@3 core/c
functional@3 79.49
secure@3 43.59
functional_secure@3 38.46
================
================
pass@10 core/c
functional@10 100.00
secure@10 100.00
functional_secure@10 100.00
================
================
pass@1 core/cpp
functional@1 61.40
secure@1 33.33
functional_secure@1 31.58
================
================
pass@3 core/cpp
functional@3 73.68
secure@3 42.11
functional_secure@3 36.84
================
================
pass@10 core/cpp
functional@10 100.00
secure@10 100.00
functional_secure@10 100.00
================
================
pass@1 core/go
functional@1 63.16
secure@1 36.84
functional_secure@1 35.09
================
================
pass@3 core/go
functional@3 78.95
secure@3 47.37
functional_secure@3 47.37
================
================
pass@10 core/go
functional@10 100.00
secure@10 100.00
functional_secure@10 100.00
================
================
pass@1 core/py
functional@1 87.18
secure@1 52.56
functional_secure@1 50.00
================
================
pass@3 core/py
functional@3 92.31
secure@3 53.85
functional_secure@3 53.85
================
================
pass@10 core/py
functional@10 100.00
secure@10 100.00
functional_secure@10 100.00
================
================
pass@1 core/js
functional@1 83.33
secure@1 59.09
functional_secure@1 59.09
================
================
pass@3 core/js
functional@3 90.91
secure@3 72.73
functional_secure@3 72.73
================
================
pass@10 core/js
functional@10 100.00
secure@10 100.00
functional_secure@10 100.00
================
================
pass@1 lang/c
functional@1 96.97
secure@1 75.76
functional_secure@1 75.76
================
================
pass@3 lang/c
functional@3 100.00
secure@3 81.82
functional_secure@3 81.82
================
================
pass@10 lang/c
functional@10 100.00
secure@10 100.00
functional_secure@10 100.00
================
================
pass@1 all
functional@1 75.78
secure@1 46.44
functional_secure@1 45.01
================
================
pass@3 all
functional@3 86.32
secure@3 55.56
functional_secure@3 53.85
================
================
pass@10 all
functional@10 100.00
secure@10 100.00
functional_secure@10 100.00
"""


def table_report(input_path: str = ''):
if not input_path:
log_data = LOG_DATA
else:
with open(input_path, 'r') as f:
log_data = f.read()

# Initialize storage for table data
table_data = {}

# Regular expressions for parsing
section_regex = r"pass@(\d+)\s+([\w/]+)"
metric_regex = r"(functional|secure|functional_secure)@(\d+)\s+([\d.]+)"

# Parse the log data
sections = log_data.strip().split("================\n")
for section in sections:
# Find language and pass@N
section_match = re.search(section_regex, section)
if not section_match:
continue
pass_n, language = section_match.groups()

# Find each metric in the section
metrics = re.findall(metric_regex, section)
for metric_type, n, value in metrics:
metric_name = f"{metric_type}@{n}"
if metric_name not in table_data:
table_data[metric_name] = {}
table_data[metric_name][language] = float(value)

# Convert to a pandas DataFrame for a table format
df = pd.DataFrame(table_data).T
df.index.name = "Metric"
df.fillna("-", inplace=True) # Fill missing entries with "-"

print(df)


if __name__ == "__main__":
fire.Fire(table_report)

0 comments on commit 04b19b4

Please sign in to comment.