-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Roadmap] torch_geometric.nn.aggr
🚀
#4712
Comments
torch_geometric.nn.aggr.*
torch_geometric.nn.aggr.*
🚀
Looks great @rusty1s! I'll try to pick up some of the smaller tasks this weekend |
Added some simple ones. |
torch_geometric.nn.aggr.*
🚀 torch_geometric.nn.aggr
🚀
I planned to pick up a couple of the tasks - hope you guys don't mind me editing the issue to make it clear what I plan on doing (I'll sick to smaller PRs if possible since I typically don't have much time during weeks) |
Do we have an open issue on this? I'd be interested to understand a bit more about what we're thinking here. Is it mostly for the case where we want to (for example) compute both a sum and mean? |
Yes, indeed. There do not exist clear plans for implementation yet though. It will likely depend on PyTorch fusing this ops as part of TorchScript, or on us providing special CUDA kernels. |
Thanks everyone for the hard work @lightaime @Padarn. I think that the final outcome looks fantastic - many cool things to promote in our upcoming release :) |
🚀 The feature, motivation and pitch
The goal of this roadmap is to unify the concepts of aggregation inside GNNs across both
MessagePassing
and global readouts. Currently, these concepts are separated, e.g., viaMessagePassing.aggr = "mean"
andglobal_mean_pool(...)
while the underlying implementation is the same. In addition, some aggregations are only available as global pooling operators (global_sort_pool
,Set2Set
, ...), while, in theory, they are also applicable duringMessagePassing
(and vice versa, e.g.,SAGEConv.aggr = "lstm"
). One additional feature is the combination of aggregations, which is a useful feature both inMessagePassing
(PNAConv
,EGConv
, ...) and global readouts.As such, we want to provide re-usable aggregations as part of a newly defined
torch_geometric.nn.aggr.*
package. Unifying these concepts also helps us to perform optimization and specialized implementations in a single place (e.g., fused kernels for multiple aggregations). After integration, the following functionality is applicable:Roadmap
The general roadmap looks as follows (at best, each implemented in a separate PR):
torch_geometric.nn.aggr.*
and implement aBaseAggr
abstract class (torch_geometric.nn.aggr
package with base class #4687)DeepGCN
aggregationsSoftmaxAggr
,PowerMeanAggr
, cf.,GENConv
(torch_geometric.nn.aggr
package with base class #4687)LSTMAggr
, cf.,SAGEConv
(LSTMAggregation
#4731)MultiAggr
class (MultiAggregation
andaggregation_resolver
#4749)class-resolver
, similar to here (MultiAggregation
andaggregation_resolver
#4749, Addclass-resolver
forAggregation
#4716)torch.jit.script
support (Integration ofnn.aggr
withinMessagePassing
#4779)MessagePassing
interface (Integration ofnn.aggr
withinMessagePassing
#4779)torch_geometric.nn.glob
totorch_geometric.nn.aggr
(respecting the new interface), deprecate old implementation:MeanAggr
,SumAggr
,MaxAggr
,MinAggr
,MulAggr
,VarAggr
,StdAggr
(torch_geometric.nn.aggr
package with base class #4687,MultiAggregation
andaggregation_resolver
#4749)MedianAggr
(AddMedianAggregation
andQuantileAggregation
#5098)AttentionalAggr
(Movenn.glob.attention.GlobalAttention
tonn.aggr.attention.AttentionalAggregation
#4986)Set2Set
(nn.aggr.Set2Set
#4762)GlobalSortAggr
(AddGlobalSortAggr
tonn.aggr
package. #4957)GraphMultiSetTransformer
(MoveGraphMultisetTransformer
tonn.aggr
package #4973)EquilibriumAggr
(EquilibriumAggregation
global aggregation layer #4522)torch_geometric.nn.glob
(Deprecatenn.glob
package #5039)SAGEConv
(RefactorSAGEConv
to useLSTMAggregation
#4863)PNAConv
(RefactorPNAConv
to rely on newAggregation
#4864)GravNetConv
(RefactorGravNetConv
to rely on newAggregation
#4865)GENConv
(RefactorGENConv
to rely on newAggregation
#4866)message_and_aggregate
functionality intact (Addreverse
support inaggregation_resolver
#5084)SAGEConv
(Support for multiple aggregations inSAGEConv
#5033)MultiAggregation
: Support forconcat
,concat+transform
,sum
,mean
,max
,attention
(Addcombine
support toMultiAggregation
#5000, Raise error for one single aggregation whenproj
orattn
mode is used inMultiAggregation
#5034)semi_grad
functionality toSoftmaxAggregation
(Addsemi_grad
and docs toSoftmaxAggregtion
#4995)torch_geometric.nn.aggr.*
documentation (UpdateSoftmaxAggregation
andPowerMeanAggregation
doc strings #5036, Fix documentation andREADME.md
regarding newnn.aggr
package #5097, Updatenn.aggr
documentation #5099, Add missing doc-strings tonn.aggr
#5104)torch_geometric.nn.aggr
package #4927)Any feedback and help from the community is highly appreciated!
cc: @lightaime @Padarn
The text was updated successfully, but these errors were encountered: