-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
154 lines (127 loc) · 5.29 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
Open-Domain Question Answering 을 수행하는 inference 코드 입니다.
대부분의 로직은 train.py 와 비슷하나 retrieval, predict 부분이 추가되어 있습니다.
"""
import argparse
import datetime
import logging
import os
import sys
from typing import Callable, List
import pytz
from datasets import Dataset, DatasetDict, Features, Value, load_from_disk
from mrc import MRC
from omegaconf import OmegaConf, dictconfig
from retrieval import DenseRetrieval, HybridRetrieval, SparseRetrieval
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, TrainingArguments, set_seed
logger = logging.getLogger(__name__)
def main(args):
config = OmegaConf.load(f"./config/{args.config}.yaml")
now_time = datetime.datetime.now(pytz.timezone("Asia/Seoul")).strftime("%m-%d-%H-%M")
if config.train.output_dir is None:
trained_model = config.model.name_or_path
if trained_model.startswith("./saved_models"):
trained_model = trained_model.replace("./saved_models/", "") # dropping "saved_models/" for sake of saving
elif trained_model.startswith("saved_models"):
trained_model = trained_model.replace("saved_models/", "")
config.train.output_dir = os.path.join("predictions", trained_model, now_time)
print(f"You can find the outputs in {config.train.output_dir}")
training_args = TrainingArguments(**config.train)
# logging 설정
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
# verbosity 설정 : Transformers logger의 정보로 사용합니다 (on main process only)
# logger.info("Training/evaluation parameters %s", training_args)
# 모델을 초기화하기 전에 난수를 고정합니다.
set_seed(config.utils.seed)
datasets = load_from_disk(config.path.predict)
print(datasets)
tokenizer = AutoTokenizer.from_pretrained(
config.model.name_or_path,
from_tf=bool(".ckpt" in config.model.name_or_path),
use_fast=True,
)
model = AutoModelForQuestionAnswering.from_pretrained(
config.model.name_or_path,
from_tf=bool(".ckpt" in config.model.name_or_path),
)
print(f"Get the pretrained model {config.model.name_or_path}")
reader = MRC(
config,
training_args,
tokenizer,
model,
)
# True일 경우 : run passage retrieval
if config.retriever.type == "sparse":
datasets = run_sparse_retrieval(
tokenize_fn=tokenizer.tokenize,
datasets=datasets,
config=config,
)
elif config.retriever.type == "dense":
datasets = run_dense_retrieval(datasets, config)
elif config.retriever.type == "hybrid":
datasets = run_hybrid_retrieval(tokenize_fn=tokenizer.tokenize, datasets=datasets, config=config)
#### eval dataset & eval example - predictions.json 생성됨
reader.predict(predict_dataset=datasets["validation"])
def run_sparse_retrieval(
tokenize_fn: Callable[[str], List[str]],
datasets: DatasetDict,
config: dictconfig.DictConfig,
) -> DatasetDict:
# Query에 맞는 Passage들을 Retrieval 합니다.
retriever = SparseRetrieval(
tokenize_fn=tokenize_fn,
config=config,
)
if config.sparse.embedding_type == "tfidf":
retriever.get_sparse_embedding()
if config.faiss.use_faiss:
retriever.build_faiss()
df = retriever.retrieve_faiss(datasets["validation"])
else:
df = retriever.retrieve(datasets["validation"])
# test data 에 대해선 정답이 없으므로 id question context 로만 데이터셋이 구성됩니다.
f = Features(
{
"context": Value(dtype="string", id=None),
"id": Value(dtype="string", id=None),
"question": Value(dtype="string", id=None),
}
)
datasets = DatasetDict({"validation": Dataset.from_pandas(df, features=f)})
return datasets
def run_dense_retrieval(datasets, config):
retriever = DenseRetrieval(config)
retriever.get_dense_passage_embedding()
df = retriever.retrieve(datasets["validation"])
f = Features(
{
"context": Value(dtype="string", id=None),
"id": Value(dtype="string", id=None),
"question": Value(dtype="string", id=None),
}
)
datasets = DatasetDict({"validation": Dataset.from_pandas(df, features=f)})
return datasets
def run_hybrid_retrieval(tokenize_fn, datasets, config):
retriever = HybridRetrieval(tokenize_fn, config)
df = retriever.retrieve(datasets["validation"])
f = Features(
{
"context": Value(dtype="string", id=None),
"id": Value(dtype="string", id=None),
"question": Value(dtype="string", id=None),
}
)
datasets = DatasetDict({"validation": Dataset.from_pandas(df, features=f)})
return datasets
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", "-c", type=str, default="custom_config")
args, _ = parser.parse_known_args()
main(args)