-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun_inference.py
89 lines (76 loc) · 3.77 KB
/
run_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import logging
import os
import hydra
import omegaconf
from datasets import load_dataset
from transformers import (
GPT2LMHeadModel,
RobertaForCausalLM,
RobertaTokenizer,
GPT2Tokenizer,
AutoConfig,
set_seed
)
from tasks.api_completion import evaluate_api_call_completion, evaluate_api_usage_completion
from tasks.perplexity import evaluate_perplexity
logger = logging.getLogger(__name__)
MODEL_CLS = {
'encoder': (AutoConfig, RobertaForCausalLM, RobertaTokenizer),
'decoder': (AutoConfig, GPT2LMHeadModel, GPT2Tokenizer)
}
@hydra.main(config_path='configuration', config_name='defaults', version_base='1.1')
def main(cfg: omegaconf.DictConfig):
if cfg.run.seed is not None:
set_seed(cfg.run.seed)
# hydra changes the current working dir, so we have to keep in memory the base path of the project
cfg.run.base_path = hydra.utils.get_original_cwd()
model_path = os.path.join(cfg.run.base_path, cfg.model.model_name_or_path)
config_cls, model_cls, tokenizer_cls = MODEL_CLS[cfg.model.model_type]
try:
logger.info(f"Attempting to load pre-trained model from local checkpoint ({cfg.model.model_name_or_path}).")
config = config_cls.from_pretrained(model_path)
model = model_cls.from_pretrained(model_path, config=config)
tokenizer = tokenizer_cls.from_pretrained(model_path, use_fast=True)
except:
logger.info(f"Loading pre-trained model from hub ({cfg.model.model_name_or_path}).")
model = model_cls.from_pretrained(cfg.model.model_name_or_path)
tokenizer = tokenizer_cls.from_pretrained(cfg.model.model_name_or_path)
model.to(cfg.device)
if cfg.model.model_type == 'decoder':
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
logger.info(f"Loading test dataset: ({cfg.run.dataset_name}).")
if cfg.run.hf_user is not None:
# load dataset from HF hub
dataset_url = os.path.join(cfg.run.hf_user, cfg.run.dataset_name)
dataset = load_dataset(dataset_url, split='train', use_auth_token=True)
else:
# load dataset locally
dataset_path = os.path.join(cfg.run.base_path, cfg.run.dataset_name)
dataset = load_dataset(dataset_path, split='train')
dataset = dataset.remove_columns(['repo_name', 'method_path', 'method_name', 'docstring'])
if cfg.run.domain != 'all':
logger.info(f"Filtering dataset to keep sample of domain: `{cfg.run.domain}`")
dataset = dataset.filter(lambda e: e['domain'] == cfg.run.domain, num_proc=cfg.run.preprocessing_num_workers)
if cfg.run.task == 'perplexity':
logger.info("***** Evaluating loss and perplexity on input dataset ******")
logger.info(f" Num test samples: {len(dataset)}")
loss, perplexity = evaluate_perplexity(cfg, model, tokenizer, dataset)
logger.info(f"Loss: {round(loss, 4)} | perplexity: {round(perplexity, 4)}")
elif cfg.run.task == 'call':
logger.info("***** Evaluating API completion on input dataset *****")
cfg.run.batch_size = 1
n_test, pass_1, pass_5, pass_10 = evaluate_api_call_completion(cfg, model, tokenizer, dataset)
logger.info(f"Number of test calls: {n_test}")
logger.info(f"Pass@1: {round(pass_1 / n_test, 4)}")
logger.info(f"Pass@5: {round(pass_5 / n_test, 4)}")
logger.info(f"Pass@10: {round(pass_10 / n_test, 4)}")
elif cfg.run.task == 'usage':
logger.info("***** Evaluating API usage completion on input dataset *****")
cfg.run.batch_size = 1
evaluate_api_usage_completion(cfg, model, tokenizer, dataset)
else:
raise ValueError("Please select an evaluation task "
"(perplexity | code-completion | call | usage")
if __name__ == '__main__':
main()