-
-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Per-leaf freezing #49
Conversation
This is the mental model I had in my mind, but I forgot that we have One difference this approach vs. separate trees for auxiliary information is that the latter is extensible. For example, if we didn't have freezing built-in, a separate package could define a I'm just using freezing as a hypothetical here to illustrate that a Functors.jl solution with multiple trees could serve both Optimisers.jl use-cases and external use-cases. Of course, it's trickier and more complex than what we have here. |
This is true. I'd worry a little bit that understanding the API for an extensible multi-tree walk might be harder than writing it yourself. One extensibility thought is: Instead of building this into Leaf, it could be a separate Freeze which wraps it, and provides an Would such a mechanism work for other things one may want to hook on? What are some examples? One more question to think about: Unlike #42 this exposes the "address" tuple as something the user is supposed to provide. Do we like or hate it? |
Related, Flux at present has this distinction: julia> Flux.functor(Chain(x=sin, y=cos))
((x = sin, y = cos), Flux.var"#154#155"())
julia> Flux.functor(Parallel(vcat, (x=sin, y=cos)))
((connection = vcat, layers = (x = sin, y = cos)), Flux.var"#178#179"()) So to freeze a named branch of Parallel, you'd have to say |
Given that julia> func, re = Flux.functor(Chain(x=sin, y=cos))
((x = sin, y = cos), Flux.var"#154#155"())
julia> re(func)
Chain(sin, cos)
julia> re(func).layers
(sin, cos) |
Oh that's bad. We should either not hide this |
There's a splat which loses the names. But the julia> m = Chain(Dense([1 2; 3 4.0], [5,6], relu), identity);
julia> g = gradient(m -> m([3,2])[1], m)[1]
(layers = ((weight = [3.0 2.0; 0.0 0.0], bias = [1.0, 0.0], σ = nothing), nothing),)
julia> s = Optimisers.setup(Optimisers.Descent(pi/10), m)
((weight = Leaf(Descent{Float64}(0.314159), nothing), bias = Leaf(Descent{Float64}(0.314159), nothing), σ = nothing), nothing)
julia> s2, m2 = Optimisers.update(s, m, g);
julia> m2.layers[1].weight
2×2 Matrix{Float64}:
0.0575222 1.37168
3.0 4.0 Notice, aside, this way to screw up, which runs without error:
|
Thoughts on reviving this without the addressing functionality so we can defer that decision? Users can always use Accessors.jl in the interim (and possibly in the long term, if we want to support that) for fine-grained manipulation. |
This is one way we could handle freezing of certain nodes, by altering the state tree.