-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfigure.py
40 lines (36 loc) · 866 Bytes
/
configure.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""Best hyperparameters found."""
import torch
MWE_GCN_proteins = {
"num_ew_channels": 8,
"num_epochs": 2000,
"in_feats": 1,
"hidden_feats": 10,
"out_feats": 112,
"n_layers": 3,
"lr": 2e-2,
"weight_decay": 0,
"patience": 1000,
"dropout": 0.2,
"aggr_mode": "sum", ## 'sum' or 'concat' for the aggregation across channels
"ewnorm": "both",
}
MWE_DGCN_proteins = {
"num_ew_channels": 8,
"num_epochs": 2000,
"in_feats": 1,
"hidden_feats": 10,
"out_feats": 112,
"n_layers": 2,
"lr": 1e-2,
"weight_decay": 0,
"patience": 300,
"dropout": 0.5,
"aggr_mode": "sum",
"residual": True,
"ewnorm": "none",
}
def get_exp_configure(args):
if args["model"] == "MWE-GCN":
return MWE_GCN_proteins
elif args["model"] == "MWE-DGCN":
return MWE_DGCN_proteins