Skip to content

Commit

Permalink
Weighted sampling in NeighborLoader and LinkNeighborLoader (#8038)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Sep 15, 2023
1 parent eb15f68 commit 080a6e9
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for weighted/biased sampling in `NeighborLoader`/`LinkNeighborLoader` ([#8038](https://github.com/pyg-team/pytorch_geometric/pull/8038))
- Added the `MixHopConv` layer and an corresponding example ([#8025](https://github.com/pyg-team/pytorch_geometric/pull/8025))
- Added the option to pass keyword arguments to the underlying normalization layers within `BasicGNN` and `MLP` ([#8024](https://github.com/pyg-team/pytorch_geometric/pull/8024), [#8033](https://github.com/pyg-team/pytorch_geometric/pull/8033))
- Added `IBMBNodeLoader` and `IBMBBatchLoader` data loaders ([#6230](https://github.com/pyg-team/pytorch_geometric/pull/6230))
Expand Down
70 changes: 69 additions & 1 deletion test/loader/test_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
withCUDA,
withPackage,
)
from torch_geometric.typing import WITH_PYG_LIB, WITH_TORCH_SPARSE
from torch_geometric.typing import (
WITH_PYG_LIB,
WITH_TORCH_SPARSE,
WITH_WEIGHTED_NEIGHBOR_SAMPLE,
)
from torch_geometric.utils import (
is_undirected,
sort_edge_index,
Expand Down Expand Up @@ -714,3 +718,67 @@ def test_neighbor_loader_mapping():
batch.n_id[batch.edge_index],
data.edge_index[:, batch.e_id],
)


@pytest.mark.skipif(
not WITH_WEIGHTED_NEIGHBOR_SAMPLE,
reason="'pyg-lib' does not support weighted neighbor sampling",
)
def test_weighted_homo_neighbor_loader():
edge_index = torch.tensor([
[1, 3, 0, 4],
[2, 2, 1, 3],
])
edge_weight = torch.tensor([0.0, 1.0, 0.0, 1.0])

data = Data(num_nodes=5, edge_index=edge_index, edge_weight=edge_weight)

loader = NeighborLoader(
data,
input_nodes=torch.tensor([2]),
num_neighbors=[1] * 2,
batch_size=1,
weight_attr='edge_weight',
)
assert len(loader) == 1

batch = next(iter(loader))

assert batch.num_nodes == 3
assert batch.n_id.tolist() == [2, 3, 4]
assert batch.num_edges == 2
assert batch.n_id[batch.edge_index].tolist() == [[3, 4], [2, 3]]


@pytest.mark.skipif(
not WITH_WEIGHTED_NEIGHBOR_SAMPLE,
reason="'pyg-lib' does not support weighted neighbor sampling",
)
def test_weighted_hetero_neighbor_loader():
edge_index = torch.tensor([
[1, 3, 0, 4],
[2, 2, 1, 3],
])
edge_weight = torch.tensor([0.0, 1.0, 0.0, 1.0])

data = HeteroData()
data['paper'].num_nodes = 5
data['paper', 'to', 'paper'].edge_index = edge_index
data['paper', 'to', 'paper'].edge_weight = edge_weight

loader = NeighborLoader(
data,
input_nodes=('paper', torch.tensor([2])),
num_neighbors=[1] * 2,
batch_size=1,
weight_attr='edge_weight',
)
assert len(loader) == 1

batch = next(iter(loader))

assert batch['paper'].num_nodes == 3
assert batch['paper'].n_id.tolist() == [2, 3, 4]
assert batch['paper', 'paper'].num_edges == 2
global_edge_index = batch['paper'].n_id[batch['paper', 'paper'].edge_index]
assert global_edge_index.tolist() == [[3, 4], [2, 3]]
9 changes: 9 additions & 0 deletions torch_geometric/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,13 @@ class LinkNeighborLoader(LinkLoader):
guaranteed to fulfill temporal constraints, *i.e.* neighbors have
an earlier or equal timestamp than the center node.
Only used if :obj:`edge_label_time` is set. (default: :obj:`None`)
weight_attr (str, optional): The name of the attribute that denotes
edge weights in the graph.
If set, weighted/biased sampling will be used such that neighbors
are more likely to get sampled the higher their edge weights are.
Edge weights do not need to sum to one, but must be non-negative,
finite and have a non-zero sum within local neighborhoods.
(default: :obj:`None`)
transform (callable, optional): A function/transform that takes in
a sampled mini-batch and returns a transformed version.
(default: :obj:`None`)
Expand Down Expand Up @@ -207,6 +214,7 @@ def __init__(
neg_sampling: Optional[NegativeSampling] = None,
neg_sampling_ratio: Optional[Union[int, float]] = None,
time_attr: Optional[str] = None,
weight_attr: Optional[str] = None,
transform: Optional[Callable] = None,
transform_sampler_output: Optional[Callable] = None,
is_sorted: bool = False,
Expand All @@ -233,6 +241,7 @@ def __init__(
disjoint=disjoint,
temporal_strategy=temporal_strategy,
time_attr=time_attr,
weight_attr=weight_attr,
is_sorted=is_sorted,
share_memory=kwargs.get('num_workers', 0) > 0,
directed=directed,
Expand Down
9 changes: 9 additions & 0 deletions torch_geometric/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,13 @@ class NeighborLoader(NodeLoader):
guaranteed to fulfill temporal constraints, *i.e.* neighbors have
an earlier or equal timestamp than the center node.
(default: :obj:`None`)
weight_attr (str, optional): The name of the attribute that denotes
edge weights in the graph.
If set, weighted/biased sampling will be used such that neighbors
are more likely to get sampled the higher their edge weights are.
Edge weights do not need to sum to one, but must be non-negative,
finite and have a non-zero sum within local neighborhoods.
(default: :obj:`None`)
transform (callable, optional): A function/transform that takes in
a sampled mini-batch and returns a transformed version.
(default: :obj:`None`)
Expand Down Expand Up @@ -204,6 +211,7 @@ def __init__(
disjoint: bool = False,
temporal_strategy: str = 'uniform',
time_attr: Optional[str] = None,
weight_attr: Optional[str] = None,
transform: Optional[Callable] = None,
transform_sampler_output: Optional[Callable] = None,
is_sorted: bool = False,
Expand All @@ -226,6 +234,7 @@ def __init__(
disjoint=disjoint,
temporal_strategy=temporal_strategy,
time_attr=time_attr,
weight_attr=weight_attr,
is_sorted=is_sorted,
share_memory=kwargs.get('num_workers', 0) > 0,
directed=directed,
Expand Down
45 changes: 40 additions & 5 deletions torch_geometric/sampler/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
disjoint: bool = False,
temporal_strategy: str = 'uniform',
time_attr: Optional[str] = None,
weight_attr: Optional[str] = None,
is_sorted: bool = False,
share_memory: bool = False,
# Deprecated:
Expand All @@ -65,18 +66,30 @@ def __init__(

if self.data_type == DataType.homogeneous:
self.num_nodes = data.num_nodes
self.node_time = data[time_attr] if time_attr else None

self.node_time: Optional[Tensor] = None
if time_attr is not None:
self.node_time = data[time_attr]

# Convert the graph data into CSC format for sampling:
self.colptr, self.row, self.perm = to_csc(
data, device='cpu', share_memory=share_memory,
is_sorted=is_sorted, src_node_time=self.node_time)

self.edge_weight: Optional[Tensor] = None
if weight_attr is not None:
self.edge_weight = data[weight_attr]
if self.perm is not None:
self.edge_weight = self.edge_weight[self.perm]

elif self.data_type == DataType.heterogeneous:
self.node_types, self.edge_types = data.metadata()

self.num_nodes = {k: data[k].num_nodes for k in self.node_types}
self.node_time = data.collect(time_attr) if time_attr else None

self.node_time: Optional[Dict[NodeType, Tensor]] = None
if time_attr is not None:
self.node_time = data.collect(time_attr)

# Conversion to/from C++ string type: Since C++ cannot take
# dictionaries with tuples as key as input, edge type triplets need
Expand All @@ -91,6 +104,16 @@ def __init__(
self.row_dict = remap_keys(row_dict, self.to_rel_type)
self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type)

self.edge_weight: Optional[Dict[EdgeType, Tensor]] = None
if weight_attr is not None:
self.edge_weight = data.collect(weight_attr)
for edge_type, edge_weight in self.edge_weight.items():
if self.perm.get(edge_type, None) is not None:
edge_weight = edge_weight[self.perm[edge_type]]
self.edge_weight[edge_type] = edge_weight
self.edge_weight = remap_keys(self.edge_weight,
self.to_rel_type)

else: # self.data_type == DataType.remote
feature_store, graph_store = data

Expand All @@ -106,7 +129,7 @@ def __init__(
for node_type in self.node_types
}

self.node_time: Optional[Dict[str, Tensor]] = None
self.node_time: Optional[Dict[NodeType, Tensor]] = None
if time_attr is not None:
# If the `time_attr` is present, we expect that `GraphStore`
# holds all edges sorted by destination, and within local
Expand Down Expand Up @@ -136,6 +159,13 @@ def __init__(
for time_attr, time_tensor in zip(time_attrs, time_tensors)
}

self.edge_weight: Optional[Dict[EdgeType, Tensor]] = None
if weight_attr is not None:
raise NotImplementedError(
f"'weight_attr' argument not yet supported within "
f"'{self.__class__.__name__}' for "
f"'(FeatureStore, GraphStore)' inputs")

# Conversion to/from C++ string type (see above):
self.to_rel_type = {k: '__'.join(k) for k in self.edge_types}
self.to_edge_type = {v: k for k, v in self.to_rel_type.items()}
Expand All @@ -145,6 +175,11 @@ def __init__(
self.row_dict = remap_keys(row_dict, self.to_rel_type)
self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type)

if (self.edge_weight is not None
and not torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE):
raise ImportError("Weighted neighbor sampling requires "
"'pyg-lib>=0.3.0'")

self.num_neighbors = num_neighbors
self.replace = replace
self.subgraph_type = SubgraphType(subgraph_type)
Expand Down Expand Up @@ -233,7 +268,7 @@ def _sample(
seed_time,
)
if torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE:
args += (None, )
args += (self.edge_weight, )
args += (
True, # csc
self.replace,
Expand Down Expand Up @@ -313,7 +348,7 @@ def _sample(
seed_time,
)
if torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE:
args += (None, )
args += (self.edge_weight, )
args += (
True, # csc
self.replace,
Expand Down

0 comments on commit 080a6e9

Please sign in to comment.