Skip to content

Commit

Permalink
more... the dimensionmismatch bug is not here
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 11, 2022
1 parent 756b450 commit b62e0a2
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 14 deletions.
14 changes: 8 additions & 6 deletions src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,24 @@ Base.length(re::Restructure) = re.length

# This flattens a model, and returns a web of offsets for later use:
function _flatten(x)
isnumeric(x) && return vcat(vec(x)), 0, length(x) # trivial case
isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case
arrays = AbstractVector[]
len = Ref(0)
off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y
push!(arrays, vec(y))
push!(arrays, _vec(y))
o = len[]
len[] = o + length(y)
o
end
reduce(vcat, arrays), off, len[]
end

_vec(x::Number) = LinRange(x,x,1)
_vec(x::AbstractArray) = vec(x)

function ChainRulesCore.rrule(::typeof(_flatten), x)
flat, off, len = _flatten(x)
_flatten_back((dflat, _, _)) = (NoT, _rebuild(x, off, dflat, len; walk = _Tangent_biwalk, prune = NoT))
_flatten_back((dflat, _, _)) = (NoT, _rebuild(x, off, unthunk(dflat), len; walk = _Tangent_biwalk, prune = NoT))
(flat, off, len), _flatten_back
end

Expand All @@ -92,7 +95,7 @@ function _trainable_biwalk(f, x, aux)
end

function _trainmap(f, ch, tr, aux)
map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c)??
map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c)
isnothing(t) ? c : f(t, a)
end
end
Expand Down Expand Up @@ -121,7 +124,7 @@ ChainRulesCore.@non_differentiable _zero(x)
# This is the gradient of model reconstruction, accumulating duplicates:
function _grad!(x, dx, off, flat::AbstractVector)
x′, _ = functor(typeof(x), x)
dx′, _ = functor(typeof(x), dx)
dx′, _ = functor(typeof(x), base(dx))
off′, _ = functor(typeof(x), off)
foreach((xᵢ, dxᵢ, oᵢ) -> _grad!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
flat
Expand All @@ -134,7 +137,6 @@ _grad!(x, dx::Zero, off, flat::AbstractVector) = dx
_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = dx # ambiguity

function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat)
println("grad! fwd ", length(flat))
_grad_back(dflat) = (NoT, NoT, _rebuild(x, off, unthunk(dflat); walk = _Tangent_biwalk, prune = NoT), NoT, NoT)
_grad!(x, dx, off, flat), _grad_back
end
1 change: 1 addition & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero
base(dx::Tangent) = backing(canonicalize(dx))
base(dx::Tangent{<:Tangent}) = backing(dx).backing # might be needed for gradient(gradient(destructure))
base(dx) = dx
const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor}

Expand Down
26 changes: 18 additions & 8 deletions test/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
m1 = collect(1:3.0)
m2 = (collect(1:3.0), collect(4:6.0))
m3 = (x = m1, y = sin, z = collect(4:6.0))
m4 = (x = m1, y = m1, z = collect(4:6.0))
m4 = (x = m1, y = m1, z = collect(4:6.0)) # tied
m5 = (a = (m3, true), b = (m1, false), c = (m4, true))
m6 = (a = m1, b = [4.0 + im], c = m1)
m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0)))
Expand Down Expand Up @@ -75,10 +75,18 @@ end
@test_broken gradient([1,2,3.0]) do v
sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6.0]))[1][1])
end[1] [8,16,24]
# With Diffractor, non-leaf _grad!(x, dx, off, flat::AbstractVector) gets double-wrapped dx:
# off = (0, 3), dx = Tangent{Tangent{Tuple{Vector{Float64}, Vector{Float64}}, ...
# until you add explicit double-unwrap: base(dx::Tangent{<:Tangent}) = backing(dx).backing

@test_skip gradient([1,2,3.0]) do v
sum(gradient(m -> sum(destructure(m)[1]), (v, [4,5,6.0]))[1][1])
end
@test_broken gradient([1,2,3.0]) do v
sum(gradient(m -> sum(destructure(m)[1])^3, (v, [4,5,6.0]))[1][1])
end[1] == [378, 378, 378]

@test_broken gradient([1,2,3.0]) do v
sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (x = v, y = sin, z = [4,5,6.0]))[1][1])
end[1] [8,16,24]
# Diffractor error in perform_optic_transform
end
end

Expand Down Expand Up @@ -109,15 +117,17 @@ end
@test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10]

@testset "second derivative" begin
# ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}}
@test_broken gradient(collect(1:6.0)) do y
sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1])
end[1] [8,16,24,0,0,0]
# This fixes it!
# ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}}
# with Zygote, which can be fixed by:
# Zygote.@adjoint Tangent{T,B}(x::Tuple) where {T,B<:Tuple} = Tangent{T,B}(x), dx -> (dx,)
@test_skip gradient(collect(1:6.0)) do y

@test_broken gradient(collect(1:6.0)) do y
sum(abs2, gradient(x -> sum(abs2, re3(x).z), y)[1])
end[1]
end[1] [0,0,0,32,40,48]
# Not fixed by this:
# Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,)
end
end
Expand Down

0 comments on commit b62e0a2

Please sign in to comment.