-
Notifications
You must be signed in to change notification settings - Fork 4
/
models.py
112 lines (91 loc) · 4.26 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import math
import torch
import torch.nn.functional as F
from torch import nn
from layers import GraphConvolution, Proto_GraphConvolution
class GCN(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout):
super(GCN, self).__init__()
self.gc1 = GraphConvolution(nfeat, nhid)
self.gc2 = GraphConvolution(nhid, nhid)
self.gc3 = GraphConvolution(nhid, nclass)
self.dropout = dropout
def forward(self, x, adj, eval=False):
x = F.relu(self.gc1(x, adj))
x = F.dropout(x, self.dropout, training=self.training)
x = self.gc2(x, adj)
if eval:
return x
else:
x = F.dropout(x, self.dropout, training=self.training)
x = self.gc3(x, adj)
return F.log_softmax(x, dim=1)
def functional_forward(self, x, adj, weights, eval=False):
x = F.relu(self.gc1.functional_forward(x, adj, id=1, weights=weights))
x = F.dropout(x, self.dropout, training=self.training)
x = self.gc2.functional_forward(x, adj, id=2, weights=weights)
if eval:
return x
else:
x = F.dropout(x, self.dropout, training=self.training)
x = self.gc3.functional_forward(x, adj, id=3, weights=weights)
return F.log_softmax(x, dim=1)
class GCN_Proto(nn.Module):
def __init__(self, args, nfeat, dropout):
super(GCN_Proto, self).__init__()
self.gc_proto1 = Proto_GraphConvolution(args, nfeat, nfeat)
self.fc_gate_weight = nn.Linear(args.structure_dim, nfeat * nfeat)
self.fc_gate_bias = nn.Linear(args.structure_dim, nfeat)
self.nfeat = nfeat
self.dropout = dropout
self.args = args
def reset_parameters(self):
stdv = 1. / math.sqrt(self.memory.size(0))
self.memory.data.uniform_(-stdv, stdv)
def forward(self, x, adj, task_repr=None):
if task_repr is not None:
task_repr_cat = task_repr
if self.args.module_type == 'sigmoid':
task_gate_weight = torch.squeeze(torch.sigmoid(self.fc_gate_weight(task_repr_cat))).view(self.nfeat,
self.nfeat)
task_gate_bias = torch.squeeze(torch.sigmoid(self.fc_gate_bias(task_repr_cat))).view(1, self.nfeat)
x = torch.tanh(self.gc_proto1(x, adj, task_gate_weight, task_gate_bias))
else:
x = torch.tanh(self.gc_proto1(x, adj))
return torch.mean(x, dim=0)
class GCN_Structure(nn.Module):
def __init__(self, args, nfeat, nhid, dropout):
super(GCN_Structure, self).__init__()
self.gc_decode_structure1 = GraphConvolution(nhid, nhid)
self.gc_community_prob = GraphConvolution(nhid, nhid)
self.gc_community_value = GraphConvolution(nhid, nhid)
self.gc_structure3 = GraphConvolution(nhid, nhid)
self.nhid = nhid
self.nfeat = nfeat
self.dropout = dropout
self.args = args
if args.hop_concat_type == 'fc':
self.concat_weight = nn.Linear(2 * nhid, nhid)
elif args.hop_concat_type == 'attention':
self.concat_weight = nn.Parameter(torch.FloatTensor(nhid, 1))
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.nhid)
self.concat_weight.data.uniform_(-stdv, stdv)
def forward(self, x, adj, adj_gd):
gc_z = self.gc_decode_structure1(x, adj)
decoder_adj = torch.sigmoid(torch.mm(gc_z, gc_z.transpose(0, 1)))
return torch.mean((adj_gd - decoder_adj).pow(2))
def forward_community(self, x, adj):
gc_z = self.gc_community_value(x, adj)
gc_s = F.softmax(self.gc_community_prob(x, adj), dim=1)
x = F.normalize(torch.mm(gc_s.transpose(0, 1), gc_z), dim=0)
return torch.mean(x, dim=0, keepdim=True)
def forward_concat(self, x):
if self.args.hop_concat_type == 'fc':
return self.concat_weight(x)
elif self.args.hop_concat_type == 'mean':
return torch.mean(x, dim=0, keepdim=True)
elif self.args.hop_concat_type == 'attention':
att_weight = F.softmax(torch.mm(x, self.concat_weight), dim=0)
return torch.sum(att_weight * x, dim=0, keepdim=True)