Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Make optimize work on structs #1073

Closed
wants to merge 8 commits into from
Closed

Conversation

Roger-luo
Copy link
Contributor

@Roger-luo Roger-luo commented Mar 6, 2020

So I tried to implement what was proposed in #637
I guess this is what @MikeInnes wanted? But I didn't see why need isimmutable trait tho. With this PR, the following would work at least and a bit faster since it's type stable now.

using Flux
using Flux.Optimise: update!, ADAM

m = Chain(Dense(100, 100, tanh), Dense(100, 100, tanh))
opt = ADAM()
x = rand(Float32, 100)
loss(m, x) = sum(m(x))
∇m, _ = gradient(loss, m, x)
update!(opt, m, ∇m)

@DhairyaLGandhi
Copy link
Member

I think this is pretty close to what we were thinking of, just getting update to work with regular structs

@Roger-luo
Copy link
Contributor Author

Roger-luo commented Mar 6, 2020

So I think this should work on regular structs use AbstractArray as its parameters, at least works on whatever params works, but there could be things like prelu that contains one parameter, not sure if Flux has it somewhere. I think the way torch deal with training is just letting users define it (at least at its early stages). might also related to #666 so this could be fine, since if we have a step function, users could just feed the gradients to their model there. (or even just use update!)

@MikeInnes
Copy link
Member

The goal of #637 is not to support updating structs, but to support functional-style training loops; that's why the API proposed in that issue takes and returns a model and optimiser state rather than updating in-place. (That's also why we need something like an isimmutable trait: we want to do things in place as an optimisation, but need an automatic way to tell when that's possible, which it isn't for some array types.)

To meet that goal it's not enough to build on top of the current optimisers, since they use very-not-functional IdDicts internally which loses all of the properties we want.

Also, as a style note, we definitely shouldn't be writing generated functions here. fmap designed for exactly this kind of thing.

I think it'd be great to use this branch to hack out a design for this stuff, but just to point out that this is a fairly comprehensive rework with a few tricky design challenges involved.

@Roger-luo
Copy link
Contributor Author

Roger-luo commented Mar 6, 2020

Ah, @MikeInnes I see what you mean, so actually this PR is not very related to #637 since it just changed the update! function, but #637 will need a rework on the entire optimizers structure with this new API IIUC.

I think maybe we could just have this first to let ppl try the structural gradient and see feedbacks. At least this PR won't break the current API.

@Roger-luo
Copy link
Contributor Author

Roger-luo commented Mar 6, 2020

fmap will make update! type unstable IIUC, since the cache is a IdDict{Any, Any}. If it's type unstable at least this will make things like AutoPreallocation hard to sort things out (result in runtime allocation caused by dynamic dispatch). And I mean it's quite straightforward to avoid it with a generated function so why not?

And actually I tried a bit still don't see how to implement this with fmap either, in order to update the struct with a NamedTuple I will need somehow indexing the structure's fields but fmap only map the function on a single functor.

@CarloLucibello
Copy link
Member

How is this interacting with parameters sharing? e.g.

l = Dense(100, 100, tanh)
m = Chain(l, l)
....

We would have two ADAM steps, one for each partial derivative, instead of a single ADAM step with the sum of the partial derivates, right? This is not what we want

@Roger-luo
Copy link
Contributor Author

I think this simply just feed the optimizer the reference of the Array, (unless you are using some immutable structure to keep your parameters).

Zygote will create a corresponding NamedTuple with new allocated memory, thus you can just feed them to different optimizers as before. I don't see why this won't work.

@CarloLucibello
Copy link
Member

Let's make a concrete example

julia> l = Dense(2,2,tanh)
Dense(2, 2, tanh)

julia> m = Chain(l,l)
Chain(Dense(2, 2, tanh), Dense(2, 2, tanh))

julia> loss(m, x) = sum(m(x))
loss (generic function with 1 method)

julia> x = rand(Float32, 2)
2-element Array{Float32,1}:
 0.6158823
 0.3728192

julia> ∇m, _ = gradient(loss, m, x)
((layers = ((W = Float32[0.94506115 0.57208484; -0.76053166 -0.46038148], b = Float32[1.5344834, -1.2348653], σ = nothing), (W = Float32[0.13190015 0.296789; 0.1380123 0.310542], b = Float32[0.9556952, 0.99998146], σ = nothing)),), Float32[0.4527306, -1.2399194])

julia> ∇m.layers[1].W
2×2 Array{Float32,2}:
  0.945061   0.572085
 -0.760532  -0.460381

julia> ∇m.layers[2].W
2×2 Array{Float32,2}:
 0.1319    0.296789
 0.138012  0.310542

Let's call W the shared weight matrix. According to this PR, a call to update!(opt, m, ∇m) will result into two calls to apply:

W .-= apply!(opt, W, ∇m.layers[1].W)
W .-= apply!(opt, W, ∇m.layers[2].W)

This is wrong for optimizers with internal state. The correct behavior is

W .-= apply!(opt, W, ∇m.layers[1].W + ∇m.layers[2].W)

@Roger-luo
Copy link
Contributor Author

True. I see your point, it seems we need to store the mapping between parameter reference and its corresponding gradient somehow. Only struct gradient seems not sufficient. I don't have a good idea yet, let me think about it too.

@MikeInnes
Copy link
Member

This isn't really any more type-stable, it's just that the IdDict is hidden inside the current optimiser objects (which results in the weird interactions Carlo points out). Part of the "why not?" is that fmap handles subtleties like DAG-structures in a correct way, and we can pretty easily fix the cache type stability for immutable objects.

@MikeInnes MikeInnes linked an issue Mar 17, 2020 that may be closed by this pull request
@Roger-luo Roger-luo changed the title Make optimize work on structs WIP: Make optimize work on structs Mar 17, 2020
@ToucheSir ToucheSir mentioned this pull request Nov 27, 2020
4 tasks
@ToucheSir
Copy link
Member

@Roger-luo can we close this in favour of Optimisers.jl?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

outdims function doesn't work properly for chained layers
5 participants