-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
144 lines (119 loc) · 5.65 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import numpy as np
import os
import csv
from tqdm import tqdm
import pandas as pd
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch_geometric.utils
from torch import Tensor
from torch.nn import Parameter
from torch_sparse import SparseTensor, set_diag
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import (
Adj,
NoneType,
OptPairTensor,
OptTensor,
Size,
)
from torch_geometric.utils import add_self_loops, remove_self_loops, softmax
from torch_geometric.nn.models import MLP, GCN, GAT
from torch_geometric.nn.models.basic_gnn import BasicGNN
from torch_geometric.nn.conv import GATConv
from torch_geometric.utils import k_hop_subgraph
import torch_geometric.transforms as T
class GATConv_ad(GATConv):
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
edge_attr: OptTensor = None, size: Size = None,
return_attention_weights=None):
H, C = self.heads, self.out_channels
# We first transform the input node features. If a tuple is passed, we
# transform source and target node features via separate weights:
if isinstance(x, Tensor):
assert x.dim() == 2, "Static graphs not supported in 'GATConv'"
x_src = x_dst = self.lin_src(x).view(-1, H, C)
else: # Tuple of source and target node features:
x_src, x_dst = x
assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'"
x_src = self.lin_src(x_src).view(-1, H, C)
if x_dst is not None:
x_dst = self.lin_dst(x_dst).view(-1, H, C)
x = (x_src, x_dst)
# Next, we compute node-level attention coefficients, both for source
# and target nodes (if present):
alpha_src = (x_src * self.att_src).sum(dim=-1)
alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1)
alpha = (alpha_src, alpha_dst)
if self.add_self_loops:
if isinstance(edge_index, Tensor):
# We only want to add self-loops for nodes that appear both as
# source and target nodes:
num_nodes = x_src.size(0)
if x_dst is not None:
num_nodes = min(num_nodes, x_dst.size(0))
num_nodes = min(size) if size is not None else num_nodes
edge_index, edge_attr = remove_self_loops(
edge_index, edge_attr)
edge_index, edge_attr = add_self_loops(
edge_index, edge_attr, fill_value=self.fill_value,
num_nodes=num_nodes)
elif isinstance(edge_index, SparseTensor):
if self.edge_dim is None:
edge_index = set_diag(edge_index)
else:
raise NotImplementedError(
"The usage of 'edge_attr' and 'add_self_loops' "
"simultaneously is currently not yet supported for "
"'edge_index' in a 'SparseTensor' form")
# edge_updater_type: (alpha: OptPairTensor, edge_attr: OptTensor)
alpha = self.edge_updater(edge_index, alpha=alpha, edge_attr=edge_attr)
# propagate_type: (x: OptPairTensor, alpha: Tensor)
out = self.propagate(edge_index, x=x, alpha=alpha, size=size)
if self.concat:
out = out.view(-1, self.heads * self.out_channels)
else:
out = out.mean(dim=1)
if self.bias is not None:
out = out + self.bias
# out += (torch.rand(out.shape[0], 1) - 0.5) * out.mean(-1, keepdim=True) * 5e-4
if isinstance(return_attention_weights, bool):
if isinstance(edge_index, Tensor):
return out, (edge_index, alpha)
elif isinstance(edge_index, SparseTensor):
return out, edge_index.set_value(alpha, layout='coo')
else:
return out
def edge_update(self, alpha_j: Tensor, alpha_i: OptTensor,
edge_attr: OptTensor, index: Tensor, size_i: Optional[int],
ptr: OptTensor) -> Tensor:
# Given edge-level attention coefficients for source and target nodes,
# we simply need to sum them up to "emulate" concatenation:
alpha = alpha_j if alpha_i is None else alpha_j + alpha_i
if edge_attr is not None and self.lin_edge is not None:
if edge_attr.dim() == 1:
edge_attr = edge_attr.view(-1, 1)
edge_attr = self.lin_edge(edge_attr)
edge_attr = edge_attr.view(-1, self.heads, self.out_channels)
alpha_edge = (edge_attr * self.att_edge).sum(dim=-1)
alpha = alpha + alpha_edge
# alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(alpha, index, ptr, size_i)
alpha += (torch.rand_like(alpha) - 0.5) * 5e-2
alpha = torch.clamp(alpha, 0, 1)
self.alpha = alpha
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
return alpha
def message(self, x_j: Tensor, alpha: Tensor) -> Tensor:
return alpha.unsqueeze(-1) * x_j
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, heads={self.heads})')
class GAT_ad(GAT):
def init_conv(self, in_channels: Union[int, Tuple[int, int]],
out_channels: int, **kwargs) -> MessagePassing:
heads = kwargs.pop('heads', 1)
concat = kwargs.pop('concat', True)
return GATConv_ad(in_channels, out_channels, heads=heads, concat=concat, dropout=self.dropout, **kwargs)