-
Notifications
You must be signed in to change notification settings - Fork 0
/
extract_kuma_len.py
68 lines (56 loc) · 1.85 KB
/
extract_kuma_len.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
import logging
import os
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type = str,
help = "select dataset / task",
default = "yelp",
)
args = parser.parse_args()
dataset = str(args.dataset)
log_dir = "saved_everything/" + str(dataset)
os.makedirs(log_dir, exist_ok = True)
logging.basicConfig(
filename= log_dir + "/kuma_length.log",
format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S'
)
logging.info(f'''
{dataset} ----''')
def one_domain_len(domain):
overall = np.zeros(5)
for _j_, seed in enumerate([5,10,15,20,25]):
if 'ood' in str(domain):
path_to_file : str = f'kuma_model/{dataset}/kuma-bert-output_seed-kuma-bert{seed}-OOD-{dataset}_{domain}.npy'
elif 'full' in str(domain):
path_to_file : str = f'kuma_model/{dataset}_full/kuma-bert-output_seed-kuma-bert{seed}.npy'
else:
path_to_file : str = f'kuma_model/{dataset}/kuma-bert-output_seed-kuma-bert{seed}.npy'
file_data = np.load(path_to_file, allow_pickle=True).item()
aggregated_ratio = np.zeros(len(file_data))
for _i_, (docid, metadata) in enumerate(file_data.items()):
rationale_ratio = min(
1.,
metadata['rationale'].sum()/metadata['full text length']
)
aggregated_ratio[_i_] = rationale_ratio
overall[_j_] = aggregated_ratio.mean()
logging.info(f'''
{domain}
mean -> {overall.mean()}
std -> {overall.std()}
all -> {overall}
''')
print(f'''{domain}
mean -> {overall.mean()}
std -> {overall.std()}
all -> {overall}
''')
one_domain_len('full')
one_domain_len('InDomain')
one_domain_len('ood1')
one_domain_len('ood2')