Skip to content

Commit

Permalink
Add drop_nodes transform (#426)
Browse files Browse the repository at this point in the history
* drop node

* tests

* Update transform.jl

* Update transform.jl

* added to gnngraphs

* error in test?

* Update src/GNNGraphs/transform.jl

float32 args

Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>

* Update src/GNNGraphs/transform.jl

arg fix

Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>

---------

Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
  • Loading branch information
rbSparky and CarloLucibello authored Jun 27, 2024
1 parent 36e8373 commit bcce0cf
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ export add_nodes,
to_unidirected,
random_walk_pe,
remove_nodes,
drop_nodes,
# from Flux
batch,
unbatch,
Expand Down
32 changes: 32 additions & 0 deletions src/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,38 @@ function remove_nodes(g::GNNGraph{<:COO_T}, nodes_to_remove::AbstractVector)
ndata, edata, g.gdata)
end

"""
drop_nodes(g::GNNGraph{<:COO_T}, p)
Randomly drop nodes (and their associated edges) from a GNNGraph based on a given probability.
Dropping nodes is a technique that can be used for graph data augmentation, refering paper [DropNode](https://arxiv.org/pdf/2008.12578.pdf).
# Arguments
- `g`: The input graph from which nodes (and their associated edges) will be dropped.
- `p`: The probability of dropping each node. Default value is `0.5`.
# Returns
A modified GNNGraph with nodes (and their associated edges) dropped based on the given probability.
# Example
```julia
using GraphNeuralNetworks
# Construct a GNNGraph
g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1], num_nodes=3)
# Drop nodes with a probability of 0.5
g_new = drop_node(g, 0.5)
println(g_new)
```
"""
function drop_nodes(g::GNNGraph{<:COO_T}, p = 0.5)
num_nodes = g.num_nodes
nodes_to_remove = filter(_ -> rand() < p, 1:num_nodes)

new_g = remove_nodes(g, nodes_to_remove)

return new_g
end

"""
add_edges(g::GNNGraph, s::AbstractVector, t::AbstractVector; [edata])
add_edges(g::GNNGraph, (s, t); [edata])
Expand Down
18 changes: 18 additions & 0 deletions test/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,24 @@ end
@test edata_new == edatatest
end end

@testset "drop_nodes" begin
if GRAPH_T == :coo
Random.seed!(42)
s = [1, 1, 2, 3]
t = [2, 3, 4, 5]
g = GNNGraph(s, t, graph_type = GRAPH_T)

gnew = drop_nodes(g, Float32(0.5))
@test gnew.num_nodes == 3

gnew = drop_nodes(g, Float32(1.0))
@test gnew.num_nodes == 0

gnew = drop_nodes(g, Float32(0.0))
@test gnew.num_nodes == 5
end
end

@testset "add_nodes" begin if GRAPH_T == :coo
g = rand_graph(6, 4, ndata = rand(2, 6), graph_type = GRAPH_T)
gnew = add_nodes(g, 5, ndata = ones(2, 5))
Expand Down

0 comments on commit bcce0cf

Please sign in to comment.