-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_openai.py
158 lines (127 loc) · 6.96 KB
/
run_openai.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import argparse
import logging
import os
import pandas as pd
import json
import sys
import time
from openai import OpenAI
from types import SimpleNamespace
from dataloader.single_task import SingleTaskDataloader
from dataloader.utils import read_json
from model.openai_batch import OpenAIBatch
from utils import evaluate
# evaluate results of all tasks in one permuation
def evaluate_result(model, result_list, df, results_file, args, logger):
split_file = os.path.join("configs", "{}.json".format(args.split))
all_test_task_list = read_json(split_file)["test"]
# collect all results for each task in this split
test_task_list = set()
task_results = {task_name : [] for task_name in all_test_task_list}
for result in result_list:
# {n_task}#{n_shot}#{perm_id}#{few-shot split id}#{test-task-name}#{case-id}#{global_prefix_n_tokens}
unique_id = result.custom_id.split("#")
n_task, n_shot, perm_id, few_shot_sample_id, task_name, case_id, global_prefix_n_tokens = unique_id
output = result.response.body
test_task_list.add(task_name)
task_results[task_name].append({"case_id": case_id, "result": output})
for task_name in test_task_list:
cases = task_results[task_name]
sorted_cases = sorted(cases, key=lambda x: int(x["case_id"]))
results = [item["result"] for item in sorted_cases]
# collect all preds for each task
test_dataloader = SingleTaskDataloader(args, logger)
dataset_config, _, test_set = test_dataloader.load_data(task_name)
if args.inference_method == "rank":
raw_predictions = model.get_predictions_rank(results, dataset_config["options"])
elif args.inference_method == "greedy":
raw_predictions = model.get_predictions_greedy(results, results)
true_labels = [item["label"] for item in test_set]
predictions = [item["prediction"] for item in raw_predictions]
acc, macro_f1, ood_rate = evaluate(args, true_labels, predictions, dataset_config["options"])
logger.info("n_task: {}, n_shot: {}, perm_id: {}, fewshot_sample_id: {}, task: {}, acc: {:.2f}, macro_f1: {:.4f}, ood_rate: {:.2f}".format(n_task, n_shot, perm_id, few_shot_sample_id, task_name, acc, macro_f1, ood_rate))
df.loc[len(df.index)] = [n_task, n_shot, perm_id, few_shot_sample_id, global_prefix_n_tokens, task_name, acc, macro_f1, ood_rate]
df.to_csv(results_file)
def load_job_status(logger, args):
file_name = os.path.join(args.result_dir, "status.json")
if not os.path.exists(file_name) or args.restart:
return []
batch_list = []
with open(file_name, 'r') as file:
for line in file.readlines():
batch_list.append(json.loads(line))
return batch_list
def save_job_status(batch_job, logger, args):
file_name = os.path.join(args.result_dir, "status.json")
with open(file_name, 'w') as file:
for obj in batch_job:
file.write(json.dumps(obj) + '\n')
def run_and_evaluate(logger, args):
model = OpenAIBatch(args, logger)
model.set_up_model()
# results_file
results_file = os.path.join(args.result_dir, "results.csv")
if os.path.exists(results_file) and not args.restart:
df = pd.read_csv(results_file, index_col=0)
else:
# df = pd.DataFrame(columns=["permutation_id", "task_name", "n_shot_per_class", "accuracy", "macro_f1", "ood_rate"])
df = pd.DataFrame(columns=["n_task", "n_shot", "permutation_id", "fewshot_sample_id", "global_prefix_n_tokens", "task_name", "accuracy", "macro_f1", "ood_rate"])
# send requests
batch_job_list = load_job_status(logger, args)
# delete failed jobs from the queue
batch_job_list = [batch_job for batch_job in batch_job_list if batch_job["status"] != "failed"]
save_job_status(batch_job_list, logger, args)
for file_name in os.listdir(args.input_dir):
if file_name not in [x["file_name"] for x in batch_job_list]:
logger.info(f"Sending batch request: {file_name}")
full_name = os.path.join(args.input_dir, file_name)
batch_job_id = model.send_batch_request(full_name)
batch_job_list.append({"file_name": file_name,
"batch_job_id": batch_job_id,
"status": "pending",
"request_counts": {"total": 0, "completed": 0, "failed": 0}
})
save_job_status(batch_job_list, logger, args)
# retrieve requests in a loop
while any(batch_job["status"] != "saved" for batch_job in batch_job_list):
for i, batch_job in enumerate(batch_job_list):
batch_job_list[i], content = model.retrieve_batch_request(batch_job, args.output_dir)
if content is not None:
evaluate_result(model, content, df, results_file, args, logger)
batch_job_list[i]["status"] = "saved"
# log all batch status
for batch_job in batch_job_list:
file_name, status, total, completed, failed = batch_job["file_name"], batch_job["status"], batch_job["request_counts"]["total"], batch_job["request_counts"]["completed"], batch_job["request_counts"]["failed"]
logger.info(f"batch: {file_name}, status: {status}, {completed} / {total} ({failed} failed)")
save_job_status(batch_job_list, logger, args)
# the batch work may take up to hours if there's a long queue
# we might not want to keep this program running
if not args.keep_waiting:
break
if any(batch_job["status"] != "saved" for batch_job in batch_job_list):
time.sleep(30)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", default="data")
parser.add_argument("--result_dir", default="output/pilot/openai")
parser.add_argument("--split", default="pilot", type=str,
help="A list of tasks to run")
parser.add_argument("--model", default="gpt-4o", choices=["gpt-3.5", "gpt-4o"])
parser.add_argument("--inference_method", default="rank", choices=["rank", "greedy"])
parser.add_argument("--restart", action='store_true')
parser.add_argument("--keep_waiting", action="store_true")
args = parser.parse_args()
args.input_dir = os.path.join(args.result_dir, "inputs", "permutations")
args.output_dir = os.path.join(args.result_dir, "outputs")
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir, exist_ok=True)
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO,
handlers=[logging.FileHandler(os.path.join(args.result_dir, "log.txt")),
logging.StreamHandler()])
logger = logging.getLogger(__name__)
logger.info(args)
run_and_evaluate(logger, args)
if __name__ == "__main__":
main()