Skip to content

Commit

Permalink
Merge pull request #201 from mplemay/fix-flux-restructure
Browse files Browse the repository at this point in the history
fix GNNChain restructure bug
  • Loading branch information
CarloLucibello authored Jul 22, 2022
2 parents 401ae1a + c4d562c commit 93b6fa2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ end
Base.iterate, Base.lastindex, Base.keys

Flux.functor(::Type{<:GNNChain}, c) = c.layers, ls -> GNNChain(ls...)
Flux.functor(::Type{<:GNNChain}, c::Tuple) = c, ls -> GNNChain(ls...)

# input from graph
applylayer(l, g::GNNGraph) = GNNGraph(g, ndata=l(node_features(g)))
Expand Down
6 changes: 6 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,11 @@
wg = WithGraph(model, g, traingraph=true)
@test length(Flux.params(wg)) == length(Flux.params(model)) + length(Flux.params(g))
end

@testset "Flux restructure" begin
chain = GNNChain(GraphConv(2=>2))
params, restructure = Flux.destructure(chain)
@test restructure(params) isa GNNChain
end
end

0 comments on commit 93b6fa2

Please sign in to comment.