Skip to content

Commit

Permalink
Save all input arguments in NodeLoader so that PyTorch Lightning ca…
Browse files Browse the repository at this point in the history
…n correctly reconstruct it (pyg-team#8809)
  • Loading branch information
rusty1s authored Jan 22, 2024
1 parent 81fdeaf commit eb10fc2
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions torch_geometric/loader/node_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,22 @@ def __init__(
if filter_per_worker is None:
filter_per_worker = infer_filter_per_worker(data)

# Remove for PyTorch Lightning:
kwargs.pop('dataset', None)
kwargs.pop('collate_fn', None)

# Get node type (or `None` for homogeneous graphs):
input_type, input_nodes, input_id = get_input_nodes(
data, input_nodes, input_id)

self.data = data
self.node_sampler = node_sampler
self.input_nodes = input_nodes
self.input_time = input_time
self.transform = transform
self.transform_sampler_output = transform_sampler_output
self.filter_per_worker = filter_per_worker
self.custom_cls = custom_cls
self.input_id = input_id

kwargs.pop('dataset', None)
kwargs.pop('collate_fn', None)

# Get node type (or `None` for homogeneous graphs):
input_type, input_nodes, input_id = get_input_nodes(
data, input_nodes, input_id)

self.input_data = NodeSamplerInput(
input_id=input_id,
Expand Down

0 comments on commit eb10fc2

Please sign in to comment.