Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Applying Grads fails #76

Closed
femtomc opened this issue May 13, 2022 · 6 comments · Fixed by FluxML/Flux.jl#1976
Closed

Applying Grads fails #76

femtomc opened this issue May 13, 2022 · 6 comments · Fixed by FluxML/Flux.jl#1976

Comments

@femtomc
Copy link

femtomc commented May 13, 2022

Simple MWE:

module GradsApplyMWE

using Optimisers
using Flux

model = Chain(Dense(50, 5))

state = Optimisers.setup(Optimisers.ADAM(), model)

data = rand(Float32, (50, 1))

grads = Flux.gradient(Flux.params(model)) do
    sum(model(data))
end

state, model = Optimisers.update(state, model, grads)

end # module

Throws:

ERROR: LoadError: type Grads has no field layers
Stacktrace:
  [1] getproperty(x::Zygote.Grads, f::Symbol)
    @ Base ./Base.jl:33
  [2] functor(#unused#::Type{Flux.Chain{Tuple{Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, x::Zygote.Grads)
    @ Flux ~/.julia/packages/Functors/qBIlC/src/functor.jl:23
  [3] (::Optimisers.var"#4#6"{Flux.Chain{Tuple{Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}})(x̄::Zygote.Grads)
    @ Optimisers ~/.julia/packages/Optimisers/pCISx/src/interface.jl:37
  [4] map
    @ ./tuple.jl:213 [inlined]
  [5] update!(tree::NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.ADAM{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, Optimisers.Leaf{Optimisers.ADAM{Float32}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}, Nothing}}}}}, x::Flux.Chain{Tuple{Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, x̄s::Zygote.Grads)
    @ Optimisers ~/.julia/packages/Optimisers/pCISx/src/interface.jl:37
  [6] update(tree::NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.ADAM{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, Optimisers.Leaf{Optimisers.ADAM{Float32}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}, Nothing}}}}}, x::Flux.Chain{Tuple{Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, x̄s::Zygote.Grads)
    @ Optimisers ~/.julia/packages/Optimisers/pCISx/src/interface.jl:46
  [7] top-level scope
    @ ~/scratch/OptimisersScratch/scratch/serialization.jl:16
  [8] include(fname::String)
    @ Base.MainInclude ./client.jl:444
  [9] top-level scope
    @ REPL[4]:1
 [10] top-level scope
    @ ~/.julia/packages/CUDA/qAl31/src/initialization.jl:52
in expression starting at /Users/mccoybecker/scratch/OptimisersScratch/scratch/serialization.jl:1

julia> versioninfo()
Julia Version 1.6.5
Commit 9058264a69 (2021-12-19 12:30 UTC)
Platform Info:
  OS: macOS (x86_64-apple-darwin19.6.0)
  CPU: Apple M1
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.1 (ORCJIT, westmere)
Environment:
  JULIA_PKG_USE_CLI_GIT = true
  JULIA_DYLIB_PATH = /usr/local/Caskroom/julia/1.6.0/Julia-1.6.app/Contents/Resources/julia/lib
@mcabbott
Copy link
Member

Optimisers does not work with the "implicit" Params / Grads mode of Zygote. It needs the "explicit" form, so that grads is some nested structure matching the model:

grads = Flux.gradient(model) do m
    sum(m(data))
end[1]

@femtomc
Copy link
Author

femtomc commented May 13, 2022

Got it, thank you!

@findmyway
Copy link

Optimisers does not work with the "implicit" Params / Grads mode of Zygote. It needs the "explicit" form, so that grads is some nested structure matching the model:

grads = Flux.gradient(model) do m
    sum(m(data))
end[1]

Can this be added into docs?

I also encounter the same error just now.

@mcabbott
Copy link
Member

See what you think of #80 .

It would be nice if this could be made to have a friendly error message, too. This package can't depend on Zygote so can't see the type Grads, but perhaps it could still notice the problem somehow?

@ToucheSir
Copy link
Member

I have some ideas with various levels of hackiness. Checking for the existence of each field on Grads should have a low false positive rate. Doubly so if we can assert the types of each field. Or maybe we can just use nameof + parentmodule and match on symbols?

A less reflection-heavy option would be to pirate Optimisers.update! for Grads in Flux and throw. The blast radius for that seems pretty small, but it is non-zero.

@mcabbott
Copy link
Member

I think the simplest pirate option is Optimisers.base(dx::Zygote.Grads) = error(), as update! has quite a few methods & you'd have to be careful about ambiguities.

We could certainly identify Grads quite accurately from its field names & their types. If we want such a check to happen only on a mismatch then that's tricky, it calls functor(typeof(x), base(x̄)) and fails inside that. An unconditional check on every update! is possible, and maybe it'll compile away.

julia> dump(g3; maxdepth=2)
Zygote.Grads
  grads: IdDict{Any, Any}
    ht: Array{Any}((32,))
    count: Int64 4
    ndel: Int64 0
  params: Params{Zygote.Buffer{Float64, Vector{Float64}}}
    order: Zygote.Buffer{Float64, Vector{Float64}}
    params: Base.IdSet{Any}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants