-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ForwardDiff + destructure is different from Zygote, on a model with BatchNorm #2122
Comments
Can reproduce. I note that commenting out And that inserting ERROR: MethodError: no method matching Float32(::ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12})
Stacktrace:
...
[12] _track_stats!(bn::BatchNorm{typeof(relu), Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}}, Float32, Vector{Float32}}, x::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}, 4}, μ::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}, 4}, σ²::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}, 4}, reduce_dims::Vector{Int64})
@ Flux ~/.julia/packages/Flux/nJ0IB/src/layers/normalise.jl:278
[13] _norm_layer_forward(l::BatchNorm{typeof(relu), Vector{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}}, Float32, Vector{Float32}}, x::Array{ForwardDiff.Dual{ForwardDiff.Tag{var"#19#20", Float32}, Float32, 12}, 4}; reduce_dims::Vector{Int64}, affine_shape::NTuple{4, Int64})
@ Flux ~/.julia/packages/Flux/nJ0IB/src/layers/normalise.jl:253
... |
For this last error, we'd need some way to pull the value out of the |
I think we can harmlessly just insert For automatic train-mode, if we do something like FluxML/NNlib.jl#434 then we can have a method for AbstractArray{<:Dual}. But I don't know what package ought to own it. |
So one fly in the ointment is that I was hoping to move
One path would be to be use Requires in NNlib to get non-CR ADs to conform to FluxML/NNlib.jl#434. Another would be adding it to AbstractDifferentiation.jl, which already uses Requires for FD + RD + Tracker. Any other ideas I had (e.g. splitting off Dual numbers from ForwardDiff and having NNlib define methods on them) feel too far off to be feasible. |
I had to check but NNlib is much lighter than ForwardDiff, even if that moves to StaticArraysCore. But it does load Requires, so that might be fine:
|
AbstractDiff is pretty similar. Let me file an issue over there and see how it goes. We can always look into the NNlib option in parallel. |
Besides detecting whether you are within AD (also an issue for dropout), the problem with BatchNorm is that ForwardDiff runs the forward pass several times (chunked mode, for any large array). I can't think of a good way to detect that. Perhaps we should make it a clear error instead? |
Package Version
Optimisers v0.2.10, ForwardDiff v0.10.33, Flux v0.13.7
Julia Version
1.8
OS / Environment
OS
Describe the bug
The example from Optimisers using ForwardDiff compared with the output from Zygote is unfortunately not the same.
Steps to Reproduce
and Zygote version
Expected Results
Observed Results
Relevant log output
No response
The text was updated successfully, but these errors were encountered: