Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add trainables_nt #175

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open

Add trainables_nt #175

wants to merge 3 commits into from

Conversation

CarloLucibello
Copy link
Member

This is a proposal for an alternative to destructure which doesn't completely flatten the parameters but returns a nested named tuple. The associated reconstructor can be be used on ComponentArrays as well.

@darsnack
Copy link
Member

Keeping differentiability aside, is fmapstructure not sufficient because of how vectors are handled (e.g. layers in Chain)?

@CarloLucibello
Copy link
Member Author

Exactly. And we need a nested namedtuple-only return in order to be compatible with ComponentArrays.

@darsnack
Copy link
Member

What about replacing destructure with this code + the ComponentArrays construction? As opposed to adding this as a separate function. It would move a lot of the tricky stuff to ComponentArrays.

@mcabbott

@mcabbott
Copy link
Member

mcabbott commented Apr 17, 2024

Wait there are two big differences from fmapstructure / Flux.state:

  • this is only trainable parameters, and
  • tuples & vectors become NamedTuples with made-up field names.

ComponentArrays has no notion of shared parameters. That's a large part of what makes everything touching Functors tricky. (In fact the replacement of a vector with a NamedTuple opens the door to weirdness here, before you get to ComponentArrays, as you replace a mutable thing with an immutable one. Probably not in a way that matters for Flux models.)

Example with this:

julia> sh = [1f0, 2f0];

julia> ps, re = trainables_nt((sh, sh, [3,4.]))
((_1 = Float32[1.0, 2.0], _2 = Float32[1.0, 2.0], _3 = [3.0, 4.0]), Optimisers.RestructureFromNT{Tuple{Vector{Float32}, Vector{Float32}, Vector{Float64}}}((Float32[1.0, 2.0], Float32[1.0, 2.0], [3.0, 4.0])))

julia> ps._1 === ps._2
true

julia> v = ComponentVector(ps);

julia> getfield(v, :data) |> println
[1.0, 2.0, 1.0, 2.0, 3.0, 4.0]

julia> v[3] = 99;

julia> re(v)  # sharing is broken
([1.0, 2.0], [99.0, 2.0], [3.0, 4.0])

And unrelated to sharing:

julia> re(v)[1] |> eltype  # accidental promotion is back
Float64

julia> re(v)[1]   # no copy on reconstruction, but will view(::CuArray) work everywhere?
2-element view(::Vector{Float64}, 1:2) with eltype Float64:
 1.0
 2.0

cf destructure:

julia> v2, re2 = destructure((sh, sh, [3,4.]))
([1.0, 2.0, 3.0, 4.0], Restructure(Tuple, ..., 4))

julia> v2[2] = 999;

julia> re2(v2)
(Float32[1.0, 999.0], Float32[1.0, 999.0], [3.0, 4.0])

When last I looked, ComponentArrays it also made more whole copies in the gradient.

More broadly, what's this for? Why do we care about ComponentArrays?

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Apr 19, 2024

More broadly, what's this for? Why do we care about ComponentArrays?

I would like to have something in the v, re = destructure(model) style but for which reconstruction is copyless and it is also compatible with ComponentArrays. This is something that seems quite needed, see FluxML/Flux.jl#2413 (comment).
I think we can provide it and see if it is used.

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Apr 19, 2024

I need help with the rrule of the reconstructor. It works for named tuples but not for component arrays:

using Zygote, Optimisers, ComponentArrays, Test
m = (collect(1:3.0), collect(4:6.0))
ps, re = trainables_nt(m)
Zygote.refresh()
gps = gradient(x -> re(x)[1][2], ps)[1]
@test gps == (_1 = [0.0, 1.0, 0.0], _2 = nothing). # ok

v = ComponentVector(ps)
gv = gradient(x -> re(x)[1][2], v)[1] # this is `nothing`!!!!

The relevant rule is

function ChainRulesCore.rrule(::typeof(restructure_from_nt), x, ps)
    model = restructure_from_nt(x, ps)
    proj_ps = ProjectTo(ps)

    function restructure_from_nt_back(Δmodel_raw)
        Δmodel = unthunk(Δmodel_raw)
        walk = RestructureFromNamedTupleBackWalk()
        function exclude(x)
            @show "exclude" x isnumeric(x)
            # i += 1
            # return i > 1
            return isnumeric(x)
        end
        Δps = fmap(ps, Δmodel; exclude, walk, cache=nothing) do p, Δ
                    @show "fmap" Δ p

                    return Δ
                end
        Δpst = Tangent{typeof(Δps)}(; Δps...)
        @show "rrule" Δmodel x ps Δps Δpst         #here  Δp = (_1 = [0.0, 1.0, 0.0], _2 = ChainRulesCore.ZeroTangent())
        @show  typeof(Δmodel) typeof(ps) typeof(Δps)
        return (NoTangent(), NoTangent(), Δps)
        # return (NoTangent(), NoTangent(), proj_ps(Δpst))
    end
    return model, restructure_from_nt_back
end

struct RestructureFromNamedTupleBackWalk <: AbstractWalk end

function (::RestructureFromNamedTupleBackWalk)(recurse, ps, Δmodel)
    @show 1 typeof(Δmodel) typeof(ps)
    Δm = make_named_tuple(Δmodel)
    @show 2 typeof(Δm) ps Δm
    Δm === nothing && return nothing
    Δm === ZeroTangent() && return ZeroTangent()
    y = mapvalue(recurse, ps, Δm)
    @show 3 typeof(Δmodel) typeof(Δm) typeof(y)
    return y
end

Why do I get nothing gradient? Am I doing something wrong with the projection?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants