-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
Copy patheg_conv.py
260 lines (220 loc) · 10.5 KB
/
eg_conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
from typing import List, Optional, Tuple
import torch
from torch import Tensor
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import Adj, OptTensor, SparseTensor, torch_sparse
from torch_geometric.utils import add_remaining_self_loops, scatter, spmm
class EGConv(MessagePassing):
r"""The Efficient Graph Convolution from the `"Adaptive Filters and
Aggregator Fusion for Efficient Graph Convolutions"
<https://arxiv.org/abs/2104.01481>`_ paper.
Its node-wise formulation is given by:
.. math::
\mathbf{x}_i^{\prime} = {\LARGE ||}_{h=1}^H \sum_{\oplus \in
\mathcal{A}} \sum_{b = 1}^B w_{i, h, \oplus, b} \;
\underset{j \in \mathcal{N}(i) \cup \{i\}}{\bigoplus}
\mathbf{W}_b \mathbf{x}_{j}
with :math:`\mathbf{W}_b` denoting a basis weight,
:math:`\oplus` denoting an aggregator, and :math:`w` denoting per-vertex
weighting coefficients across different heads, bases and aggregators.
EGC retains :math:`\mathcal{O}(|\mathcal{V}|)` memory usage, making it a
sensible alternative to :class:`~torch_geometric.nn.conv.GCNConv`,
:class:`~torch_geometric.nn.conv.SAGEConv` or
:class:`~torch_geometric.nn.conv.GINConv`.
.. note::
For an example of using :obj:`EGConv`, see `examples/egc.py
<https://github.com/pyg-team/pytorch_geometric/blob/master/
examples/egc.py>`_.
Args:
in_channels (int): Size of each input sample, or :obj:`-1` to derive
the size from the first input(s) to the forward method.
out_channels (int): Size of each output sample.
aggregators (List[str], optional): Aggregators to be used.
Supported aggregators are :obj:`"sum"`, :obj:`"mean"`,
:obj:`"symnorm"`, :obj:`"max"`, :obj:`"min"`, :obj:`"std"`,
:obj:`"var"`.
Multiple aggregators can be used to improve the performance.
(default: :obj:`["symnorm"]`)
num_heads (int, optional): Number of heads :math:`H` to use. Must have
:obj:`out_channels % num_heads == 0`. It is recommended to set
:obj:`num_heads >= num_bases`. (default: :obj:`8`)
num_bases (int, optional): Number of basis weights :math:`B` to use.
(default: :obj:`4`)
cached (bool, optional): If set to :obj:`True`, the layer will cache
the computation of the edge index with added self loops on first
execution, along with caching the calculation of the symmetric
normalized edge weights if the :obj:`"symnorm"` aggregator is
being used. This parameter should only be set to :obj:`True` in
transductive learning scenarios. (default: :obj:`False`)
add_self_loops (bool, optional): If set to :obj:`False`, will not add
self-loops to the input graph. (default: :obj:`True`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F_{in})`,
edge indices :math:`(2, |\mathcal{E}|)`
- **output:** node features :math:`(|\mathcal{V}|, F_{out})`
"""
_cached_edge_index: Optional[Tuple[Tensor, OptTensor]]
_cached_adj_t: Optional[SparseTensor]
def __init__(
self,
in_channels: int,
out_channels: int,
aggregators: List[str] = ['symnorm'],
num_heads: int = 8,
num_bases: int = 4,
cached: bool = False,
add_self_loops: bool = True,
bias: bool = True,
**kwargs,
):
super().__init__(node_dim=0, **kwargs)
if out_channels % num_heads != 0:
raise ValueError(f"'out_channels' (got {out_channels}) must be "
f"divisible by the number of heads "
f"(got {num_heads})")
for a in aggregators:
if a not in ['sum', 'mean', 'symnorm', 'min', 'max', 'var', 'std']:
raise ValueError(f"Unsupported aggregator: '{a}'")
self.in_channels = in_channels
self.out_channels = out_channels
self.num_heads = num_heads
self.num_bases = num_bases
self.cached = cached
self.add_self_loops = add_self_loops
self.aggregators = aggregators
self.bases_lin = Linear(in_channels,
(out_channels // num_heads) * num_bases,
bias=False, weight_initializer='glorot')
self.comb_lin = Linear(in_channels,
num_heads * num_bases * len(aggregators))
if bias:
self.bias = Parameter(torch.empty(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
super().reset_parameters()
self.bases_lin.reset_parameters()
self.comb_lin.reset_parameters()
zeros(self.bias)
self._cached_adj_t = None
self._cached_edge_index = None
def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
symnorm_weight: OptTensor = None
if "symnorm" in self.aggregators:
if isinstance(edge_index, Tensor):
cache = self._cached_edge_index
if cache is None:
edge_index, symnorm_weight = gcn_norm( # yapf: disable
edge_index, None, num_nodes=x.size(self.node_dim),
improved=False, add_self_loops=self.add_self_loops,
flow=self.flow, dtype=x.dtype)
if self.cached:
self._cached_edge_index = (edge_index, symnorm_weight)
else:
edge_index, symnorm_weight = cache
elif isinstance(edge_index, SparseTensor):
cache = self._cached_adj_t
if cache is None:
edge_index = gcn_norm( # yapf: disable
edge_index, None, num_nodes=x.size(self.node_dim),
improved=False, add_self_loops=self.add_self_loops,
flow=self.flow, dtype=x.dtype)
if self.cached:
self._cached_adj_t = edge_index
else:
edge_index = cache
elif self.add_self_loops:
if isinstance(edge_index, Tensor):
cache = self._cached_edge_index
if self.cached and cache is not None:
edge_index = cache[0]
else:
edge_index, _ = add_remaining_self_loops(edge_index)
if self.cached:
self._cached_edge_index = (edge_index, None)
elif isinstance(edge_index, SparseTensor):
cache = self._cached_adj_t
if self.cached and cache is not None:
edge_index = cache
else:
edge_index = torch_sparse.fill_diag(edge_index, 1.0)
if self.cached:
self._cached_adj_t = edge_index
# [num_nodes, (out_channels // num_heads) * num_bases]
bases = self.bases_lin(x)
# [num_nodes, num_heads * num_bases * num_aggrs]
weightings = self.comb_lin(x)
# [num_nodes, num_aggregators, (out_channels // num_heads) * num_bases]
# propagate_type: (x: Tensor, symnorm_weight: OptTensor)
aggregated = self.propagate(edge_index, x=bases,
symnorm_weight=symnorm_weight)
weightings = weightings.view(-1, self.num_heads,
self.num_bases * len(self.aggregators))
aggregated = aggregated.view(
-1,
len(self.aggregators) * self.num_bases,
self.out_channels // self.num_heads,
)
# [num_nodes, num_heads, out_channels // num_heads]
out = torch.matmul(weightings, aggregated)
out = out.view(-1, self.out_channels)
if self.bias is not None:
out = out + self.bias
return out
def message(self, x_j: Tensor) -> Tensor:
return x_j
def aggregate(self, inputs: Tensor, index: Tensor,
dim_size: Optional[int] = None,
symnorm_weight: OptTensor = None) -> Tensor:
outs = []
for aggr in self.aggregators:
if aggr == 'symnorm':
assert symnorm_weight is not None
out = scatter(inputs * symnorm_weight.view(-1, 1), index, 0,
dim_size, reduce='sum')
elif aggr == 'var' or aggr == 'std':
mean = scatter(inputs, index, 0, dim_size, reduce='mean')
mean_squares = scatter(inputs * inputs, index, 0, dim_size,
reduce='mean')
out = mean_squares - mean * mean
if aggr == 'std':
out = out.clamp(min=1e-5).sqrt()
else:
out = scatter(inputs, index, 0, dim_size, reduce=aggr)
outs.append(out)
return torch.stack(outs, dim=1) if len(outs) > 1 else outs[0]
def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:
adj_t_2 = adj_t
if len(self.aggregators) > 1 and 'symnorm' in self.aggregators:
if isinstance(adj_t, SparseTensor):
adj_t_2 = adj_t.set_value(None)
else:
adj_t_2 = adj_t.clone()
adj_t_2.values().fill_(1.0)
outs = []
for aggr in self.aggregators:
if aggr == 'symnorm':
out = spmm(adj_t, x, reduce='sum')
elif aggr in ['var', 'std']:
mean = spmm(adj_t_2, x, reduce='mean')
mean_sq = spmm(adj_t_2, x * x, reduce='mean')
out = mean_sq - mean * mean
if aggr == 'std':
out = torch.sqrt(out.relu_() + 1e-5)
else:
out = spmm(adj_t_2, x, reduce=aggr)
outs.append(out)
return torch.stack(outs, dim=1) if len(outs) > 1 else outs[0]
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, aggregators={self.aggregators})')