-
Notifications
You must be signed in to change notification settings - Fork 2
/
models.py
75 lines (65 loc) · 2.54 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import torch.nn as nn
import torch.nn.functional as F
class GraphConv(nn.Module):
def __init__(self,in_dim,out_dim,drop=0.5,bias=False,activation=None):
super(GraphConv,self).__init__()
self.dropout = nn.Dropout(drop)
self.activation = activation
self.w = nn.Linear(in_dim,out_dim,bias=bias)
nn.init.xavier_uniform_(self.w.weight)
self.bias = bias
if self.bias:
nn.init.zeros_(self.w.bias)
def forward(self,adj,x):
x = self.dropout(x)
x = adj.mm(x)
x = self.w(x)
if self.activation:
return self.activation(x)
else:
return x
class AE(nn.Module):
def __init__(self,feat_dim,hid_dim,out_dim,bias=False):
super(AE,self).__init__()
self.conv1 = GraphConv(feat_dim,hid_dim,bias=bias,activation=F.relu)
self.mu = GraphConv(hid_dim,out_dim,bias=bias,activation=torch.sigmoid)
self.conv3 = GraphConv(out_dim,hid_dim,bias=bias,activation=F.relu)
self.conv4 = GraphConv(hid_dim,feat_dim,bias=bias,activation=torch.sigmoid)
self.logvar = GraphConv(hid_dim,out_dim,bias=bias,activation=torch.sigmoid)
def encoder(self,g,x):
x = self.conv1(g,x)
h = self.mu(g,x)
std = self.logvar(g,x)
return h,std
def decoder(self,g,x):
x = self.conv3(g,x)
x = self.conv4(g,x)
return x
def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
else:
return mu
def kl_loss(self,g,mu,logvar):
n_nodes = g.shape[0]
return -0.5 / n_nodes * torch.mean(torch.sum(
1 + 2 * logvar - mu.pow(2) - logvar.exp().pow(2), 1))
def forward(self,g,x):
mu,logvar = self.encoder(g,x)
z = self.reparameterize(mu,logvar)
kl = self.kl_loss(g,mu,logvar)
return self.decoder(g,z),z,kl
class LP(nn.Module):
def __init__(self,hid_dim,out_dim,bias=False):
super(LP,self).__init__()
self.res1 = GraphConv(out_dim,hid_dim,bias=bias,activation=F.relu)
self.res2 = GraphConv(hid_dim,hid_dim,bias=bias,activation=torch.tanh)
self.res3 = GraphConv(hid_dim,hid_dim,bias=bias,activation=F.relu)
self.res4 = GraphConv(hid_dim,out_dim,bias=bias,activation=torch.sigmoid)
def forward(self,g,z):
z = self.res2(g,self.res1(g,z))
res = self.res4(g,self.res3(g,z))
return res,z