Skip to content

Commit

Permalink
NeighborLoader: Optionally use argument data without to_csc (#4620)
Browse files Browse the repository at this point in the history
* update

* update

* update

* changelog

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
ZenoTan and rusty1s authored May 11, 2022
1 parent 6f120ff commit e3ba9d3
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added faster initialization of `NeighborLoader` in case edge indices are already sorted (via `is_sorted=True`) ([#4620](https://github.com/pyg-team/pytorch_geometric/pull/4620))
- Added `AddPositionalEncoding` transform ([#4521](https://github.com/pyg-team/pytorch_geometric/pull/4521))
- Added `HeteroData.is_undirected()` support ([#4604](https://github.com/pyg-team/pytorch_geometric/pull/4604))
- Added the `Genius` and `Wiki` datasets to `nn.datasets.LINKXDataset` ([#4570](https://github.com/pyg-team/pytorch_geometric/pull/4570), [#4600](https://github.com/pyg-team/pytorch_geometric/pull/4600))
Expand Down
1 change: 1 addition & 0 deletions torch_geometric/data/lightning_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def __init__(
directed=kwargs.get('directed', True),
input_type=get_input_nodes(data, input_train_nodes)[0],
time_attr=kwargs.get('time_attr', None),
is_sorted=kwargs.get('is_sorted', False),
)
self.input_train_nodes = input_train_nodes
self.input_val_nodes = input_val_nodes
Expand Down
23 changes: 17 additions & 6 deletions torch_geometric/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,6 @@ class LinkNeighborLoader(torch.utils.data.DataLoader):
replacement. (default: :obj:`False`)
directed (bool, optional): If set to :obj:`False`, will include all
edges between all sampled nodes. (default: :obj:`True`)
transform (Callable, optional): A function/transform that takes in
a sampled mini-batch and returns a transformed version.
(default: :obj:`None`)
neg_sampling_ratio (float, optional): The ratio of sampled negative
edges to the number of positive edges.
If :obj:`edge_label` does not exist, it will be automatically
Expand All @@ -219,6 +216,13 @@ class LinkNeighborLoader(torch.utils.data.DataLoader):
:meth:`F.binary_cross_entropy`) and of type
:obj:`torch.long` for multi-class classification (to facilitate the
ease-of-use of :meth:`F.cross_entropy`). (default: :obj:`0.0`).
transform (Callable, optional): A function/transform that takes in
a sampled mini-batch and returns a transformed version.
(default: :obj:`None`)
is_sorted (bool, optional): If set to :obj:`True`, assumes that
:obj:`edge_index` is sorted by column. This avoids internal
re-sorting of the data and can improve runtime and memory
efficiency. (default: :obj:`False`)
**kwargs (optional): Additional arguments of
:class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
:obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
Expand All @@ -231,9 +235,10 @@ def __init__(
edge_label: OptTensor = None,
replace: bool = False,
directed: bool = True,
neg_sampling_ratio: float = 0.0,
transform: Callable = None,
is_sorted: bool = False,
neighbor_sampler: Optional[LinkNeighborSampler] = None,
neg_sampling_ratio: float = 0.0,
**kwargs,
):
# Remove for PyTorch Lightning:
Expand All @@ -259,9 +264,15 @@ def __init__(

if neighbor_sampler is None:
self.neighbor_sampler = LinkNeighborSampler(
data, num_neighbors, replace, directed, edge_type,
data,
num_neighbors,
replace,
directed,
input_type=edge_type,
is_sorted=is_sorted,
neg_sampling_ratio=self.neg_sampling_ratio,
share_memory=kwargs.get('num_workers', 0) > 0,
neg_sampling_ratio=self.neg_sampling_ratio)
)

super().__init__(Dataset(edge_label_index, edge_label),
collate_fn=self.neighbor_sampler, **kwargs)
Expand Down
24 changes: 19 additions & 5 deletions torch_geometric/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ def __init__(
replace: bool = False,
directed: bool = True,
input_type: Optional[Any] = None,
share_memory: bool = False,
time_attr: Optional[str] = None,
is_sorted: bool = False,
share_memory: bool = False,
):
self.data_cls = data.__class__
self.num_neighbors = num_neighbors
Expand All @@ -41,7 +42,8 @@ def __init__(
f"'{data.__class__.__name__}' object")

# Convert the graph data into a suitable format for sampling.
out = to_csc(data, device='cpu', share_memory=share_memory)
out = to_csc(data, device='cpu', share_memory=share_memory,
is_sorted=is_sorted)
self.colptr, self.row, self.perm = out
assert isinstance(num_neighbors, (list, tuple))

Expand All @@ -54,7 +56,8 @@ def __init__(
# Convert the graph data into a suitable format for sampling.
# NOTE: Since C++ cannot take dictionaries with tuples as key as
# input, edge type triplets are converted into single strings.
out = to_hetero_csc(data, device='cpu', share_memory=share_memory)
out = to_hetero_csc(data, device='cpu', share_memory=share_memory,
is_sorted=is_sorted)
self.colptr_dict, self.row_dict, self.perm_dict = out

self.node_types, self.edge_types = data.metadata()
Expand Down Expand Up @@ -245,6 +248,10 @@ class NeighborLoader(torch.utils.data.DataLoader):
transform (Callable, optional): A function/transform that takes in
a sampled mini-batch and returns a transformed version.
(default: :obj:`None`)
is_sorted (bool, optional): If set to :obj:`True`, assumes that
:obj:`edge_index` is sorted by column. This avoids internal
re-sorting of the data and can improve runtime and memory
efficiency. (default: :obj:`False`)
**kwargs (optional): Additional arguments of
:class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
:obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
Expand All @@ -258,6 +265,7 @@ def __init__(
directed: bool = True,
time_attr: Optional[str] = None,
transform: Callable = None,
is_sorted: bool = False,
neighbor_sampler: Optional[NeighborSampler] = None,
**kwargs,
):
Expand All @@ -281,9 +289,15 @@ def __init__(

if neighbor_sampler is None:
self.neighbor_sampler = NeighborSampler(
data, num_neighbors, replace, directed, node_type,
data,
num_neighbors,
replace,
directed,
input_type=node_type,
time_attr=time_attr,
share_memory=kwargs.get('num_workers', 0) > 0)
is_sorted=is_sorted,
share_memory=kwargs.get('num_workers', 0) > 0,
)

super().__init__(input_nodes, collate_fn=self.neighbor_sampler,
**kwargs)
Expand Down
13 changes: 8 additions & 5 deletions torch_geometric/loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def to_csc(
data: Union[Data, EdgeStorage],
device: Optional[torch.device] = None,
share_memory: bool = False,
is_sorted: bool = False,
) -> Tuple[Tensor, Tensor, OptTensor]:
# Convert the graph data into a suitable format for sampling (CSC format).
# Returns the `colptr` and `row` indices of the graph, as well as an
Expand All @@ -47,17 +48,18 @@ def to_csc(

elif hasattr(data, 'edge_index'):
(row, col) = data.edge_index
size = data.size()
perm = (col * size[0]).add_(row).argsort()
if not is_sorted:
size = data.size()
perm = (col * size[0]).add_(row).argsort()
row = row[perm]
colptr = torch.ops.torch_sparse.ind2ptr(col[perm], size[1])
row = row[perm]
else:
raise AttributeError("Data object does not contain attributes "
"'adj_t' or 'edge_index'")

colptr = colptr.to(device)
row = row.to(device)
perm = perm if perm is not None else perm.to(device)
perm = perm.to(device) if perm is not None else None

if not colptr.is_cuda and share_memory:
colptr.share_memory_()
Expand All @@ -72,6 +74,7 @@ def to_hetero_csc(
data: HeteroData,
device: Optional[torch.device] = None,
share_memory: bool = False,
is_sorted: bool = False,
) -> Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, OptTensor]]:
# Convert the heterogeneous graph data into a suitable format for sampling
# (CSC format).
Expand All @@ -83,7 +86,7 @@ def to_hetero_csc(

for store in data.edge_stores:
key = edge_type_to_str(store._key)
out = to_csc(store, device, share_memory)
out = to_csc(store, device, share_memory, is_sorted)
colptr_dict[key], row_dict[key], perm_dict[key] = out

return colptr_dict, row_dict, perm_dict
Expand Down

0 comments on commit e3ba9d3

Please sign in to comment.