-
Notifications
You must be signed in to change notification settings - Fork 0
/
external_models.py
121 lines (97 loc) · 3.22 KB
/
external_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as gnn
class Gat(nn.Module):
def __init__(self, nfeat, nclass):
super(Gat, self).__init__()
self.gat1 = gnn.GATConv(nfeat, 8, 8, True)
self.gat2 = gnn.GATConv(64, nclass)
return
def forward(self, x, edge_index):
x = self.gat1(x, edge_index)
x = F.elu(x)
x = self.gat2(x, edge_index)
return x
class Sage(nn.Module):
def __init__(self, nfeat, nclass):
super(Sage, self).__init__()
self.sage1 = gnn.SAGEConv(nfeat, 64)
self.sage2 = gnn.SAGEConv(64, nclass)
return
def forward(self, x, edge_index):
x = self.sage1(x, edge_index)
x = F.sigmoid(x)
x = self.sage2(x, edge_index)
return x
class Gcn(nn.Module):
def __init__(self, nfeat, nclass):
super(Gcn, self).__init__()
hdim = 64
self.gcn1 = gnn.GCNConv(nfeat, hdim)
self.gcn2 = gnn.GCNConv(hdim, nclass)
return
def forward(self, x, edge_index):
x = self.gcn1(x, edge_index)
x = self.gcn2(x, edge_index)
x = F.softmax(x)
return x
class Sgc(nn.Module):
def __init__(self, nfeat, nclass):
super(Sgc, self).__init__()
self.sgc = gnn.SGConv(nfeat, nclass, 2, True)
return
def forward(self, x, edge_index):
x = self.sgc(x, edge_index)
return x
class Appnp(nn.Module):
def __init__(self, nfeat, nclass):
super(Appnp, self).__init__()
self.linear = nn.Linear(nfeat, 64)
self.linear2 = nn.Linear(64, nclass)
self.appnp = gnn.APPNP(K=5, alpha=0.1)
return
def forward(self, x, edge_index):
x = F.relu(self.linear(x))
x = self.linear2(x)
x = self.appnp(x, edge_index)
return F.log_softmax(x, dim=-1)
class Agnn(nn.Module):
def __init__(self, nfeat, nclass):
super(Agnn, self).__init__()
self.linear = nn.Linear(nfeat, 16)
self.agnn1 = gnn.AGNNConv()
self.agnn2 = gnn.AGNNConv()
self.agnn3 = gnn.AGNNConv()
self.agnn4 = gnn.AGNNConv()
self.linear2 = nn.Linear(16, nclass)
return
def forward(self, x, edge_index):
x = F.relu(self.linear(x))
x = self.agnn1(x, edge_index)
x = self.agnn2(x, edge_index)
x = self.agnn3(x, edge_index)
# x = self.agnn4(x,edge_index)
x = F.softmax(self.linear2(x))
return x
class Arma(nn.Module):
def __init__(self, nfeat, nclass):
super(Arma, self).__init__()
self.arma1 = gnn.ARMAConv(nfeat, 16, num_stacks=2)
self.arma2 = gnn.ARMAConv(16, nclass, num_stacks=2)
return
def forward(self, x, edge_index):
x = self.arma1(x, edge_index)
x = self.arma2(x, edge_index)
return x
class Gated(nn.Module):
def __init__(self, nfeat, nclass):
super(Gated, self).__init__()
self.linear1 = nn.Linear(nfeat, 64)
self.gated = gnn.GatedGraphConv(64, 3)
self.linear = nn.Linear(64, nclass)
return
def forward(self, x, edge_index):
x = F.relu(self.linear1(x))
x = self.gated(x, edge_index)
x = self.linear(x)
return F.softmax(x)