From 3f0019f4d3738dfb109ed15e19b5098de1afa08d Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 8 Apr 2022 18:44:22 +0000 Subject: [PATCH] emergency fix --- torch_geometric/loader/neighbor_loader.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch_geometric/loader/neighbor_loader.py b/torch_geometric/loader/neighbor_loader.py index b184147aae08..7425001c2718 100644 --- a/torch_geometric/loader/neighbor_loader.py +++ b/torch_geometric/loader/neighbor_loader.py @@ -269,6 +269,8 @@ def get_input_nodes(data: Union[Data, HeteroData], if isinstance(data, Data): if input_nodes is None: return None, range(data.num_nodes) + if input_nodes.dtype == torch.bool: + input_nodes = input_nodes.nonzero(as_tuple=False).view(-1) return None, input_nodes assert input_nodes is not None @@ -283,4 +285,7 @@ def get_input_nodes(data: Union[Data, HeteroData], if input_nodes[1] is None: return input_nodes[0], range(data[input_nodes[0]].num_nodes) - return input_nodes + node_type, input_nodes = input_nodes + if input_nodes.dtype == torch.bool: + input_nodes = input_nodes.nonzero(as_tuple=False).view(-1) + return node_type, input_nodes