-
Notifications
You must be signed in to change notification settings - Fork 1
/
SIM-GNN-SUBNET.py
81 lines (61 loc) · 2.46 KB
/
SIM-GNN-SUBNET.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#!/usr/bin/env python3
"""This script is an adaptation of the original file "example_4_b.py" created by Roman Martin."""
from GNNSubNet import GNNSubNet as gnn
import ensemble_gnn as egnn
import copy
import random
import time
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import matthews_corrcoef
RANDOM_SEED: int = 1717100
# location of the files
#loc = "/home/bastian/GitHub/GNN-SubNet/TCGA"
# PPI network
#ppi = f'{loc}/KIDNEY_RANDOM_PPI.txt'
# single-omic features
#feats = [f'{loc}/KIDNEY_RANDOM_mRNA_FEATURES.txt']
# multi-omic features
#feats = [f'{loc}/KIDNEY_RANDOM_mRNA_FEATURES.txt', f'{loc}/KIDNEY_RANDOM_Methy_FEATURES.txt']
# outcome class
#targ = f'{loc}/KIDNEY_RANDOM_TARGET.txt'
# location of the files
loc = "/home/bastian/TCGA-BRCA"
# PPI network
ppi = f'{loc}/HRPD_brca_subtypes.csv'
# single-omic features
feats = [f'{loc}/GE_brca_subtypes.csv']
# outcome class
targ = f'{loc}/binary_target_brca_subtypes.csv'
# Number of splits for K-fold cross validation
splits: int = 10
# Split data equaliy with split_n and train single models
avg_local_performance: list = []
avg_ensemble_performance: list = []
# For reproducibility of the data splits
# random.seed(RANDOM_SEED)
# random_seeds: list = random.sample(range(100, 999), 2*rounds)
start = time.time()
# Load the multi-omics data
g = gnn.GNNSubNet(loc, ppi, feats, targ, normalize=True)
# Get some general information about the data dimension
# g.summary()
accuracy_single: list = []
counter: int = 0
model_pairs: list = egnn.split_n_fold_cv(g, n_splits=splits, random_seed=RANDOM_SEED)
for g_train, g_test in model_pairs:
counter += 1
print("## Training fold %d" % counter)
g_train.train(method='chebconv', epoch_nr=20)
#pn.grow(100)
predicted_local_classes = g_train.predict(g_test)[0]
print("### Balanced accuracy: fold %d score: %.3f" % (counter, balanced_accuracy_score(g_test.true_class, predicted_local_classes)))
print("## Finished training fold %d" % counter)
print("")
# Stores the test data and single client models into lists
accuracy_single.append(balanced_accuracy_score(g_test.true_class, predicted_local_classes))
avg_local: float = sum(accuracy_single)/len(accuracy_single)
print("# All balanced accuracy values from local tests: %s" % str(accuracy_single))
print("# Average performance with local model: %.3f" % (avg_local))
end = time.time()
print("\n\tTime to go through:", end-start)