From 080a6e9246c040ac77516862b5729caf05996641 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Fri, 15 Sep 2023 11:17:30 +0200 Subject: [PATCH] Weighted sampling in `NeighborLoader` and `LinkNeighborLoader` (#8038) --- CHANGELOG.md | 1 + test/loader/test_neighbor_loader.py | 70 ++++++++++++++++++- .../loader/link_neighbor_loader.py | 9 +++ torch_geometric/loader/neighbor_loader.py | 9 +++ torch_geometric/sampler/neighbor_sampler.py | 45 ++++++++++-- 5 files changed, 128 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index baa53f116a81..5cb3241ae843 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for weighted/biased sampling in `NeighborLoader`/`LinkNeighborLoader` ([#8038](https://github.com/pyg-team/pytorch_geometric/pull/8038)) - Added the `MixHopConv` layer and an corresponding example ([#8025](https://github.com/pyg-team/pytorch_geometric/pull/8025)) - Added the option to pass keyword arguments to the underlying normalization layers within `BasicGNN` and `MLP` ([#8024](https://github.com/pyg-team/pytorch_geometric/pull/8024), [#8033](https://github.com/pyg-team/pytorch_geometric/pull/8033)) - Added `IBMBNodeLoader` and `IBMBBatchLoader` data loaders ([#6230](https://github.com/pyg-team/pytorch_geometric/pull/6230)) diff --git a/test/loader/test_neighbor_loader.py b/test/loader/test_neighbor_loader.py index 17e2dd326544..434558a8a7eb 100644 --- a/test/loader/test_neighbor_loader.py +++ b/test/loader/test_neighbor_loader.py @@ -20,7 +20,11 @@ withCUDA, withPackage, ) -from torch_geometric.typing import WITH_PYG_LIB, WITH_TORCH_SPARSE +from torch_geometric.typing import ( + WITH_PYG_LIB, + WITH_TORCH_SPARSE, + WITH_WEIGHTED_NEIGHBOR_SAMPLE, +) from torch_geometric.utils import ( is_undirected, sort_edge_index, @@ -714,3 +718,67 @@ def test_neighbor_loader_mapping(): batch.n_id[batch.edge_index], data.edge_index[:, batch.e_id], ) + + +@pytest.mark.skipif( + not WITH_WEIGHTED_NEIGHBOR_SAMPLE, + reason="'pyg-lib' does not support weighted neighbor sampling", +) +def test_weighted_homo_neighbor_loader(): + edge_index = torch.tensor([ + [1, 3, 0, 4], + [2, 2, 1, 3], + ]) + edge_weight = torch.tensor([0.0, 1.0, 0.0, 1.0]) + + data = Data(num_nodes=5, edge_index=edge_index, edge_weight=edge_weight) + + loader = NeighborLoader( + data, + input_nodes=torch.tensor([2]), + num_neighbors=[1] * 2, + batch_size=1, + weight_attr='edge_weight', + ) + assert len(loader) == 1 + + batch = next(iter(loader)) + + assert batch.num_nodes == 3 + assert batch.n_id.tolist() == [2, 3, 4] + assert batch.num_edges == 2 + assert batch.n_id[batch.edge_index].tolist() == [[3, 4], [2, 3]] + + +@pytest.mark.skipif( + not WITH_WEIGHTED_NEIGHBOR_SAMPLE, + reason="'pyg-lib' does not support weighted neighbor sampling", +) +def test_weighted_hetero_neighbor_loader(): + edge_index = torch.tensor([ + [1, 3, 0, 4], + [2, 2, 1, 3], + ]) + edge_weight = torch.tensor([0.0, 1.0, 0.0, 1.0]) + + data = HeteroData() + data['paper'].num_nodes = 5 + data['paper', 'to', 'paper'].edge_index = edge_index + data['paper', 'to', 'paper'].edge_weight = edge_weight + + loader = NeighborLoader( + data, + input_nodes=('paper', torch.tensor([2])), + num_neighbors=[1] * 2, + batch_size=1, + weight_attr='edge_weight', + ) + assert len(loader) == 1 + + batch = next(iter(loader)) + + assert batch['paper'].num_nodes == 3 + assert batch['paper'].n_id.tolist() == [2, 3, 4] + assert batch['paper', 'paper'].num_edges == 2 + global_edge_index = batch['paper'].n_id[batch['paper', 'paper'].edge_index] + assert global_edge_index.tolist() == [[3, 4], [2, 3]] diff --git a/torch_geometric/loader/link_neighbor_loader.py b/torch_geometric/loader/link_neighbor_loader.py index 7d9e8c79397f..9871088eef23 100644 --- a/torch_geometric/loader/link_neighbor_loader.py +++ b/torch_geometric/loader/link_neighbor_loader.py @@ -165,6 +165,13 @@ class LinkNeighborLoader(LinkLoader): guaranteed to fulfill temporal constraints, *i.e.* neighbors have an earlier or equal timestamp than the center node. Only used if :obj:`edge_label_time` is set. (default: :obj:`None`) + weight_attr (str, optional): The name of the attribute that denotes + edge weights in the graph. + If set, weighted/biased sampling will be used such that neighbors + are more likely to get sampled the higher their edge weights are. + Edge weights do not need to sum to one, but must be non-negative, + finite and have a non-zero sum within local neighborhoods. + (default: :obj:`None`) transform (callable, optional): A function/transform that takes in a sampled mini-batch and returns a transformed version. (default: :obj:`None`) @@ -207,6 +214,7 @@ def __init__( neg_sampling: Optional[NegativeSampling] = None, neg_sampling_ratio: Optional[Union[int, float]] = None, time_attr: Optional[str] = None, + weight_attr: Optional[str] = None, transform: Optional[Callable] = None, transform_sampler_output: Optional[Callable] = None, is_sorted: bool = False, @@ -233,6 +241,7 @@ def __init__( disjoint=disjoint, temporal_strategy=temporal_strategy, time_attr=time_attr, + weight_attr=weight_attr, is_sorted=is_sorted, share_memory=kwargs.get('num_workers', 0) > 0, directed=directed, diff --git a/torch_geometric/loader/neighbor_loader.py b/torch_geometric/loader/neighbor_loader.py index a5213b3b9882..9cbf0bcbaebe 100644 --- a/torch_geometric/loader/neighbor_loader.py +++ b/torch_geometric/loader/neighbor_loader.py @@ -165,6 +165,13 @@ class NeighborLoader(NodeLoader): guaranteed to fulfill temporal constraints, *i.e.* neighbors have an earlier or equal timestamp than the center node. (default: :obj:`None`) + weight_attr (str, optional): The name of the attribute that denotes + edge weights in the graph. + If set, weighted/biased sampling will be used such that neighbors + are more likely to get sampled the higher their edge weights are. + Edge weights do not need to sum to one, but must be non-negative, + finite and have a non-zero sum within local neighborhoods. + (default: :obj:`None`) transform (callable, optional): A function/transform that takes in a sampled mini-batch and returns a transformed version. (default: :obj:`None`) @@ -204,6 +211,7 @@ def __init__( disjoint: bool = False, temporal_strategy: str = 'uniform', time_attr: Optional[str] = None, + weight_attr: Optional[str] = None, transform: Optional[Callable] = None, transform_sampler_output: Optional[Callable] = None, is_sorted: bool = False, @@ -226,6 +234,7 @@ def __init__( disjoint=disjoint, temporal_strategy=temporal_strategy, time_attr=time_attr, + weight_attr=weight_attr, is_sorted=is_sorted, share_memory=kwargs.get('num_workers', 0) > 0, directed=directed, diff --git a/torch_geometric/sampler/neighbor_sampler.py b/torch_geometric/sampler/neighbor_sampler.py index f80ce9002476..2a4c3d0aa408 100644 --- a/torch_geometric/sampler/neighbor_sampler.py +++ b/torch_geometric/sampler/neighbor_sampler.py @@ -43,6 +43,7 @@ def __init__( disjoint: bool = False, temporal_strategy: str = 'uniform', time_attr: Optional[str] = None, + weight_attr: Optional[str] = None, is_sorted: bool = False, share_memory: bool = False, # Deprecated: @@ -65,18 +66,30 @@ def __init__( if self.data_type == DataType.homogeneous: self.num_nodes = data.num_nodes - self.node_time = data[time_attr] if time_attr else None + + self.node_time: Optional[Tensor] = None + if time_attr is not None: + self.node_time = data[time_attr] # Convert the graph data into CSC format for sampling: self.colptr, self.row, self.perm = to_csc( data, device='cpu', share_memory=share_memory, is_sorted=is_sorted, src_node_time=self.node_time) + self.edge_weight: Optional[Tensor] = None + if weight_attr is not None: + self.edge_weight = data[weight_attr] + if self.perm is not None: + self.edge_weight = self.edge_weight[self.perm] + elif self.data_type == DataType.heterogeneous: self.node_types, self.edge_types = data.metadata() self.num_nodes = {k: data[k].num_nodes for k in self.node_types} - self.node_time = data.collect(time_attr) if time_attr else None + + self.node_time: Optional[Dict[NodeType, Tensor]] = None + if time_attr is not None: + self.node_time = data.collect(time_attr) # Conversion to/from C++ string type: Since C++ cannot take # dictionaries with tuples as key as input, edge type triplets need @@ -91,6 +104,16 @@ def __init__( self.row_dict = remap_keys(row_dict, self.to_rel_type) self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type) + self.edge_weight: Optional[Dict[EdgeType, Tensor]] = None + if weight_attr is not None: + self.edge_weight = data.collect(weight_attr) + for edge_type, edge_weight in self.edge_weight.items(): + if self.perm.get(edge_type, None) is not None: + edge_weight = edge_weight[self.perm[edge_type]] + self.edge_weight[edge_type] = edge_weight + self.edge_weight = remap_keys(self.edge_weight, + self.to_rel_type) + else: # self.data_type == DataType.remote feature_store, graph_store = data @@ -106,7 +129,7 @@ def __init__( for node_type in self.node_types } - self.node_time: Optional[Dict[str, Tensor]] = None + self.node_time: Optional[Dict[NodeType, Tensor]] = None if time_attr is not None: # If the `time_attr` is present, we expect that `GraphStore` # holds all edges sorted by destination, and within local @@ -136,6 +159,13 @@ def __init__( for time_attr, time_tensor in zip(time_attrs, time_tensors) } + self.edge_weight: Optional[Dict[EdgeType, Tensor]] = None + if weight_attr is not None: + raise NotImplementedError( + f"'weight_attr' argument not yet supported within " + f"'{self.__class__.__name__}' for " + f"'(FeatureStore, GraphStore)' inputs") + # Conversion to/from C++ string type (see above): self.to_rel_type = {k: '__'.join(k) for k in self.edge_types} self.to_edge_type = {v: k for k, v in self.to_rel_type.items()} @@ -145,6 +175,11 @@ def __init__( self.row_dict = remap_keys(row_dict, self.to_rel_type) self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type) + if (self.edge_weight is not None + and not torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE): + raise ImportError("Weighted neighbor sampling requires " + "'pyg-lib>=0.3.0'") + self.num_neighbors = num_neighbors self.replace = replace self.subgraph_type = SubgraphType(subgraph_type) @@ -233,7 +268,7 @@ def _sample( seed_time, ) if torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE: - args += (None, ) + args += (self.edge_weight, ) args += ( True, # csc self.replace, @@ -313,7 +348,7 @@ def _sample( seed_time, ) if torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE: - args += (None, ) + args += (self.edge_weight, ) args += ( True, # csc self.replace,