-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathmain.py
62 lines (52 loc) · 2.37 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import sys
sys.path.append('./src')
from dgld.utils.evaluation import split_auc
from dgld.utils.common import seed_everything
from dgld.utils.argparser import parse_all_args
from dgld.utils.load_data import load_data, load_custom_data, load_truth_data
from dgld.utils.inject_anomalies import inject_contextual_anomalies, inject_structural_anomalies
from dgld.utils.common_params import Q_MAP, K, P
from dgld.utils.log import Dgldlog
from dgld.models import *
import random
import os
truth_list = ['weibo', 'tfinance', 'tsocial', 'reddit', 'Amazon', 'Class', 'Disney', 'elliptic', 'Enron']
if __name__ == "__main__":
args_dict, args = parse_all_args()
data_name = args_dict['dataset']
save_path = args.save_path
exp_name = args.exp_name
log = Dgldlog(save_path, exp_name, args)
res_list_final = []
res_list_attrb = []
res_list_struct = []
seed_list = [random.randint(0, 99999) for i in range(args.runs)]
seed_list[0] = args_dict['seed']
for runs in range(args.runs):
log.update_runs()
seed = seed_list[runs]
seed_everything(seed)
args_dict['seed'] = seed
if data_name in truth_list:
graph = load_truth_data(data_path=args.data_path, dataset_name=data_name)
elif data_name == 'custom':
graph = load_custom_data(data_path=args.data_path)
else:
graph = load_data(data_name)
graph = inject_contextual_anomalies(graph=graph, k=K, p=P, q=Q_MAP[data_name], seed=seed)
graph = inject_structural_anomalies(graph=graph, p=P, q=Q_MAP[data_name], seed=seed)
label = graph.ndata['label']
if args.model in ['DOMINANT', 'AnomalyDAE', 'ComGA', 'DONE', 'AdONE', 'CONAD', 'ALARM', 'ONE', 'GAAN', 'GUIDE',
'CoLA', 'AAGNN', 'SLGAD', 'ANEMONE', 'GCNAE', 'MLPAE', 'SCAN']:
model = eval(f'{args.model}(**args_dict["model"])')
else:
raise ValueError(f"{args.model} is not implemented!")
model.fit(graph, **args_dict["fit"])
result = model.predict(graph, **args_dict["predict"])
final_score, a_score, s_score = split_auc(label, result)
res_list_final.append(final_score)
res_list_attrb.append(a_score)
res_list_struct.append(s_score)
print(args_dict)
log.save_result(res_list_final, res_list_attrb, res_list_struct, seed_list, args)
os._exit(0)