Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train! using Metal and stateful optimizers fails #2310

Closed
sotlampr opened this issue Aug 10, 2023 · 5 comments · Fixed by #2318
Closed

train! using Metal and stateful optimizers fails #2310

sotlampr opened this issue Aug 10, 2023 · 5 comments · Fixed by #2318

Comments

@sotlampr
Copy link

sotlampr commented Aug 10, 2023

Using the "Metal" GPU backend and Flux.Optimise.train! with anything else than Flux.Optimise.Descent fails (tried Nesterov and Adam). Seems like stateful optimizers implicitly work with Float64. I'm not sure if this is a bug or expected behaviour due to the experimental nature of Metal backend.

Thanks!

Minimal example to reproduce:

using Flux
using Metal

@assert Flux.GPU_BACKEND == "Metal"

data = [(rand(Float32, 1024), rand(Float32.(0:1), 1)) for _=1:10] .|> gpu 
m = Dense(1024, 1) |> gpu 

loss(m, x, y) = Flux.Losses.logitcrossentropy(m(x), y)
opt = Flux.Train.setup(Flux.Optimise.Nesterov(), m)  # plain `Descent` works fine

Flux.Optimise.train!(loss, m, data, opt)
Backtrace when using `Nesterov`
ERROR: Metal does not support Float64 values, try using Float32 instead
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] MtlMatrix{Float64, Metal.MTL.MTLResourceStorageModePrivate}(#unused#::UndefInitializer, dims::Tuple{Int64, Int64})
    @ Metal ~/.julia/packages/Metal/qeZqc/src/array.jl:45
  [3] (MtlMatrix{Float64})(#unused#::UndefInitializer, dims::Tuple{Int64, Int64})
    @ Metal ~/.julia/packages/Metal/qeZqc/src/array.jl:94
  [4] (MtlArray{Float64})(#unused#::UndefInitializer, dims::Tuple{Int64, Int64})
    @ Metal ~/.julia/packages/Metal/qeZqc/src/array.jl:106
  [5] similar(#unused#::Type{MtlArray{Float64}}, dims::Tuple{Int64, Int64})
    @ Base ./abstractarray.jl:882
  [6] similar(#unused#::Type{MtlArray{Float64}}, shape::Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}})
    @ Base ./abstractarray.jl:881
  [7] similar(bc::Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{2}, Nothing, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(Base.literal_pow), Tuple{Base.RefValue{typeof(^)}, Float64, Base.RefValue{Val{2}}}}}}, MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}}}, Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{2}, Nothing, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(+), Tuple{Int64, Float64}}, Float64, MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}}}}}, #unused#::Type{Float64})
    @ Metal ~/.julia/packages/Metal/qeZqc/src/broadcast.jl:11
  [8] copy
    @ ~/.julia/packages/GPUArrays/5XhED/src/host/broadcast.jl:37 [inlined]
  [9] materialize
    @ ./broadcast.jl:873 [inlined]
 [10] apply!(o::Optimisers.Nesterov{Float64}, state::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, x::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, dx::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate})
    @ Optimisers ~/.julia/packages/Optimisers/F7eR3/src/rules.jl:80
 [11] _update!(ℓ::Optimisers.Leaf{Optimisers.Nesterov{Float64}, MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}}, x::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}; grads::IdDict{Optimisers.Leaf, Any}, params::IdDict{Any, Any})
    @ Optimisers ~/.julia/packages/Optimisers/F7eR3/src/interface.jl:92
 [12] _update!
    @ ~/.julia/packages/Optimisers/F7eR3/src/interface.jl:88 [inlined]
 [13] FluxML/Flux.jl#8
    @ ~/.julia/packages/Optimisers/F7eR3/src/interface.jl:81 [inlined]
 [14] map(f::Optimisers.var"#8#9"{IdDict{Optimisers.Leaf, Any}, IdDict{Any, Any}}, t::Tuple{Optimisers.Leaf{Optimisers.Nesterov{Float64}, MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}}, Optimisers.Leaf{Optimisers.Nesterov{Float64}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}}, Tuple{}}, s::Tuple{MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, typeof(identity)})
    @ Base ./tuple.jl:302
 [15] map(f::Function, nt::NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Nesterov{Float64}, MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}}, Optimisers.Leaf{Optimisers.Nesterov{Float64}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}}, Tuple{}}}, nts::NamedTuple{(:weight, :bias, :σ), Tuple{MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, typeof(identity)}})
    @ Base ./namedtuple.jl:219
 [16] valuemap
    @ ~/.julia/packages/Optimisers/F7eR3/src/interface.jl:178 [inlined]
 [17] _update!(tree::NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Nesterov{Float64}, MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}}, Optimisers.Leaf{Optimisers.Nesterov{Float64}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}}, Tuple{}}}, x::Dense{typeof(identity), MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}}; grads::IdDict{Optimisers.Leaf, Any}, params::IdDict{Any, Any})
    @ Optimisers ~/.julia/packages/Optimisers/F7eR3/src/interface.jl:81
 [18] _update!
    @ ~/.julia/packages/Optimisers/F7eR3/src/interface.jl:77 [inlined]
 [19] update!(::NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Nesterov{Float64}, MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}}, Optimisers.Leaf{Optimisers.Nesterov{Float64}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}}, Tuple{}}}, ::Dense{typeof(identity), MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}}, ::NamedTuple{(:weight, :bias, :σ), Tuple{MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Nothing}})
    @ Optimisers ~/.julia/packages/Optimisers/F7eR3/src/interface.jl:73
 [20] macro expansion
    @ ~/.julia/packages/Flux/pR3k3/src/train.jl:111 [inlined]
 [21] macro expansion
    @ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
 [22] train!(loss::Function, model::Dense{typeof(identity), MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}}, data::Vector{Tuple{MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}}}, opt::NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Nesterov{Float64}, MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}}, Optimisers.Leaf{Optimisers.Nesterov{Float64}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}}, Tuple{}}}; cb::Nothing)
    @ Flux.Train ~/.julia/packages/Flux/pR3k3/src/train.jl:105
 [23] train!(loss::Function, model::Dense{typeof(identity), MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}}, data::Vector{Tuple{MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}}}, opt::NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Nesterov{Float64}, MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}}, Optimisers.Leaf{Optimisers.Nesterov{Float64}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}}, Tuple{}}})
    @ Flux.Train ~/.julia/packages/Flux/pR3k3/src/train.jl:102
 [24] top-level scope
    @ ~/MLPluto/script.jl:17
 [25] include(fname::String)
    @ Base.MainInclude ./client.jl:478
 [26] top-level scope
    @ none:1
 [27] eval
    @ ./boot.jl:370 [inlined]
 [28] exec_options(opts::Base.JLOptions)
    @ Base ./client.jl:282
 [29] _start()
    @ Base ./client.jl:522
