diff --git a/CHANGELOG.md b/CHANGELOG.md index d9d7a922ae08..ddd3bcbeb156 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.0.5] - 2022-MM-DD ### Added +- Added a `max_sample` argument to `AddMetaPaths` in order to tackle very dense metapath edges ([#4750](https://github.com/pyg-team/pytorch_geometric/pull/4750)) - Test `HANConv` with empty tensors ([#4756](https://github.com/pyg-team/pytorch_geometric/pull/4756)) - Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755)) - Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926)) diff --git a/test/transforms/test_add_metapaths.py b/test/transforms/test_add_metapaths.py index 2b5c44c9a524..7895ff3c721d 100644 --- a/test/transforms/test_add_metapaths.py +++ b/test/transforms/test_add_metapaths.py @@ -1,61 +1,89 @@ +import copy + import torch +from torch import tensor from torch_geometric.data import HeteroData from torch_geometric.transforms import AddMetaPaths def test_add_metapaths(): - dblp = HeteroData() - dblp['paper'].x = torch.ones(5) - dblp['author'].x = torch.ones(6) - dblp['conference'].x = torch.ones(3) - dblp['paper', 'cites', 'paper'].edge_index = torch.tensor([[0, 1, 2, 3], - [1, 2, 4, 2]]) - dblp['paper', 'author'].edge_index = torch.tensor([[0, 1, 2, 3, 4], - [2, 2, 5, 2, 5]]) - dblp['author', 'paper'].edge_index = dblp['paper', - 'author'].edge_index[[1, 0]] - dblp['conference', 'paper'].edge_index = torch.tensor([[0, 0, 1, 2, 2], - [0, 1, 2, 3, 4]]) - dblp['paper', 'conference'].edge_index = dblp['conference', - 'paper'].edge_index[[1, 0]] + data = HeteroData() + data['p'].x = torch.ones(5) + data['a'].x = torch.ones(6) + data['c'].x = torch.ones(3) + data['p', 'p'].edge_index = tensor([[0, 1, 2, 3], [1, 2, 4, 2]]) + data['p', 'a'].edge_index = tensor([[0, 1, 2, 3, 4], [2, 2, 5, 2, 5]]) + data['a', 'p'].edge_index = data['p', 'a'].edge_index.flip([0]) + data['c', 'p'].edge_index = tensor([[0, 0, 1, 2, 2], [0, 1, 2, 3, 4]]) + data['p', 'c'].edge_index = data['c', 'p'].edge_index.flip([0]) # Test transform options: - orig_edge_type = dblp.edge_types - metapaths = [[('paper', 'conference'), ('conference', 'paper')]] - meta1 = AddMetaPaths(metapaths)(dblp.clone()) - meta2 = AddMetaPaths(metapaths, drop_orig_edges=True)(dblp.clone()) - meta3 = AddMetaPaths(metapaths, drop_orig_edges=True, - keep_same_node_type=True)(dblp.clone()) - meta4 = AddMetaPaths(metapaths, drop_orig_edges=True, - keep_same_node_type=True, - drop_unconnected_nodes=True)(dblp.clone()) - - assert meta1['paper', 'metapath_0', 'paper'].edge_index.shape[-1] == 9 - assert meta2['paper', 'metapath_0', 'paper'].edge_index.shape[-1] == 9 - assert meta3['paper', 'metapath_0', 'paper'].edge_index.shape[-1] == 9 - assert meta4['paper', 'metapath_0', 'paper'].edge_index.shape[-1] == 9 - - assert all([i in meta1.edge_types for i in orig_edge_type]) - assert meta2.edge_types == [('paper', 'metapath_0', 'paper')] - assert meta3.edge_types == [('paper', 'cites', 'paper'), - ('paper', 'metapath_0', 'paper')] - assert meta4.edge_types == [('paper', 'cites', 'paper'), - ('paper', 'metapath_0', 'paper')] - - assert meta3.node_types == ['paper', 'author', 'conference'] - assert meta4.node_types == ['paper'] + metapaths = [[('p', 'c'), ('c', 'p')]] + + transform = AddMetaPaths(metapaths) + assert str(transform) == 'AddMetaPaths()' + meta1 = transform(copy.copy(data)) + + transform = AddMetaPaths(metapaths, drop_orig_edges=True) + assert str(transform) == 'AddMetaPaths()' + meta2 = transform(copy.copy(data)) + + transform = AddMetaPaths(metapaths, drop_orig_edges=True, + keep_same_node_type=True) + assert str(transform) == 'AddMetaPaths()' + meta3 = transform(copy.copy(data)) + + transform = AddMetaPaths(metapaths, drop_orig_edges=True, + keep_same_node_type=True, + drop_unconnected_nodes=True) + assert str(transform) == 'AddMetaPaths()' + meta4 = transform(copy.copy(data)) + + assert meta1['metapath_0'].edge_index.size() == (2, 9) + assert meta2['metapath_0'].edge_index.size() == (2, 9) + assert meta3['metapath_0'].edge_index.size() == (2, 9) + assert meta4['metapath_0'].edge_index.size() == (2, 9) + + assert all([i in meta1.edge_types for i in data.edge_types]) + assert meta2.edge_types == [('p', 'metapath_0', 'p')] + assert meta3.edge_types == [('p', 'to', 'p'), ('p', 'metapath_0', 'p')] + assert meta4.edge_types == [('p', 'to', 'p'), ('p', 'metapath_0', 'p')] + + assert meta3.node_types == ['p', 'a', 'c'] + assert meta4.node_types == ['p'] # Test 4-hop metapath: - metapaths = [[('author', 'paper'), ('paper', 'conference')], - [('author', 'paper'), ('paper', 'conference'), - ('conference', 'paper'), ('paper', 'author')]] - meta1 = AddMetaPaths(metapaths)(dblp.clone()) - new_edge_types = [('author', 'metapath_0', 'conference'), - ('author', 'metapath_1', 'author')] - assert meta1[new_edge_types[0]].edge_index.shape[-1] == 4 - assert meta1[new_edge_types[1]].edge_index.shape[-1] == 4 + metapaths = [ + [('a', 'p'), ('p', 'c')], + [('a', 'p'), ('p', 'c'), ('c', 'p'), ('p', 'a')], + ] + transform = AddMetaPaths(metapaths) + meta = transform(copy.copy(data)) + new_edge_types = [('a', 'metapath_0', 'c'), ('a', 'metapath_1', 'a')] + assert meta['metapath_0'].edge_index.size() == (2, 4) + assert meta['metapath_1'].edge_index.size() == (2, 4) # Test `metapath_dict` information: - assert list(meta1.metapath_dict.values()) == metapaths - assert list(meta1.metapath_dict.keys()) == new_edge_types + assert list(meta.metapath_dict.values()) == metapaths + assert list(meta.metapath_dict.keys()) == new_edge_types + + +def test_add_metapaths_max_sample(): + torch.manual_seed(12345) + + data = HeteroData() + data['p'].x = torch.ones(5) + data['a'].x = torch.ones(6) + data['c'].x = torch.ones(3) + data['p', 'p'].edge_index = tensor([[0, 1, 2, 3], [1, 2, 4, 2]]) + data['p', 'a'].edge_index = tensor([[0, 1, 2, 3, 4], [2, 2, 5, 2, 5]]) + data['a', 'p'].edge_index = data['p', 'a'].edge_index.flip([0]) + data['c', 'p'].edge_index = tensor([[0, 0, 1, 2, 2], [0, 1, 2, 3, 4]]) + data['p', 'c'].edge_index = data['c', 'p'].edge_index.flip([0]) + + metapaths = [[('p', 'c'), ('c', 'p')]] + transform = AddMetaPaths(metapaths, max_sample=1) + + meta = transform(data) + assert meta['metapath_0'].edge_index.size(1) < 9 diff --git a/torch_geometric/transforms/add_metapaths.py b/torch_geometric/transforms/add_metapaths.py index e9e213dec42c..8f68bfe1c7a4 100644 --- a/torch_geometric/transforms/add_metapaths.py +++ b/torch_geometric/transforms/add_metapaths.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import torch from torch_sparse import SparseTensor @@ -7,6 +7,7 @@ from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform from torch_geometric.typing import EdgeType +from torch_geometric.utils import degree @functional_transform('add_metapaths') @@ -83,11 +84,18 @@ class AddMetaPaths(BaseTransform): (default: :obj:`False`) drop_unconnected_nodes (bool, optional): If set to :obj:`True` drop node types not connected by any edge type. (default: :obj:`False`) + max_sample (int, optional): If set, will sample at maximum + :obj:`max_sample` neighbors within metapaths. Useful in order to + tackle very dense metapath edges. (default: :obj:`None`) """ - def __init__(self, metapaths: List[List[EdgeType]], - drop_orig_edges: bool = False, - keep_same_node_type: bool = False, - drop_unconnected_nodes: bool = False): + def __init__( + self, + metapaths: List[List[EdgeType]], + drop_orig_edges: bool = False, + keep_same_node_type: bool = False, + drop_unconnected_nodes: bool = False, + max_sample: Optional[int] = None, + ): for path in metapaths: assert len(path) >= 2, f"Invalid metapath '{path}'" @@ -99,6 +107,7 @@ def __init__(self, metapaths: List[List[EdgeType]], self.drop_orig_edges = drop_orig_edges self.keep_same_node_type = keep_same_node_type self.drop_unconnected_nodes = drop_unconnected_nodes + self.max_sample = max_sample def __call__(self, data: HeteroData) -> HeteroData: edge_types = data.edge_types # save original edge types @@ -114,12 +123,19 @@ def __call__(self, data: HeteroData) -> HeteroData: edge_index=data[edge_type].edge_index, sparse_sizes=data[edge_type].size()) + if self.max_sample is not None: + adj1 = self.sample_adj(adj1) + for i, edge_type in enumerate(metapath[1:]): adj2 = SparseTensor.from_edge_index( edge_index=data[edge_type].edge_index, sparse_sizes=data[edge_type].size()) + adj1 = adj1 @ adj2 + if self.max_sample is not None: + adj1 = self.sample_adj(adj1) + row, col, _ = adj1.coo() new_edge_type = (metapath[0][0], f'metapath_{j}', metapath[-1][-1]) data[new_edge_type].edge_index = torch.vstack([row, col]) @@ -145,3 +161,10 @@ def __call__(self, data: HeteroData) -> HeteroData: del data[node] return data + + def sample_adj(self, adj: SparseTensor) -> SparseTensor: + row, col, _ = adj.coo() + deg = degree(row, num_nodes=adj.size(0)) + prob = (self.max_sample * (1. / deg))[row] + mask = torch.rand_like(prob) < prob + return adj.masked_select_nnz(mask, layout='coo')