Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampling according to max_sample within AddMetaPaths #4750

Merged
merged 19 commits into from
Jun 3, 2022
Merged
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added a `max_sample` argument to `AddMetaPaths` in order to tackle very dense metapath edges ([#4750](https://github.com/pyg-team/pytorch_geometric/pull/4750))
- Test `HANConv` with empty tensors ([#4756](https://github.com/pyg-team/pytorch_geometric/pull/4756))
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
125 changes: 77 additions & 48 deletions test/transforms/test_add_metapaths.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,90 @@
import copy

import torch
from torch import tensor
from torch_sparse import SparseTensor

from torch_geometric.data import HeteroData
from torch_geometric.transforms import AddMetaPaths


def test_add_metapaths():
dblp = HeteroData()
dblp['paper'].x = torch.ones(5)
dblp['author'].x = torch.ones(6)
dblp['conference'].x = torch.ones(3)
dblp['paper', 'cites', 'paper'].edge_index = torch.tensor([[0, 1, 2, 3],
[1, 2, 4, 2]])
dblp['paper', 'author'].edge_index = torch.tensor([[0, 1, 2, 3, 4],
[2, 2, 5, 2, 5]])
dblp['author', 'paper'].edge_index = dblp['paper',
'author'].edge_index[[1, 0]]
dblp['conference', 'paper'].edge_index = torch.tensor([[0, 0, 1, 2, 2],
[0, 1, 2, 3, 4]])
dblp['paper', 'conference'].edge_index = dblp['conference',
'paper'].edge_index[[1, 0]]
data = HeteroData()
data['p'].x = torch.ones(5)
data['a'].x = torch.ones(6)
data['c'].x = torch.ones(3)
data['p', 'p'].edge_index = tensor([[0, 1, 2, 3], [1, 2, 4, 2]])
data['p', 'a'].edge_index = tensor([[0, 1, 2, 3, 4], [2, 2, 5, 2, 5]])
data['a', 'p'].edge_index = data['p', 'a'].edge_index.flip([0])
data['c', 'p'].edge_index = tensor([[0, 0, 1, 2, 2], [0, 1, 2, 3, 4]])
data['p', 'c'].edge_index = data['c', 'p'].edge_index.flip([0])

# Test transform options:
orig_edge_type = dblp.edge_types
metapaths = [[('paper', 'conference'), ('conference', 'paper')]]
meta1 = AddMetaPaths(metapaths)(dblp.clone())
meta2 = AddMetaPaths(metapaths, drop_orig_edges=True)(dblp.clone())
meta3 = AddMetaPaths(metapaths, drop_orig_edges=True,
keep_same_node_type=True)(dblp.clone())
meta4 = AddMetaPaths(metapaths, drop_orig_edges=True,
keep_same_node_type=True,
drop_unconnected_nodes=True)(dblp.clone())

assert meta1['paper', 'metapath_0', 'paper'].edge_index.shape[-1] == 9
assert meta2['paper', 'metapath_0', 'paper'].edge_index.shape[-1] == 9
assert meta3['paper', 'metapath_0', 'paper'].edge_index.shape[-1] == 9
assert meta4['paper', 'metapath_0', 'paper'].edge_index.shape[-1] == 9

assert all([i in meta1.edge_types for i in orig_edge_type])
assert meta2.edge_types == [('paper', 'metapath_0', 'paper')]
assert meta3.edge_types == [('paper', 'cites', 'paper'),
('paper', 'metapath_0', 'paper')]
assert meta4.edge_types == [('paper', 'cites', 'paper'),
('paper', 'metapath_0', 'paper')]

assert meta3.node_types == ['paper', 'author', 'conference']
assert meta4.node_types == ['paper']
metapaths = [[('p', 'c'), ('c', 'p')]]

transform = AddMetaPaths(metapaths)
assert str(transform) == 'AddMetaPaths()'
meta1 = transform(copy.copy(data))

transform = AddMetaPaths(metapaths, drop_orig_edges=True)
assert str(transform) == 'AddMetaPaths()'
meta2 = transform(copy.copy(data))

transform = AddMetaPaths(metapaths, drop_orig_edges=True,
keep_same_node_type=True)
assert str(transform) == 'AddMetaPaths()'
meta3 = transform(copy.copy(data))

transform = AddMetaPaths(metapaths, drop_orig_edges=True,
keep_same_node_type=True,
drop_unconnected_nodes=True)
assert str(transform) == 'AddMetaPaths()'
meta4 = transform(copy.copy(data))

assert meta1['metapath_0'].edge_index.size() == (2, 9)
assert meta2['metapath_0'].edge_index.size() == (2, 9)
assert meta3['metapath_0'].edge_index.size() == (2, 9)
assert meta4['metapath_0'].edge_index.size() == (2, 9)

assert all([i in meta1.edge_types for i in data.edge_types])
assert meta2.edge_types == [('p', 'metapath_0', 'p')]
assert meta3.edge_types == [('p', 'to', 'p'), ('p', 'metapath_0', 'p')]
assert meta4.edge_types == [('p', 'to', 'p'), ('p', 'metapath_0', 'p')]

assert meta3.node_types == ['p', 'a', 'c']
assert meta4.node_types == ['p']

# Test 4-hop metapath:
metapaths = [[('author', 'paper'), ('paper', 'conference')],
[('author', 'paper'), ('paper', 'conference'),
('conference', 'paper'), ('paper', 'author')]]
meta1 = AddMetaPaths(metapaths)(dblp.clone())
new_edge_types = [('author', 'metapath_0', 'conference'),
('author', 'metapath_1', 'author')]
assert meta1[new_edge_types[0]].edge_index.shape[-1] == 4
assert meta1[new_edge_types[1]].edge_index.shape[-1] == 4
metapaths = [
[('a', 'p'), ('p', 'c')],
[('a', 'p'), ('p', 'c'), ('c', 'p'), ('p', 'a')],
]
transform = AddMetaPaths(metapaths)
meta = transform(copy.copy(data))
new_edge_types = [('a', 'metapath_0', 'c'), ('a', 'metapath_1', 'a')]
assert meta['metapath_0'].edge_index.size() == (2, 4)
assert meta['metapath_1'].edge_index.size() == (2, 4)

# Test `metapath_dict` information:
assert list(meta1.metapath_dict.values()) == metapaths
assert list(meta1.metapath_dict.keys()) == new_edge_types
assert list(meta.metapath_dict.values()) == metapaths
assert list(meta.metapath_dict.keys()) == new_edge_types


def test_add_metapaths_max_sample():
torch.manual_seed(12345)

data = HeteroData()
data['p'].x = torch.ones(5)
data['a'].x = torch.ones(6)
data['c'].x = torch.ones(3)
data['p', 'p'].edge_index = tensor([[0, 1, 2, 3], [1, 2, 4, 2]])
data['p', 'a'].edge_index = tensor([[0, 1, 2, 3, 4], [2, 2, 5, 2, 5]])
data['a', 'p'].edge_index = data['p', 'a'].edge_index.flip([0])
data['c', 'p'].edge_index = tensor([[0, 0, 1, 2, 2], [0, 1, 2, 3, 4]])
data['p', 'c'].edge_index = data['c', 'p'].edge_index.flip([0])

metapaths = [[('p', 'c'), ('c', 'p')]]
transform = AddMetaPaths(metapaths, max_sample=1)

meta = transform(data)
assert meta['metapath_0'].edge_index.size(1) < 9
32 changes: 27 additions & 5 deletions torch_geometric/transforms/add_metapaths.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List
import time
from typing import List, Optional

import torch
from torch_sparse import SparseTensor
@@ -83,11 +84,18 @@ class AddMetaPaths(BaseTransform):
(default: :obj:`False`)
drop_unconnected_nodes (bool, optional): If set to :obj:`True` drop
node types not connected by any edge type. (default: :obj:`False`)
max_sample (int, optional): If set, will sample at maximum
:obj:`max_sample` neighbors within metapaths. Useful in order to
tackle very dense metapath edges. (default: :obj:`None`)
"""
def __init__(self, metapaths: List[List[EdgeType]],
drop_orig_edges: bool = False,
keep_same_node_type: bool = False,
drop_unconnected_nodes: bool = False):
def __init__(
self,
metapaths: List[List[EdgeType]],
drop_orig_edges: bool = False,
keep_same_node_type: bool = False,
drop_unconnected_nodes: bool = False,
max_sample: Optional[int] = None,
):

for path in metapaths:
assert len(path) >= 2, f"Invalid metapath '{path}'"
@@ -99,6 +107,7 @@ def __init__(self, metapaths: List[List[EdgeType]],
self.drop_orig_edges = drop_orig_edges
self.keep_same_node_type = keep_same_node_type
self.drop_unconnected_nodes = drop_unconnected_nodes
self.max_sample = max_sample

def __call__(self, data: HeteroData) -> HeteroData:
edge_types = data.edge_types # save original edge types
@@ -114,12 +123,19 @@ def __call__(self, data: HeteroData) -> HeteroData:
edge_index=data[edge_type].edge_index,
sparse_sizes=data[edge_type].size())

if self.max_sample is not None:
adj1 = self.sample_adj(adj1)

for i, edge_type in enumerate(metapath[1:]):
adj2 = SparseTensor.from_edge_index(
edge_index=data[edge_type].edge_index,
sparse_sizes=data[edge_type].size())

adj1 = adj1 @ adj2

if self.max_sample is not None:
adj1 = self.sample_adj(adj1)

row, col, _ = adj1.coo()
new_edge_type = (metapath[0][0], f'metapath_{j}', metapath[-1][-1])
data[new_edge_type].edge_index = torch.vstack([row, col])
@@ -145,3 +161,9 @@ def __call__(self, data: HeteroData) -> HeteroData:
del data[node]

return data

def sample_adj(self, adj: SparseTensor) -> SparseTensor:
count = adj.sum(dim=1)
prob = (self.max_sample * (1. / count)).repeat_interleave(count.long())
mask = torch.rand_like(prob) < prob
return adj.masked_select_nnz(mask, layout='coo')