Backtrace when using `Adam`
ERROR: InvalidIRError: compiling MethodInstance for (::GPUArrays.var"#broadcast_kernel#26")(::Metal.mtlKernelContext, ::MtlDeviceMatrix{Float32, 1}, ::Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{2}, Nothing, typeof(*), Tuple{Float64, Base.Broadcast.Extruded{MtlDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{2}, Nothing, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float64}}, Base.Broadcast.Extruded{MtlDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}}}, ::Int64) resulted in invalid LLVM IR
Reason: unsupported unsupported use of double value
Reason: unsupported unsupported use of double value
Reason: unsupported unsupported use of double value
Reason: unsupported unsupported use of double value
Reason: unsupported unsupported use of double value
Stacktrace:
[1] Float64
  @ ./float.jl:261
[2] convert
  @ ./number.jl:7
[3] _promote
  @ ./promotion.jl:358
[4] promote
  @ ./promotion.jl:381
[5] *
  @ ./promotion.jl:411
[6] _broadcast_getindex_evalf
  @ ./broadcast.jl:683
[7] _broadcast_getindex
  @ ./broadcast.jl:656
[8] _getindex
  @ ./broadcast.jl:679
[9] _broadcast_getindex
  @ ./broadcast.jl:655
[10] getindex
  @ ./broadcast.jl:610
[11] broadcast_kernel
  @ ~/.julia/packages/GPUArrays/5XhED/src/host/broadcast.jl:59
Reason: unsupported unsupported use of double value
Stacktrace:
[1] *
 @ ./float.jl:410
[2] *
 @ ./promotion.jl:411
[3] _broadcast_getindex_evalf
 @ ./broadcast.jl:683
[4] _broadcast_getindex
 @ ./broadcast.jl:656
[5] _getindex
 @ ./broadcast.jl:679
[6] _broadcast_getindex
 @ ./broadcast.jl:655
[7] getindex
 @ ./broadcast.jl:610
[8] broadcast_kernel
 @ ~/.julia/packages/GPUArrays/5XhED/src/host/broadcast.jl:59
