Skip to content

Commit

Permalink
Sampling according to max_sample within AddMetaPaths (#4750)
Browse files Browse the repository at this point in the history
* sampling according to max_sample

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* debugging prints

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove coalesce

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use sparse adj for sampling

* doc

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rename

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix test

* update

* linting

* update

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Jun 3, 2022
1 parent 09f25e9 commit 8bd9ae4
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 53 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added a `max_sample` argument to `AddMetaPaths` in order to tackle very dense metapath edges ([#4750](https://github.com/pyg-team/pytorch_geometric/pull/4750))
- Test `HANConv` with empty tensors ([#4756](https://github.com/pyg-team/pytorch_geometric/pull/4756))
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
Expand Down
124 changes: 76 additions & 48 deletions test/transforms/test_add_metapaths.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,89 @@
import copy

import torch
from torch import tensor

from torch_geometric.data import HeteroData
from torch_geometric.transforms import AddMetaPaths


def test_add_metapaths():
dblp = HeteroData()
dblp['paper'].x = torch.ones(5)
dblp['author'].x = torch.ones(6)
dblp['conference'].x = torch.ones(3)
dblp['paper', 'cites', 'paper'].edge_index = torch.tensor([[0, 1, 2, 3],
[1, 2, 4, 2]])
dblp['paper', 'author'].edge_index = torch.tensor([[0, 1, 2, 3, 4],
[2, 2, 5, 2, 5]])
dblp['author', 'paper'].edge_index = dblp['paper',
'author'].edge_index[[1, 0]]
dblp['conference', 'paper'].edge_index = torch.tensor([[0, 0, 1, 2, 2],
[0, 1, 2, 3, 4]])
dblp['paper', 'conference'].edge_index = dblp['conference',
'paper'].edge_index[[1, 0]]
data = HeteroData()
data['p'].x = torch.ones(5)
data['a'].x = torch.ones(6)
data['c'].x = torch.ones(3)
data['p', 'p'].edge_index = tensor([[0, 1, 2, 3], [1, 2, 4, 2]])
data['p', 'a'].edge_index = tensor([[0, 1, 2, 3, 4], [2, 2, 5, 2, 5]])
data['a', 'p'].edge_index = data['p', 'a'].edge_index.flip([0])
data['c', 'p'].edge_index = tensor([[0, 0, 1, 2, 2], [0, 1, 2, 3, 4]])
data['p', 'c'].edge_index = data['c', 'p'].edge_index.flip([0])

# Test transform options:
orig_edge_type = dblp.edge_types
metapaths = [[('paper', 'conference'), ('conference', 'paper')]]
meta1 = AddMetaPaths(metapaths)(dblp.clone())
meta2 = AddMetaPaths(metapaths, drop_orig_edges=True)(dblp.clone())
meta3 = AddMetaPaths(metapaths, drop_orig_edges=True,
keep_same_node_type=True)(dblp.clone())
meta4 = AddMetaPaths(metapaths, drop_orig_edges=True,
keep_same_node_type=True,
drop_unconnected_nodes=True)(dblp.clone())

assert meta1['paper', 'metapath_0', 'paper'].edge_index.shape[-1] == 9
assert meta2['paper', 'metapath_0', 'paper'].edge_index.shape[-1] == 9
assert meta3['paper', 'metapath_0', 'paper'].edge_index.shape[-1] == 9
assert meta4['paper', 'metapath_0', 'paper'].edge_index.shape[-1] == 9

assert all([i in meta1.edge_types for i in orig_edge_type])
assert meta2.edge_types == [('paper', 'metapath_0', 'paper')]
assert meta3.edge_types == [('paper', 'cites', 'paper'),
('paper', 'metapath_0', 'paper')]
assert meta4.edge_types == [('paper', 'cites', 'paper'),
('paper', 'metapath_0', 'paper')]

assert meta3.node_types == ['paper', 'author', 'conference']
assert meta4.node_types == ['paper']
metapaths = [[('p', 'c'), ('c', 'p')]]

transform = AddMetaPaths(metapaths)
assert str(transform) == 'AddMetaPaths()'
meta1 = transform(copy.copy(data))

transform = AddMetaPaths(metapaths, drop_orig_edges=True)
assert str(transform) == 'AddMetaPaths()'
meta2 = transform(copy.copy(data))

transform = AddMetaPaths(metapaths, drop_orig_edges=True,
keep_same_node_type=True)
assert str(transform) == 'AddMetaPaths()'
meta3 = transform(copy.copy(data))

transform = AddMetaPaths(metapaths, drop_orig_edges=True,
keep_same_node_type=True,
drop_unconnected_nodes=True)
assert str(transform) == 'AddMetaPaths()'
meta4 = transform(copy.copy(data))

assert meta1['metapath_0'].edge_index.size() == (2, 9)
assert meta2['metapath_0'].edge_index.size() == (2, 9)
assert meta3['metapath_0'].edge_index.size() == (2, 9)
assert meta4['metapath_0'].edge_index.size() == (2, 9)

assert all([i in meta1.edge_types for i in data.edge_types])
assert meta2.edge_types == [('p', 'metapath_0', 'p')]
assert meta3.edge_types == [('p', 'to', 'p'), ('p', 'metapath_0', 'p')]
assert meta4.edge_types == [('p', 'to', 'p'), ('p', 'metapath_0', 'p')]

assert meta3.node_types == ['p', 'a', 'c']
assert meta4.node_types == ['p']

# Test 4-hop metapath:
metapaths = [[('author', 'paper'), ('paper', 'conference')],
[('author', 'paper'), ('paper', 'conference'),
('conference', 'paper'), ('paper', 'author')]]
meta1 = AddMetaPaths(metapaths)(dblp.clone())
new_edge_types = [('author', 'metapath_0', 'conference'),
('author', 'metapath_1', 'author')]
assert meta1[new_edge_types[0]].edge_index.shape[-1] == 4
assert meta1[new_edge_types[1]].edge_index.shape[-1] == 4
metapaths = [
[('a', 'p'), ('p', 'c')],
[('a', 'p'), ('p', 'c'), ('c', 'p'), ('p', 'a')],
]
transform = AddMetaPaths(metapaths)
meta = transform(copy.copy(data))
new_edge_types = [('a', 'metapath_0', 'c'), ('a', 'metapath_1', 'a')]
assert meta['metapath_0'].edge_index.size() == (2, 4)
assert meta['metapath_1'].edge_index.size() == (2, 4)

# Test `metapath_dict` information:
assert list(meta1.metapath_dict.values()) == metapaths
assert list(meta1.metapath_dict.keys()) == new_edge_types
assert list(meta.metapath_dict.values()) == metapaths
assert list(meta.metapath_dict.keys()) == new_edge_types


def test_add_metapaths_max_sample():
torch.manual_seed(12345)

data = HeteroData()
data['p'].x = torch.ones(5)
data['a'].x = torch.ones(6)
data['c'].x = torch.ones(3)
data['p', 'p'].edge_index = tensor([[0, 1, 2, 3], [1, 2, 4, 2]])
data['p', 'a'].edge_index = tensor([[0, 1, 2, 3, 4], [2, 2, 5, 2, 5]])
data['a', 'p'].edge_index = data['p', 'a'].edge_index.flip([0])
data['c', 'p'].edge_index = tensor([[0, 0, 1, 2, 2], [0, 1, 2, 3, 4]])
data['p', 'c'].edge_index = data['c', 'p'].edge_index.flip([0])

metapaths = [[('p', 'c'), ('c', 'p')]]
transform = AddMetaPaths(metapaths, max_sample=1)

meta = transform(data)
assert meta['metapath_0'].edge_index.size(1) < 9
33 changes: 28 additions & 5 deletions torch_geometric/transforms/add_metapaths.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional

import torch
from torch_sparse import SparseTensor
Expand All @@ -7,6 +7,7 @@
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.typing import EdgeType
from torch_geometric.utils import degree


@functional_transform('add_metapaths')
Expand Down Expand Up @@ -83,11 +84,18 @@ class AddMetaPaths(BaseTransform):
(default: :obj:`False`)
drop_unconnected_nodes (bool, optional): If set to :obj:`True` drop
node types not connected by any edge type. (default: :obj:`False`)
max_sample (int, optional): If set, will sample at maximum
:obj:`max_sample` neighbors within metapaths. Useful in order to
tackle very dense metapath edges. (default: :obj:`None`)
"""
def __init__(self, metapaths: List[List[EdgeType]],
drop_orig_edges: bool = False,
keep_same_node_type: bool = False,
drop_unconnected_nodes: bool = False):
def __init__(
self,
metapaths: List[List[EdgeType]],
drop_orig_edges: bool = False,
keep_same_node_type: bool = False,
drop_unconnected_nodes: bool = False,
max_sample: Optional[int] = None,
):

for path in metapaths:
assert len(path) >= 2, f"Invalid metapath '{path}'"
Expand All @@ -99,6 +107,7 @@ def __init__(self, metapaths: List[List[EdgeType]],
self.drop_orig_edges = drop_orig_edges
self.keep_same_node_type = keep_same_node_type
self.drop_unconnected_nodes = drop_unconnected_nodes
self.max_sample = max_sample

def __call__(self, data: HeteroData) -> HeteroData:
edge_types = data.edge_types # save original edge types
Expand All @@ -114,12 +123,19 @@ def __call__(self, data: HeteroData) -> HeteroData:
edge_index=data[edge_type].edge_index,
sparse_sizes=data[edge_type].size())

if self.max_sample is not None:
adj1 = self.sample_adj(adj1)

for i, edge_type in enumerate(metapath[1:]):
adj2 = SparseTensor.from_edge_index(
edge_index=data[edge_type].edge_index,
sparse_sizes=data[edge_type].size())

adj1 = adj1 @ adj2

if self.max_sample is not None:
adj1 = self.sample_adj(adj1)

row, col, _ = adj1.coo()
new_edge_type = (metapath[0][0], f'metapath_{j}', metapath[-1][-1])
data[new_edge_type].edge_index = torch.vstack([row, col])
Expand All @@ -145,3 +161,10 @@ def __call__(self, data: HeteroData) -> HeteroData:
del data[node]

return data

def sample_adj(self, adj: SparseTensor) -> SparseTensor:
row, col, _ = adj.coo()
deg = degree(row, num_nodes=adj.size(0))
prob = (self.max_sample * (1. / deg))[row]
mask = torch.rand_like(prob) < prob
return adj.masked_select_nnz(mask, layout='coo')

0 comments on commit 8bd9ae4

Please sign in to comment.