Skip to content

Commit

Permalink
emergency fix
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Apr 8, 2022
1 parent 9bcb946 commit 3f0019f
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion torch_geometric/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ def get_input_nodes(data: Union[Data, HeteroData],
if isinstance(data, Data):
if input_nodes is None:
return None, range(data.num_nodes)
if input_nodes.dtype == torch.bool:
input_nodes = input_nodes.nonzero(as_tuple=False).view(-1)
return None, input_nodes

assert input_nodes is not None
Expand All @@ -283,4 +285,7 @@ def get_input_nodes(data: Union[Data, HeteroData],
if input_nodes[1] is None:
return input_nodes[0], range(data[input_nodes[0]].num_nodes)

return input_nodes
node_type, input_nodes = input_nodes
if input_nodes.dtype == torch.bool:
input_nodes = input_nodes.nonzero(as_tuple=False).view(-1)
return node_type, input_nodes

0 comments on commit 3f0019f

Please sign in to comment.