-
Notifications
You must be signed in to change notification settings - Fork 0
/
layers.py
28 lines (25 loc) · 1.13 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from torch_geometric.nn import GCNConv
from torch_geometric.nn.pool.topk_pool import topk,filter_adj
from torch.nn import Parameter
import torch
class SAGPool(torch.nn.Module):
def __init__(self,in_channels,ratio=0.8,Conv=GCNConv,non_linearity=torch.tanh):
super(SAGPool,self).__init__()
self.in_channels = in_channels
self.ratio = ratio
self.score_layer = Conv(in_channels,1)
self.non_linearity = non_linearity
def forward(self, x, edge_index, edge_attr=None, batch=None):
if batch is None:
batch = edge_index.new_zeros(x.size(0))
#x = x.unsqueeze(-1) if x.dim() == 1 else x
score = self.score_layer(x,edge_index).squeeze()
#the following two lines are added to fix a bug (IndexError)
if len(score.size()) == 0:
score = score.unsqueeze(0)
perm = topk(score, self.ratio, batch)
x = x[perm] * self.non_linearity(score[perm]).view(-1, 1)
batch = batch[perm]
edge_index, edge_attr = filter_adj(
edge_index, edge_attr, perm, num_nodes=score.size(0))
return x, edge_index, edge_attr, batch, perm