diff --git a/src/destructure.jl b/src/destructure.jl index 75e1876a..6b7b6932 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -55,11 +55,11 @@ Base.length(re::Restructure) = re.length # This flattens a model, and returns a web of offsets for later use: function _flatten(x) - isnumeric(x) && return vcat(vec(x)), 0, length(x) # trivial case + isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case arrays = AbstractVector[] len = Ref(0) off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y - push!(arrays, vec(y)) + push!(arrays, _vec(y)) o = len[] len[] = o + length(y) o @@ -67,9 +67,12 @@ function _flatten(x) reduce(vcat, arrays), off, len[] end +_vec(x::Number) = LinRange(x,x,1) +_vec(x::AbstractArray) = vec(x) + function ChainRulesCore.rrule(::typeof(_flatten), x) flat, off, len = _flatten(x) - _flatten_back((dflat, _, _)) = (NoT, _rebuild(x, off, dflat, len; walk = _Tangent_biwalk, prune = NoT)) + _flatten_back((dflat, _, _)) = (NoT, _rebuild(x, off, unthunk(dflat), len; walk = _Tangent_biwalk, prune = NoT)) (flat, off, len), _flatten_back end @@ -92,7 +95,7 @@ function _trainable_biwalk(f, x, aux) end function _trainmap(f, ch, tr, aux) - map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c)?? + map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c) isnothing(t) ? c : f(t, a) end end @@ -121,7 +124,7 @@ ChainRulesCore.@non_differentiable _zero(x) # This is the gradient of model reconstruction, accumulating duplicates: function _grad!(x, dx, off, flat::AbstractVector) x′, _ = functor(typeof(x), x) - dx′, _ = functor(typeof(x), dx) + dx′, _ = functor(typeof(x), base(dx)) off′, _ = functor(typeof(x), off) foreach((xᵢ, dxᵢ, oᵢ) -> _grad!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′) flat @@ -134,7 +137,6 @@ _grad!(x, dx::Zero, off, flat::AbstractVector) = dx _grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = dx # ambiguity function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat) - println("grad! fwd ", length(flat)) _grad_back(dflat) = (NoT, NoT, _rebuild(x, off, unthunk(dflat); walk = _Tangent_biwalk, prune = NoT), NoT, NoT) _grad!(x, dx, off, flat), _grad_back end diff --git a/src/interface.jl b/src/interface.jl index 1116b90a..8df73066 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -1,6 +1,7 @@ using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero base(dx::Tangent) = backing(canonicalize(dx)) +base(dx::Tangent{<:Tangent}) = backing(dx).backing # might be needed for gradient(gradient(destructure)) base(dx) = dx const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor} diff --git a/test/destructure.jl b/test/destructure.jl index 6de5a6af..954854ae 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -2,7 +2,7 @@ m1 = collect(1:3.0) m2 = (collect(1:3.0), collect(4:6.0)) m3 = (x = m1, y = sin, z = collect(4:6.0)) -m4 = (x = m1, y = m1, z = collect(4:6.0)) +m4 = (x = m1, y = m1, z = collect(4:6.0)) # tied m5 = (a = (m3, true), b = (m1, false), c = (m4, true)) m6 = (a = m1, b = [4.0 + im], c = m1) m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0))) @@ -75,10 +75,18 @@ end @test_broken gradient([1,2,3.0]) do v sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6.0]))[1][1]) end[1] ≈ [8,16,24] + # With Diffractor, non-leaf _grad!(x, dx, off, flat::AbstractVector) gets double-wrapped dx: + # off = (0, 3), dx = Tangent{Tangent{Tuple{Vector{Float64}, Vector{Float64}}, ... + # until you add explicit double-unwrap: base(dx::Tangent{<:Tangent}) = backing(dx).backing - @test_skip gradient([1,2,3.0]) do v - sum(gradient(m -> sum(destructure(m)[1]), (v, [4,5,6.0]))[1][1]) - end + @test_broken gradient([1,2,3.0]) do v + sum(gradient(m -> sum(destructure(m)[1])^3, (v, [4,5,6.0]))[1][1]) + end[1] == [378, 378, 378] + + @test_broken gradient([1,2,3.0]) do v + sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (x = v, y = sin, z = [4,5,6.0]))[1][1]) + end[1] ≈ [8,16,24] + # Diffractor error in perform_optic_transform end end @@ -109,15 +117,17 @@ end @test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] @testset "second derivative" begin - # ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}} @test_broken gradient(collect(1:6.0)) do y sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1]) end[1] ≈ [8,16,24,0,0,0] - # This fixes it! + # ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}} + # with Zygote, which can be fixed by: # Zygote.@adjoint Tangent{T,B}(x::Tuple) where {T,B<:Tuple} = Tangent{T,B}(x), dx -> (dx,) - @test_skip gradient(collect(1:6.0)) do y + + @test_broken gradient(collect(1:6.0)) do y sum(abs2, gradient(x -> sum(abs2, re3(x).z), y)[1]) - end[1] + end[1] ≈ [0,0,0,32,40,48] + # Not fixed by this: # Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,) end end