Skip to content

Commit

Permalink
bounds check
Browse files Browse the repository at this point in the history
  • Loading branch information
rbSparky committed Mar 17, 2024
1 parent d87b5c4 commit 5cc79cc
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,15 @@ function remove_nodes(g::GNNGraph{<:COO_T}, nodes_to_remove::AbstractVector)
edata = g.edata
ndata = g.ndata

#fix
edges_to_remove_s = findall(x -> nodes_to_remove[searchsortedlast(nodes_to_remove, x)] == x, s)
edges_to_remove_t = findall(x -> nodes_to_remove[searchsortedlast(nodes_to_remove, x)] == x, t)
function find_edges_to_remove(nodes, nodes_to_remove)
return findall(node_id -> begin
idx = searchsortedlast(nodes_to_remove, node_id)
idx >= 1 && idx <= length(nodes_to_remove) && nodes_to_remove[idx] == node_id
end, nodes)
end

edges_to_remove_s = find_edges_to_remove(s, nodes_to_remove)
edges_to_remove_t = find_edges_to_remove(t, nodes_to_remove)
edges_to_remove = union(edges_to_remove_s, edges_to_remove_t)

mask_edges_to_keep = trues(length(s))
Expand Down

0 comments on commit 5cc79cc

Please sign in to comment.