Reason: unsupported unsupported use of double value
Stacktrace:
[1] Float64
  @ ./float.jl:261
[2] convert
  @ ./number.jl:7
[3] _promote
  @ ./promotion.jl:358
[4] promote
  @ ./promotion.jl:381
[5] *
  @ ./promotion.jl:411
[6] _broadcast_getindex_evalf
  @ ./broadcast.jl:683
[7] _broadcast_getindex
  @ ./broadcast.jl:656
[8] _getindex
  @ ./broadcast.jl:680
[9] _getindex
  @ ./broadcast.jl:679
[10] _broadcast_getindex
  @ ./broadcast.jl:655
[11] getindex
  @ ./broadcast.jl:610
[12] broadcast_kernel
  @ ~/.julia/packages/GPUArrays/5XhED/src/host/broadcast.jl:59
Reason: unsupported unsupported use of double value
Stacktrace:
[1] *
 @ ./float.jl:410
[2] *
 @ ./promotion.jl:411
[3] _broadcast_getindex_evalf
 @ ./broadcast.jl:683
[4] _broadcast_getindex
 @ ./broadcast.jl:656
[5] _getindex
 @ ./broadcast.jl:680
[6] _getindex
 @ ./broadcast.jl:679
[7] _broadcast_getindex
 @ ./broadcast.jl:655
[8] getindex
 @ ./broadcast.jl:610
[9] broadcast_kernel
 @ ~/.julia/packages/GPUArrays/5XhED/src/host/broadcast.jl:59
Reason: unsupported unsupported use of double value
Stacktrace:
[1] +
 @ ./float.jl:408
[2] _broadcast_getindex_evalf
 @ ./broadcast.jl:683
[3] _broadcast_getindex
 @ ./broadcast.jl:656
[4] getindex
 @ ./broadcast.jl:610
[5] broadcast_kernel
 @ ~/.julia/packages/GPUArrays/5XhED/src/host/broadcast.jl:59
Reason: unsupported unsupported use of double value
Stacktrace:
[1] Float32
 @ ./float.jl:258
[2] convert
 @ ./number.jl:7
[3] setindex!
 @ ~/.julia/packages/Metal/qeZqc/src/device/array.jl:105
[4] setindex!
 @ ~/.julia/packages/Metal/qeZqc/src/device/array.jl:118
[5] broadcast_kernel
 @ ~/.julia/packages/GPUArrays/5XhED/src/host/broadcast.jl:59
Hint: catch this exception as `err` and call `code_typed(err; interactive = true)` to introspect the erronous code with Cthulhu.jl
Stacktrace:
[1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}, args::LLVM.Module)
  @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/validation.jl:149
[2] macro expansion
  @ ~/.julia/packages/GPUCompiler/YO8Uj/src/driver.jl:415 [inlined]
[3] macro expansion
  @ ~/.julia/packages/TimerOutputs/RsWnF/src/TimerOutput.jl:253 [inlined]
[4] macro expansion
  @ ~/.julia/packages/GPUCompiler/YO8Uj/src/driver.jl:414 [inlined]
[5] emit_llvm(job::GPUCompiler.CompilerJob; libraries::Bool, toplevel::Bool, optimize::Bool, cleanup::Bool, only_entry::Bool, validate::Bool)
  @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:89
[6] emit_llvm
  @ ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:83 [inlined]
[7] codegen(output::Symbol, job::GPUCompiler.CompilerJob; libraries::Bool, toplevel::Bool, optimize::Bool, cleanup::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
  @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/driver.jl:129
[8] codegen
  @ ~/.julia/packages/GPUCompiler/YO8Uj/src/driver.jl:110 [inlined]
[9] compile(target::Symbol, job::GPUCompiler.CompilerJob; libraries::Bool, toplevel::Bool, optimize::Bool, cleanup::Bool, strip::Bool, validate::Bool, only_entry::Bool)
  @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/driver.jl:106
[10] compile
  @ ~/.julia/packages/GPUCompiler/YO8Uj/src/driver.jl:98 [inlined]
[11] FluxML/Flux.jl#51
  @ ~/.julia/packages/Metal/qeZqc/src/compiler/compilation.jl:57 [inlined]
[12] JuliaContext(f::Metal.var"#51#52"{GPUCompiler.CompilerJob{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}})
  @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/driver.jl:47
[13] compile(job::GPUCompiler.CompilerJob)
  @ Metal ~/.julia/packages/Metal/qeZqc/src/compiler/compilation.jl:56
[14] actual_compilation(cache::Dict{Any, Any}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}, compiler::typeof(Metal.compile), linker::typeof(Metal.link))
  @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/execution.jl:125
[15] cached_compilation(cache::Dict{Any, Any}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{GPUCompiler.MetalCompilerTarget, Metal.MetalCompilerParams}, compiler::Function, linker::Function)
  @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/execution.jl:103
[16] macro expansion
  @ ~/.julia/packages/Metal/qeZqc/src/compiler/execution.jl:162 [inlined]
[17] macro expansion
  @ ./lock.jl:267 [inlined]
[18] mtlfunction(f::GPUArrays.var"#broadcast_kernel#26", tt::Type{Tuple{Metal.mtlKernelContext, MtlDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{2}, Nothing, typeof(*), Tuple{Float64, Base.Broadcast.Extruded{MtlDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{2}, Nothing, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float64}}, Base.Broadcast.Extruded{MtlDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}}}, Int64}}; name::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
  @ Metal ~/.julia/packages/Metal/qeZqc/src/compiler/execution.jl:157
[19] mtlfunction(f::GPUArrays.var"#broadcast_kernel#26", tt::Type{Tuple{Metal.mtlKernelContext, MtlDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{2}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{2}, Nothing, typeof(*), Tuple{Float64, Base.Broadcast.Extruded{MtlDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Base.Broadcast.Broadcasted{Metal.MtlArrayStyle{2}, Nothing, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Int64, Float64}}, Base.Broadcast.Extruded{MtlDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}}}, Int64}})
  @ Metal ~/.julia/packages/Metal/qeZqc/src/compiler/execution.jl:155
[20] macro expansion
  @ ~/.julia/packages/Metal/qeZqc/src/compiler/execution.jl:77 [inlined]
[21] #launch_heuristic#98
  @ ~/.julia/packages/Metal/qeZqc/src/gpuarrays.jl:14 [inlined]
[22] launch_heuristic
  @ ~/.julia/packages/Metal/qeZqc/src/gpuarrays.jl:12 [inlined]
[23] _copyto!
  @ ~/.julia/packages/GPUArrays/5XhED/src/host/broadcast.jl:65 [inlined]
[24] materialize!
  @ ~/.julia/packages/GPUArrays/5XhED/src/host/broadcast.jl:41 [inlined]
[25] materialize!
  @ ./broadcast.jl:881 [inlined]
[26] macro expansion
  @ ~/.julia/packages/Optimisers/F7eR3/src/interface.jl:201 [inlined]
[27] apply!(o::Optimisers.Adam{Float64}, state::Tuple{MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{Float64, Float64}}, x::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, dx::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate})
  @ Optimisers ~/.julia/packages/Optimisers/F7eR3/src/rules.jl:213
[28] _update!(ℓ::Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{Float64, Float64}}}, x::MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}; grads::IdDict{Optimisers.Leaf, Any}, params::IdDict{Any, Any})
  @ Optimisers ~/.julia/packages/Optimisers/F7eR3/src/interface.jl:92
[29] _update!
  @ ~/.julia/packages/Optimisers/F7eR3/src/interface.jl:88 [inlined]
[30] FluxML/Flux.jl#8
  @ ~/.julia/packages/Optimisers/F7eR3/src/interface.jl:81 [inlined]
[31] map(f::Optimisers.var"#8#9"{IdDict{Optimisers.Leaf, Any}, IdDict{Any, Any}}, t::Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{Float64, Float64}}}, Tuple{}}, s::Tuple{MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, typeof(identity)})
  @ Base ./tuple.jl:302
[32] map(f::Function, nt::NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{Float64, Float64}}}, Tuple{}}}, nts::NamedTuple{(:weight, :bias, :σ), Tuple{MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, typeof(identity)}})
  @ Base ./namedtuple.jl:219
[33] valuemap
  @ ~/.julia/packages/Optimisers/F7eR3/src/interface.jl:178 [inlined]
[34] _update!(tree::NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{Float64, Float64}}}, Tuple{}}}, x::Dense{typeof(identity), MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}}; grads::IdDict{Optimisers.Leaf, Any}, params::IdDict{Any, Any})
  @ Optimisers ~/.julia/packages/Optimisers/F7eR3/src/interface.jl:81
[35] _update!
  @ ~/.julia/packages/Optimisers/F7eR3/src/interface.jl:77 [inlined]
