Skip to content

Commit

Permalink
RandomLinkSplit: Allow edge_type == rev_edge_type (#4757)
Browse files Browse the repository at this point in the history
* initial commit

* changelog
  • Loading branch information
rusty1s authored Jun 2, 2022
1 parent cf2010b commit 09f25e9
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
### Changed
- Allow `edge_type == rev_edge_type` argument in `RandomLinkSplit` ([#4757](https://github.com/pyg-team/pytorch_geometric/pull/4757))
- Fixed a numerical instability in the `GeneralConv` and `neighbor_sample` tests ([#4754](https://github.com/pyg-team/pytorch_geometric/pull/4754))
- Fixed a bug in `HANConv` in which destination node features rather than source node features were propagated ([#4753](https://github.com/pyg-team/pytorch_geometric/pull/4753))
- Fixed versions of `checkout` and `setup-python` in CI ([#4751](https://github.com/pyg-team/pytorch_geometric/pull/4751))
Expand Down
16 changes: 16 additions & 0 deletions test/transforms/test_random_link_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,19 @@ def test_random_link_split_on_hetero_data():
train_data['p', 'p'].edge_attr)
assert train_data['p', 'a'].edge_index.size() == (2, 600)
assert train_data['a', 'p'].edge_index.size() == (2, 600)


def test_random_link_split_on_undirected_hetero_data():
data = HeteroData()
data['p'].x = torch.arange(100)
data['p', 'p'].edge_index = get_edge_index(100, 100, 500)
data['p', 'p'].edge_index = to_undirected(data['p', 'p'].edge_index)

transform = RandomLinkSplit(is_undirected=True, edge_types=('p', 'p'))
train_data, val_data, test_data = transform(data)
assert train_data['p', 'p'].is_undirected()

transform = RandomLinkSplit(is_undirected=True, edge_types=('p', 'p'),
rev_edge_types=('p', 'p'))
train_data, val_data, test_data = transform(data)
assert train_data['p', 'p'].is_undirected()
3 changes: 2 additions & 1 deletion torch_geometric/transforms/random_link_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ def __call__(self, data: Union[Data, HeteroData]):

is_undirected = self.is_undirected
is_undirected &= not store.is_bipartite()
is_undirected &= rev_edge_type is None
is_undirected &= (rev_edge_type is None
or store._key == data[rev_edge_type]._key)

edge_index = store.edge_index
if is_undirected:
Expand Down

0 comments on commit 09f25e9

Please sign in to comment.