Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

collect OneElement when used with implicit Params #989

Closed
wants to merge 2 commits into from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jun 9, 2021

This is to address SciML/Surrogates.jl#279, by ensuring that when using implicit parameters, the arrays in Grads are mutable ones. Current behaviour:

julia> W = rand(3);

julia> gradient(() -> W[1], Params([W]))
Grads(...)

julia> ans[W]  # perhaps surprising? Changed by this PR
3-element Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}:
 1.0
 0.0
 0.0

julia> gradient(() -> W[1] + W[2], Params([W]))  # with two calls, accumulation does already work:
Grads(...)

julia> ans[W]
3-element Vector{Float64}:
 1.0
 1.0
 0.0

Perhaps deserves a bit more thought before merging. Do we insist that gradients in Grads are mutable?

The stack trace from Flux looks like this, shouldn't it be updating x::Vector{Float32} from xs::Zygote.Params, not updating x̄::Zygote.OneElement?

ERROR: LoadError: setindex! not defined for Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}
...
 [10] apply!
    @ ~/.julia/dev/Flux/src/optimise/optimisers.jl:42 [inlined]
 [11] update!(opt::Descent, x::Vector{Float32}, x̄::Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}})
    @ Flux.Optimise ~/.julia/dev/Flux/src/optimise/train.jl:23
 [12] update!(opt::Descent, xs::Zygote.Params, gs::Zygote.Grads)
    @ Flux.Optimise ~/.julia/dev/Flux/src/optimise/train.jl:29

But before updating x, Flux scales the gradient, to apply the learning rate from the optimiser. That's a slightly strange feature, maybe it shouldn't do that?

function apply!(o::Descent, x, Δ)
  Δ .*= o.eta
end

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Jun 9, 2021

It seems to me that there must have been a bug in accum over using the OneElelment? Its pretty standard to be truthful about the element types we deal with, and yes, would not want to rule out immutable arrays (or any other such data type)

We should explore what the flaw with accum (before OneElelment) was.

@DhairyaLGandhi
Copy link
Member

Why does OneElelment attempt to modify the return type or the type of output gradient at all? That doesn't seem right.

@ChrisRackauckas
Copy link
Member

Why does OneElelment attempt to modify the return type or the type of output gradient at all? That doesn't seem right.

Zygote already doesn't enforce that your returned gradient type matches the input parameters at all. I'm not saying that this is a good thing, but:

using Zygote
Zygote.gradient(sum,rand(4))[1] # 4-element Fill{Float64}: entries equal to 1.0

it's just generally a problem with Zygote. As a band-aid this is fine. What really should be happening is one should ArrayInterface.restructure(input,grad), which is an interface function which says "take this thing grad and make it of the form input". For Arrays this is the same as reshape, but for example it will take a Vector and turn it back into a ComponentArray, or re-GPU it if the user though it should be on the GPU, etc. If you really want to enforce that constraint (which I think you should, because generally returning a random type that doesn't act the same as the user's representation of the parameters is bound to fail often), then it should be enforced fully.

@mcabbott
Copy link
Member Author

Good example. I think this Fill could produce the same error as seen here -- since Flux's optimisers seem to assume that the gradient is a mutable array.

I'd argue that the ideal thing isn't to satisfy Flux by unnecessarily materialising things, but for Flux to at least check before blindly mutating, but ideally update x .+= Δ .* o.eta without caring.

@ChrisRackauckas
Copy link
Member

I'd argue that the ideal thing isn't to satisfy Flux by unnecessarily materialising things, but for Flux to at least check before blindly mutating, but ideally update x .+= Δ .* o.eta without caring.

I think that change should be done regardless. Flux tries to make sure you aren't mutating, but then requires that the types can mutate, seems to just generally be a combination that would cause trouble.

@ToucheSir
Copy link
Member

I can't think of any areas off the top of my head where Flux tries to prevent mutation beyond what Zygote complains about. Certainly there is a laundry list of pain points with the current optimizers, and that's what Optimisers.jl is trying to address. See in particular this PR, which brings in ArrayInterface to avoid mutating immutable parameter types.

Now, there are still a couple of roadblocks. Foremost is that not every parameter type in Flux is a proper (Abstract)Array. Dhairya and I talked about that yesterday and it hopefully shouldn't be a problem for too much longer. The second, more fundamental issue is that not being able to mutate wrecks all sorts of havok when using implicit params. I've been ruminating over writing a "Taking Explicit Params Seriously" issue for a while now, but need to figure out how to structure it to avoid too much scorched earth ;)

@CarloLucibello
Copy link
Member

Related to FluxML/Flux.jl#1510. I agree with @ChrisRackauckas, we
should add ArrayInterface.restructure(input,grad) to Flux.update! or Flux.apply!

@mcabbott
Copy link
Member Author

Closing in favour of fixing this in Flux, then. Both a quick band-aid, and ultimately a better design.

@mcabbott mcabbott closed this Jun 10, 2021
bors bot added a commit to FluxML/Flux.jl that referenced this pull request Jun 17, 2021
1613: use ArrayInterface.restructure in update! r=CarloLucibello a=CarloLucibello

Suggestion coming from @ChrisRackauckas in FluxML/Zygote.jl#989. 
Now `update!` handles basically any gradient Zygote emits, e.g. FillArrays and Zygote.OneElement. 

Fix #1510 


Co-authored-by: CarloLucibello <carlo.lucibello@gmail.com>
bors bot added a commit to FluxML/Flux.jl that referenced this pull request Jun 17, 2021
1613: use ArrayInterface.restructure in update! r=CarloLucibello a=CarloLucibello

Suggestion coming from @ChrisRackauckas in FluxML/Zygote.jl#989. 
Now `update!` handles basically any gradient Zygote emits, e.g. FillArrays and Zygote.OneElement. 

Fix #1510 


Co-authored-by: CarloLucibello <carlo.lucibello@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants