-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deprecate nn.glob
package
#5039
Conversation
Codecov Report
@@ Coverage Diff @@
## master #5039 +/- ##
==========================================
- Coverage 84.74% 82.89% -1.85%
==========================================
Files 331 331
Lines 18162 18137 -25
==========================================
- Hits 15391 15035 -356
- Misses 2771 3102 +331
Help us with your feedback. Take ten seconds to tell us how you rate us. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! LGTM! Left a few minor comments.
torch_geometric/nn/pool/glob.py
Outdated
sum_aggr = SumAggregation() | ||
|
||
|
||
def global_add_pool(x: Tensor, index: Optional[Tensor], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we keep the original doc strings for global pools?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea
torch_geometric/nn/pool/glob.py
Outdated
|
||
@deprecated( | ||
details="use 'nn.aggr.GlobalSortAggr' instead", | ||
func_name='nn.glob.global_sort_pool', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is a bit confusing for me since the modules are now moved to nn.pool.glob.*
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks good point!
return scatter(x, batch, dim=-2, dim_size=size, reduce='max') | ||
|
||
|
||
class GlobalPooling(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please clarify why this is removed completely? If we need it can we implement it with MultiAggregation
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes removed completed (discussed with @rusty1s in earlier version of this PR - the code for this was added recently and we decided it was no longer needed)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense
test/nn/aggr/test_sort.py
Outdated
@@ -1,78 +0,0 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please clarify why the tests are removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mistake, added back.
Co-authored-by: Guohao Li <lightaime@gmail.com>
Co-authored-by: Guohao Li <lightaime@gmail.com>
Co-authored-by: Guohao Li <lightaime@gmail.com>
Thanks for the review @lightaime |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Padarn. I think my comments earlier let to some misunderstanding, sorry about that. Clarified more in the comments. Let me know what you think!
test/nn/glob/test_glob.py
Outdated
@@ -1,72 +0,0 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this file to pool/test_glob.py
?
torch_geometric/nn/pool/__init__.py
Outdated
@@ -11,6 +11,8 @@ | |||
from .asap import ASAPooling | |||
from .pan_pool import PANPooling | |||
from .mem_pool import MemPooling | |||
from .glob import (global_max_pool, global_mean_pool, global_add_pool, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah sorry, this is not what I meant. I meant:
- Move
global_add_pool
,global_mean_pool
, andglobal_max_pool
toglob/glob.py
to pool/glob.py`. - Move
glob/__init__.py
tonn/glob.py
and keep the deprecations in there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea here is that we keep the implementations of global_*_pool
as they were (since they are so heavily used).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm sorry I don't fully understand the distinction between these. In both cases you can
from torch_geometric.nn import global_max_pool
is it that you want to keep this?
from torch_geometric.nn.glob import global_max_pool
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or you just want to clearly separate the deprecations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we only need to
deprecate the import of torch_geometric.nn.glob
.
Ah yeah I misunderstood. Thanks for clarifying, I'll make the changes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Left some small comments.
torch_geometric/nn/pool/__init__.py
Outdated
'MemPooling', 'max_pool', 'avg_pool', 'max_pool_x', 'max_pool_neighbor_x', | ||
'avg_pool_x', 'avg_pool_neighbor_x', 'graclus', 'voxel_grid', 'fps', 'knn', | ||
'knn_graph', 'radius', 'radius_graph', 'nearest', 'global_max_pool', | ||
'global_add_pool', 'global_mean_pool' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add back a comma for auto-formatting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1. Let‘s also move the global pooling methods to the top.
torch_geometric/nn/glob.py
Outdated
|
||
Set2Set = deprecated( | ||
details="use 'nn.aggr.Set2Set' instead", | ||
func_name='nn.pool.Set2Set', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The deprecations are moved back to nn.glob.*
. Should we change back from nn.pool.*
to nn.glob.*
? Sorry about that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is correct. We provide backward compatibility in nn/glob.py
but deprecate its use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! This is exactly what I had in mind. Left some last comments.
torch_geometric/nn/glob.py
Outdated
|
||
Set2Set = deprecated( | ||
details="use 'nn.aggr.Set2Set' instead", | ||
func_name='nn.pool.Set2Set', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is correct. We provide backward compatibility in nn/glob.py
but deprecate its use.
@@ -0,0 +1,35 @@ | |||
from torch_geometric.deprecation import deprecated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The global_*_pool
methods need to be added here as well.
torch_geometric/nn/pool/__init__.py
Outdated
'MemPooling', 'max_pool', 'avg_pool', 'max_pool_x', 'max_pool_neighbor_x', | ||
'avg_pool_x', 'avg_pool_neighbor_x', 'graclus', 'voxel_grid', 'fps', 'knn', | ||
'knn_graph', 'radius', 'radius_graph', 'nearest', 'global_max_pool', | ||
'global_add_pool', 'global_mean_pool' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1. Let‘s also move the global pooling methods to the top.
torch_geometric/nn/pool/glob.py
Outdated
@@ -25,19 +29,19 @@ def global_add_pool(x: Tensor, batch: Optional[Tensor], | |||
""" | |||
if batch is None: | |||
return x.sum(dim=-2, keepdim=x.dim() == 2) | |||
size = int(batch.max().item() + 1) if size is None else size | |||
return scatter(x, batch, dim=-2, dim_size=size, reduce='add') | |||
return sum_aggr(x, batch, dim_size=size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we leave the implementation as it is? I am not super happy with having global modules here.
Thanks for the reviews again! I've addressed the comments. Will merge later unless any objections |
I have some problem with the docs build. Will figure it out tonight |
could be sphinx-doc/sphinx#10705 .. may have to wait until they release the new version |
Yes, don‘t worry about it. |
Removes use of
nn.glob
package. #4712