From eb10fc2dfbea5bbd3ba0970fab2ec049a7e4edbd Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Mon, 22 Jan 2024 12:34:33 +0100 Subject: [PATCH] Save all input arguments in `NodeLoader` so that PyTorch Lightning can correctly reconstruct it (#8809) --- torch_geometric/loader/node_loader.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/torch_geometric/loader/node_loader.py b/torch_geometric/loader/node_loader.py index 46449c33e13a..3a16bf618ab2 100644 --- a/torch_geometric/loader/node_loader.py +++ b/torch_geometric/loader/node_loader.py @@ -103,20 +103,22 @@ def __init__( if filter_per_worker is None: filter_per_worker = infer_filter_per_worker(data) - # Remove for PyTorch Lightning: - kwargs.pop('dataset', None) - kwargs.pop('collate_fn', None) - - # Get node type (or `None` for homogeneous graphs): - input_type, input_nodes, input_id = get_input_nodes( - data, input_nodes, input_id) - self.data = data self.node_sampler = node_sampler + self.input_nodes = input_nodes + self.input_time = input_time self.transform = transform self.transform_sampler_output = transform_sampler_output self.filter_per_worker = filter_per_worker self.custom_cls = custom_cls + self.input_id = input_id + + kwargs.pop('dataset', None) + kwargs.pop('collate_fn', None) + + # Get node type (or `None` for homogeneous graphs): + input_type, input_nodes, input_id = get_input_nodes( + data, input_nodes, input_id) self.input_data = NodeSamplerInput( input_id=input_id,