Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FusedAggregation of simple scatter reductions #6036

Merged
merged 10 commits into from
Nov 23, 2022
Merged

FusedAggregation of simple scatter reductions #6036

merged 10 commits into from
Nov 23, 2022

Conversation

rusty1s
Copy link
Member

@rusty1s rusty1s commented Nov 22, 2022

+------------------------------+---------+---------+
| Aggregators                  | Vanilla | Fusion  |
+==============================+=========+=========+
| :obj:`[sum, mean]`           | 0.4019s | 0.1666s |
+------------------------------+---------+---------+
| :obj:`[sum, mean, min, max]` | 0.7841s | 0.4223s |
+------------------------------+---------+---------+
| :obj:`[sum, mean, var]`      | 0.9711s | 0.3614s |
+------------------------------+---------+---------+
| :obj:`[sum, mean, var, std]` | 1.5994s | 0.3722s |
+------------------------------+---------+---------+

@codecov
Copy link

codecov bot commented Nov 22, 2022

Codecov Report

Merging #6036 (db2cbae) into master (8f2dc12) will decrease coverage by 1.83%.
The diff coverage is 97.79%.

❗ Current head db2cbae differs from pull request most recent head bcab4a9. Consider uploading reports for the commit bcab4a9 to get more accurate results

@@            Coverage Diff             @@
##           master    #6036      +/-   ##
==========================================
- Coverage   86.70%   84.86%   -1.84%     
==========================================
  Files         360      361       +1     
  Lines       19967    20097     +130     
==========================================
- Hits        17312    17055     -257     
- Misses       2655     3042     +387     
Impacted Files Coverage Δ
torch_geometric/nn/aggr/fused.py 97.77% <97.77%> (ø)
torch_geometric/nn/aggr/basic.py 97.77% <100.00%> (ø)
torch_geometric/nn/models/dimenet_utils.py 0.00% <0.00%> (-75.52%) ⬇️
torch_geometric/nn/models/dimenet.py 14.90% <0.00%> (-52.76%) ⬇️
torch_geometric/profile/profile.py 36.73% <0.00%> (-27.56%) ⬇️
torch_geometric/nn/conv/utils/typing.py 81.25% <0.00%> (-17.50%) ⬇️
torch_geometric/nn/pool/asap.py 92.10% <0.00%> (-7.90%) ⬇️
torch_geometric/nn/inits.py 67.85% <0.00%> (-7.15%) ⬇️
torch_geometric/nn/dense/linear.py 87.40% <0.00%> (-5.93%) ⬇️
torch_geometric/transforms/add_self_loops.py 94.44% <0.00%> (-5.56%) ⬇️
... and 14 more

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

Copy link
Contributor

@lightaime lightaime left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing. I really like it!

An open discussion: I wonder if it is possible to move these caching logics into these aggregations (MeanAggregation, VarAggregation, StdAggregation) that may utilize caches. We can first sort the executions of aggregations based on caching dependence and maintain caching results for the degree, sum, mean, and var into a buffer that the later aggregations can read. That may make the control flow a bit clear. Just a random thought. Essentially they are the same but moving the caching logic to the Aggregation may be easier to read. I think I am wrong. It is not efficient to parallelize the computation in this way.

torch_geometric/nn/aggr/fused.py Outdated Show resolved Hide resolved
torch_geometric/nn/aggr/fused.py Outdated Show resolved Hide resolved
@rusty1s
Copy link
Member Author

rusty1s commented Nov 23, 2022

@lightaime Yeah, I would like to avoid using torch.nn.Module as a cache. If you have better ideas how to implement this, I am all ears. I am not super happy with the current implementation, but I failed to come up with a better approach.

@rusty1s rusty1s merged commit 4c1c66f into master Nov 23, 2022
@rusty1s rusty1s deleted the fusion branch November 23, 2022 08:40
JakubPietrakIntel pushed a commit to JakubPietrakIntel/pytorch_geometric that referenced this pull request Nov 25, 2022
+------------------------------+---------+---------+
    | Aggregators                  | Vanilla | Fusion  |
    +==============================+=========+=========+
    | :obj:`[sum, mean]`           | 0.4019s | 0.1666s |
    +------------------------------+---------+---------+
    | :obj:`[sum, mean, min, max]` | 0.7841s | 0.4223s |
    +------------------------------+---------+---------+
    | :obj:`[sum, mean, var]`      | 0.9711s | 0.3614s |
    +------------------------------+---------+---------+
    | :obj:`[sum, mean, var, std]` | 1.5994s | 0.3722s |
    +------------------------------+---------+---------+
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants