Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux.reset! triggers a BoundsError #2124

Closed
svilupp opened this issue Nov 24, 2022 · 1 comment
Closed

Flux.reset! triggers a BoundsError #2124

svilupp opened this issue Nov 24, 2022 · 1 comment
Labels

Comments

@svilupp
Copy link
Contributor

svilupp commented Nov 24, 2022

Package Version

0.13.8

Julia Version

1.8

OS / Environment

MacOS (arm64)

Describe the bug

I cannot reproduce the behaviour used in model Zoo tutorials for recurrent nets.

Specifically, I cannot reset hidden states on the whole chain. Expectation would be that calling Flux.reset!(rnn) where rnn is the whole chain would reset all recurring layers in it (eg, example here).

Steps to Reproduce

using Flux
using Flux: chunk

rnn = Chain(RNN(1,1,identity),Dense(1, 1, identity))
data=ones(Float32,1,4)
g = gradient(Flux.params(rnn)) do
    # Flux.reset!(rnn) # does not work
    Flux.reset!(rnn[1]) # does work
    sum(rnn(data)) # mock calculation to return a scalar
end

Expected Results

I expect to be able to call Flux.reset!(rnn) to reset hidden states.

Observed Results

Calling Flux.reset!(rnn) leads to error:

ERROR: BoundsError: attempt to access Tuple{} at index [0]

The issue can be solved by calling reset! only on the recurrent layers.
The error does not appear when calling reset! outside of pullback/AD context.

I suspect this line might be insufficient in the pullback context. but I'm not sure how to fix it - any ideas?

Relevant log output

ERROR: BoundsError: attempt to access Tuple{} at index [0]
Stacktrace:
[1] getindex(t::Tuple, i::Int64)
@ Base ./tuple.jl:29
[2] last(a::Tuple{})
@ Base ./abstractarray.jl:479
[3] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context{true}}, ::typeof(foldl), op::Base.var"#57#58"{typeof(Flux.reset!)}, x::Tuple{}; init::Nothing)
@ ChainRules ~/.julia/packages/ChainRules/hVHC4/src/rulesets/Base/mapreduce.jl:448
[4] chain_rrule_kw
@ ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:230 [inlined]
[5] macro expansion
@ ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0 [inlined]
[6] _pullback(::Zygote.Context{true}, ::Base.var"#foldl##kw", ::NamedTuple{(:init,), Tuple{Nothing}}, ::typeof(foldl), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Tuple{})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:9
[7] _pullback
@ ./tuple.jl:555 [inlined]
[8] _pullback(::Zygote.Context{true}, ::typeof(foreach), ::typeof(Flux.reset!), ::Tuple{})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[9] _pullback
@ ~/.julia/packages/Flux/FKl3M/src/layers/recurrent.jl:180 [inlined]
[10] _pullback(ctx::Zygote.Context{true}, f::typeof(Flux.reset!), args::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[11] _pullback
@ ./abstractarray.jl:2774 [inlined]
[12] _pullback(::Zygote.Context{true}, ::typeof(foreach), ::typeof(Flux.reset!), ::NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, Vector{Float32}, typeof(identity)}})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[13] _pullback
@ ~/.julia/packages/Flux/FKl3M/src/layers/recurrent.jl:180 [inlined]
[14] _pullback
@ ./tuple.jl:555 [inlined]
[15] #rrule_via_ad#46
@ ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:255 [inlined]
[16] rrule_via_ad
@ ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:243 [inlined]
[17] #1703
@ ~/.julia/packages/ChainRules/hVHC4/src/rulesets/Base/mapreduce.jl:444 [inlined]
[18] BottomRF
@ ./reduce.jl:81 [inlined]
[19] #836
@ ./accumulate.jl:291 [inlined]
[20] afoldl
@ ./operators.jl:549 [inlined]
[21] #accumulate#835
@ ./accumulate.jl:290 [inlined]
[22] #rrule#1702
@ ~/.julia/packages/ChainRules/hVHC4/src/rulesets/Base/mapreduce.jl:440 [inlined]
[23] chain_rrule_kw
@ ~/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:230 [inlined]
[24] macro expansion
@ ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0 [inlined]
[25] _pullback
@ ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:9 [inlined]
[26] _pullback
@ ./tuple.jl:555 [inlined]
[27] _pullback(::Zygote.Context{true}, ::typeof(foreach), ::typeof(Flux.reset!), ::Tuple{Flux.Recur{Flux.RNNCell{typeof(identity), Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[28] _pullback
@ ~/.julia/packages/Flux/FKl3M/src/layers/recurrent.jl:180 [inlined]
[29] _pullback(ctx::Zygote.Context{true}, f::typeof(Flux.reset!), args::Tuple{Flux.Recur{Flux.RNNCell{typeof(identity), Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[30] _pullback
@ ./abstractarray.jl:2774 [inlined]
[31] _pullback(::Zygote.Context{true}, ::typeof(foreach), ::typeof(Flux.reset!), ::NamedTuple{(:layers,), Tuple{Tuple{Flux.Recur{Flux.RNNCell{typeof(identity), Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[32] _pullback
@ ~/.julia/packages/Flux/FKl3M/src/layers/recurrent.jl:180 [inlined]
[33] _pullback(ctx::Zygote.Context{true}, f::typeof(Flux.reset!), args::Chain{Tuple{Flux.Recur{Flux.RNNCell{typeof(identity), Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[34] _pullback
@ ~/Documents/Julia-training-range/flux-raw/flux_bug.jl:7 [inlined]
[35] _pullback(::Zygote.Context{true}, ::var"#51#52")
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:0
[36] pullback(f::Function, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:373
[37] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
@ Zygote ~/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:96
[38] top-level scope
@ ~/Documents/Julia-training-range/flux-raw/flux_bug.jl:6

@svilupp svilupp added the bug label Nov 24, 2022
@ToucheSir
Copy link
Member

Dupe of FluxML/Zygote.jl#1297. See #2057 for some workarounds, though 99% of the time the way to avoid this is just to call reset! outside of gradient/pullback.

@ToucheSir ToucheSir closed this as not planned Won't fix, can't repro, duplicate, stale Nov 25, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants