Skip to content

Commit

Permalink
Changed view to reshape in LCMAggregation (#8026)
Browse files Browse the repository at this point in the history
`view` resulted in an error in some cases, so I changed it to `reshape`.
I also modified the tests to reflect such a scenario.
  • Loading branch information
ArchieGertsman authored Sep 14, 2023
1 parent 3e69022 commit 13cbae6
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `IBMBNodeLoader` and `IBMBBatchLoader` data loaders ([#6230](https://github.com/pyg-team/pytorch_geometric/pull/6230))
- Added the `NeuralFingerprint` model for learning fingerprints of molecules ([#7919](https://github.com/pyg-team/pytorch_geometric/pull/7919))
- Added `SparseTensor` support to `WLConvContinuous`, `GeneralConv`, `PDNConv` and `ARMAConv` ([#8013](https://github.com/pyg-team/pytorch_geometric/pull/8013))
- Added `LCMAggregation`, an implementation of Learnable Communitive Monoids ([#7976](https://github.com/pyg-team/pytorch_geometric/pull/7976), [#8023](https://github.com/pyg-team/pytorch_geometric/pull/8023))
- Added `LCMAggregation`, an implementation of Learnable Communitive Monoids ([#7976](https://github.com/pyg-team/pytorch_geometric/pull/7976), [#8023](https://github.com/pyg-team/pytorch_geometric/pull/8023), [#8026](https://github.com/pyg-team/pytorch_geometric/pull/8026))
- Added a warning for isolated/non-existing node types in `HeteroData.validate()` ([#7995](https://github.com/pyg-team/pytorch_geometric/pull/7995))
- Added `utils.cumsum` implementation ([#7994](https://github.com/pyg-team/pytorch_geometric/pull/7994))
- Added the `BrcaTcga` dataset ([#7905](https://github.com/pyg-team/pytorch_geometric/pull/7905))
Expand Down
4 changes: 2 additions & 2 deletions test/nn/aggr/test_lcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def test_lcm_aggregation_with_project():


def test_lcm_aggregation_without_project():
x = torch.randn(6, 16)
index = torch.tensor([0, 0, 1, 1, 1, 2])
x = torch.randn(5, 16)
index = torch.tensor([0, 1, 1, 2, 2])

aggr = LCMAggregation(16, 16, project=False)
assert str(aggr) == 'LCMAggregation(16, 16, project=False)'
Expand Down
6 changes: 3 additions & 3 deletions torch_geometric/nn/aggr/lcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,15 @@ def forward(
if x.size(0) % 2 == 1:
# This level of the tree has an odd number of nodes, so the
# remaining unmatched node gets moved to the next level.
x, remainder = x[:-1].contiguous(), x[-1:]
x, remainder = x[:-1], x[-1:]
else:
remainder = None

left_right = x.view(-1, 2, num_nodes, num_features)
right_left = left_right.flip(dims=[1])

left_right = left_right.view(-1, num_features)
right_left = right_left.view(-1, num_features)
left_right = left_right.reshape(-1, num_features)
right_left = right_left.reshape(-1, num_features)

# Execute the GRUCell for all (left, right) pairs in the current
# level of the tree in parallel:
Expand Down

0 comments on commit 13cbae6

Please sign in to comment.