-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Make Enzyme discrete adjoints work #2282
Conversation
MWE now works: ```julia using Enzyme, OrdinaryDiffEq, StaticArrays Enzyme.EnzymeCore.EnzymeRules.inactive_type(::Type{SciMLBase.DEStats}) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf_from_initdt!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.fixed_t_for_floatingpoint_error!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_accept!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_reject!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(DiffEqBase.fastpow), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.increment_nf_perform_step!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.check_error!), args...) = true Enzyme.EnzymeCore.EnzymeRules.inactive(::typeof(OrdinaryDiffEq.log_step!), args...) = true function lorenz!(du, u, p, t) du[1] = 10.0(u[2] - u[1]) du[2] = u[1] * (28.0 - u[3]) - u[2] du[3] = u[1] * u[2] - (8 / 3) * u[3] end const _saveat = SA[0.0,0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0] function f(y::Array{Float64}, u0::Array{Float64}) tspan = (0.0, 3.0) prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz!, u0, tspan) sol = DiffEqBase.solve(prob, Tsit5(), saveat = _saveat, sensealg = DiffEqBase.SensitivityADPassThrough()) y .= sol[1,:] return nothing end; u0 = [1.0; 0.0; 0.0] d_u0 = zeros(3) y = zeros(13) dy = zeros(13) Enzyme.autodiff(Reverse, f, Duplicated(y, dy), Duplicated(u0, d_u0)); ``` Core issues to finish this: 1. I shouldn't have to pull all of the logging out to a separate function, but there seems to be a bug in enzyme with int inactivity EnzymeAD/Enzyme.jl#1636 2. `saveat` has issues because it uses Julia ranges, which can have a floating point fix issue EnzymeAD/Enzyme.jl#274 3. adding the zero(u), zero(u) is required because Enzyme does not seem to support non-fully initialized types (@wsmoses is that known?) and segfaults when trying to use the uninitialized memory. So making the inner constructor not use undef is and easy fix to that. But that's not memory optimal. It would take a bit of a refactor to make it memory optimal, but it's no big deal and it's probably something that improves the package anyways.
@ChrisRackauckas I don't follow, what's the error with non-fully initialized types? |
Presumably not known, so open an issue |
Julia allows for a type to be partially defined by having an inner type constructor which only puts a value on a subset of arguments. The others are set to |
@@ -65,7 +65,7 @@ end | |||
q = inv(qmax) | |||
else | |||
expo = 1 / (get_current_adaptive_order(alg, integrator.cache) + 1) | |||
qtmp = DiffEqBase.fastpow(EEst, expo) / gamma | |||
qtmp = ^(EEst, expo) / gamma |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't these have a @fastmath
(and Float32)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I was just seeing what's required in order to make Enzyme work on this.
This was used as a performance optimization early on, dropping the construction of those two vectors since we already construct so much in the caches, we might as well reuse one of the cache pointers. And it's also built into some methods what cache pointer that must be. So the integrator is made with undef's and then during initialization phase the pointers are set. However, this is unnecessary and adds some complexity. For one, it makes the constructor a bit of a mess. But for two, it gives Enzyme issues as demonstrated in #2282. A better solution is then to just, construct the type correctly. To do this, we simply need to refactor the information of what vectors correspond to fsal first and last into a function that is per-cache, and use that function in the integrator construction. That's already done in this PR. All that's required to complete this PR is to ensure this refactor is done on every method.
All of the pieces are in now, Enzyme now works on the explicit solvers! |
MWE now works:
Core issues to finish this:
saveat
has issues because it uses Julia ranges, which can have a floating point fix issuefloatrange
causes "unkown" binary operator EnzymeAD/Enzyme.jl#274