diff --git a/src/destructure.jl b/src/destructure.jl index 8e6b75dd..75e1876a 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -102,7 +102,12 @@ function _Tangent_biwalk(f, x, aux) # use with prune = NoT au, _ = functor(typeof(x), aux) y = _trainmap(f, ch, _trainable(x), au) y isa Tuple{} && return NoT - Tangent{typeof(x), typeof(y)}(y) + p = ProjectTo(x) + if p isa ProjectTo # e.g. Array, NamedTuple + p(y) + else # p === identity for unknown structs + Tangent{typeof(x), typeof(y)}(y) + end end function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat, len; kw...) diff --git a/src/interface.jl b/src/interface.jl index 4864ae16..1116b90a 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -71,6 +71,7 @@ trainable(x) = functor(x)[1] _trainable(x) = _trainable(functor(x)[1], trainable(x)) _trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr) _trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr +_trainable(ch::AbstractArray, tr::AbstractArray) = tr function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tuple @warn "trainable(x) should now return a NamedTuple with the field names, not a Tuple" map(c -> c in tr ? c : nothing, ch) diff --git a/test/destructure.jl b/test/destructure.jl index 55ab37df..6de5a6af 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -6,6 +6,7 @@ m4 = (x = m1, y = m1, z = collect(4:6.0)) m5 = (a = (m3, true), b = (m1, false), c = (m4, true)) m6 = (a = m1, b = [4.0 + im], c = m1) m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0))) +m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]] @testset "flatten & rebuild" begin @test destructure(m1)[1] isa Vector{Float64} @@ -31,12 +32,20 @@ m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0 @test m6′.a === m6′.c @test m6′.b == [7 + 4im] + # struct, trainable @test destructure(m7)[1] == 1:3 m7′ = destructure(m7)[2]([10,20,30]) @test m7′.a == (sin, [10,20,30]) @test m7′.b == (cos, [4,5,6]) @test m7′.c == (tan, [7,8,9]) + @test destructure(m8)[1] == 1:5 + m8′ = destructure(m8)[2](1:5) + @test m8′[1].x === m8′[1].y + @test m8′[2].b.y === false + @test m8′[3][1] == [5.0] + + # errors @test_throws Exception destructure(m7)[2]([10,20]) @test_throws Exception destructure(m7)[2]([10,20,30,40]) end @@ -57,6 +66,11 @@ end @test g6.a isa Vector{Float64} @test g6.b == [0+im] + g8 = gradient(m -> sum(abs2, destructure(m)[1]), m8)[1] + @test g8[1].x == [2,4,6] + @test g8[2].b.x == [8] + @test g8[3] == [[10.0]] + @testset "second derivative" begin @test_broken gradient([1,2,3.0]) do v sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6.0]))[1][1]) @@ -90,6 +104,10 @@ end @test gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0] @test gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0] + v8, re8 = destructure(m8) + @test gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] + @test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] + @testset "second derivative" begin # ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}} @test_broken gradient(collect(1:6.0)) do y