[36] update!(::NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{Float64, Float64}}}, Tuple{}}}, ::Dense{typeof(identity), MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}}, ::NamedTuple{(:weight, :bias, :σ), Tuple{MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Nothing}})
  @ Optimisers ~/.julia/packages/Optimisers/F7eR3/src/interface.jl:73
[37] macro expansion
  @ ~/.julia/packages/Flux/pR3k3/src/train.jl:111 [inlined]
[38] macro expansion
  @ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
[39] train!(loss::Function, model::Dense{typeof(identity), MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}}, data::Vector{Tuple{MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}}}, opt::NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{Float64, Float64}}}, Tuple{}}}; cb::Nothing)
  @ Flux.Train ~/.julia/packages/Flux/pR3k3/src/train.jl:105
[40] train!(loss::Function, model::Dense{typeof(identity), MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}}, data::Vector{Tuple{MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}}}, opt::NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlMatrix{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, MtlVector{Float32, Metal.MTL.MTLResourceStorageModePrivate}, Tuple{Float64, Float64}}}, Tuple{}}})
  @ Flux.Train ~/.julia/packages/Flux/pR3k3/src/train.jl:102
[41] top-level scope
  @ ~/MLPluto/script.jl:29
[42] include(fname::String)
  @ Base.MainInclude ./client.jl:478
[43] top-level scope
  @ none:1
[44] eval
  @ ./boot.jl:370 [inlined]
[45] exec_options(opts::Base.JLOptions)
  @ Base ./client.jl:282
[46] _start()
  @ Base ./client.jl:522
Julia version: 1.9.2
Julia Version 1.9.2
Commit e4ee485e90 (2023-07-05 09:39 UTC)
Platform Info:
  OS: macOS (arm64-apple-darwin22.5.0)
  CPU: 8 × Apple M1 Pro
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, apple-m1)
  Threads: 1 on 6 virtual cores

Flux version: 0.14.2
Metal version: 0.5.0

@ToucheSir ToucheSir transferred this issue from FluxML/Flux.jl Aug 10, 2023
@ToucheSir
Copy link
Member

We should be doing calculations in Float32, so either some intermediate ones aren't or Float64 gradients are somehow being passed to the optimizers. Hence I've moved this to Optimisers.jl. I would recommend first checking the gradients to make sure all are Float32. If you've confirmed that, then it should be straightforward to create a MWE using a single array and Optimisers.jl (no Flux), which should make it easier for us to investigate.

@sotlampr
Copy link
Author

sotlampr commented Aug 10, 2023

@ToucheSir Thanks for you reply.
Replacing

> Flux.Optimise.Nesterov()
Nesterov(0.001, 0.9, IdDict{Any, Any}())

with

> Optimisers.Nesterov()
Nesterov{Float32}(0.001f0, 0.9f0)

fixes the problem.

It seems in Optimisers, the hyper-parameters are templated, while in Flux they are explicitly typed as Float64.

So it's a Flux bug, but I guess you're going to move to Optimisers at some point? This is also contradicted in Flux documentation:

The new version of Flux's training code was written as an independent package, Optimisers.jl. Only the function train! belongs to Flux itself.

@ToucheSir
Copy link
Member

ToucheSir commented Aug 10, 2023

We already did. If you use Flux per the tutorials, everything should run on Optimisers.jl under the hood by default. https://fluxml.ai/Flux.jl/stable/training/optimisers/#man-optimisers touches on this. You should never have to touch Flux.Optimise now.

@sotlampr
Copy link
Author

Great! One should, however, use explicitly Optimisers, as e.g. using Flux; Nesterov() returns the old interface which is not templated and will fail on GPUs (at least Metal). From the documentation:

In Flux 0.14, Flux.Adam() returns the old one, with supertype Flux.Optimise.AbstractOptimiser, but setup will silently translate it to its new counterpart.

Also,

using Flux, Optimisers
Adam()

warns WARNING: both Optimisers and Flux export "Adam"; uses of it in module Main must be qualified

You may want to consider making this more clear, that one should explicitly use Optimisers.

Thanks!

@ToucheSir ToucheSir transferred this issue from FluxML/Optimisers.jl Aug 10, 2023
@ToucheSir ToucheSir reopened this Aug 10, 2023
@ToucheSir
Copy link
Member

You shouldn't have to use it explicitly, the translation layer should take care of it. In this case it appears the translation layer isn't converting hyperparams from Float64 -> Float32, so this remains an issue. It should be fixed by FluxML/Optimisers.jl#151, so leaving this open until that's released.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants