-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
110 lines (83 loc) · 2.82 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import GraphConv
# Multi-layer Graph Convolutional Networks
class GCN(nn.Module):
def __init__(self, in_dim, out_dim, act_fn, num_layers=2):
super(GCN, self).__init__()
assert num_layers >= 2
self.num_layers = num_layers
self.convs = nn.ModuleList()
self.convs.append(GraphConv(in_dim, out_dim * 2))
for _ in range(self.num_layers - 2):
self.convs.append(GraphConv(out_dim * 2, out_dim * 2))
self.convs.append(GraphConv(out_dim * 2, out_dim))
self.act_fn = act_fn
def forward(self, graph, feat):
for i in range(self.num_layers):
feat = self.act_fn(self.convs[i](graph, feat))
return feat
# Multi-layer(2-layer) Perceptron
class MLP(nn.Module):
def __init__(self, in_dim, out_dim):
super(MLP, self).__init__()
self.fc1 = nn.Linear(in_dim, out_dim)
self.fc2 = nn.Linear(out_dim, in_dim)
def forward(self, x):
z = F.elu(self.fc1(x))
return self.fc2(z)
class Grace(nn.Module):
r"""
GRACE model
Parameters
-----------
in_dim: int
Input feature size.
hid_dim: int
Hidden feature size.
out_dim: int
Output feature size.
num_layers: int
Number of the GNN encoder layers.
act_fn: nn.Module
Activation function.
temp: float
Temperature constant.
"""
def __init__(self, in_dim, hid_dim, out_dim, num_layers, act_fn, temp):
super(Grace, self).__init__()
self.encoder = GCN(in_dim, hid_dim, act_fn, num_layers)
self.temp = temp
self.proj = MLP(hid_dim, out_dim)
def sim(self, z1, z2):
# normalize embeddings across feature dimension
z1 = F.normalize(z1)
z2 = F.normalize(z2)
s = th.mm(z1, z2.t())
return s
def get_loss(self, z1, z2):
# calculate SimCLR loss
f = lambda x: th.exp(x / self.temp)
refl_sim = f(self.sim(z1, z1)) # intra-view pairs
between_sim = f(self.sim(z1, z2)) # inter-view pairs
# between_sim.diag(): positive pairs
x1 = refl_sim.sum(1) + between_sim.sum(1) - refl_sim.diag()
loss = -th.log(between_sim.diag() / x1)
return loss
def get_embedding(self, graph, feat):
# get embeddings from the model for evaluation
h = self.encoder(graph, feat)
return h.detach()
def forward(self, graph1, graph2, feat1, feat2):
# encoding
h1 = self.encoder(graph1, feat1)
h2 = self.encoder(graph2, feat2)
# projection
z1 = self.proj(h1)
z2 = self.proj(h2)
# get loss
l1 = self.get_loss(z1, z2)
l2 = self.get_loss(z2, z1)
ret = (l1 + l2) * 0.5
return ret.mean()