-
-
Notifications
You must be signed in to change notification settings - Fork 73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Out of place QuadratureAdjoint for Working with StaticArrays #680
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also add a test set.
To format the files, you can run
using JuliaFormatter, SciMLSensitivity
format(joinpath(dirname(pathof(SciMLSensitivity)), ".."))
This doesn't need a whole new dispatch. You just need to handle the mutating cases with a branch, not copy paste the entire code and change like 30 lines. |
This is much better than before. You need to have an alternative |
src/concrete_solve.jl
Outdated
end | ||
end | ||
|
||
function df(u, p, t, i;outtype=nothing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to make this work, you probably need to change the dgdu accmulation code in adjoint_common inside of the affect!
definition
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, so its inferred as a redefinition of the first df, which is causing the tests using the first one to fail
test/adjoint.jl
Outdated
####Fully oop Adjoint | ||
|
||
u0 = @SArray Float32[2.0, 0.0] | ||
datasize = 30 | ||
tspan = (0.0f0, 1.5f0) | ||
tsteps = range(tspan[1], tspan[2], length = datasize) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a StaticArray/SimpleChain regression test for optimization, not an adjoint test. Please add a proper adjoint test. It should probably be its own file adjoint_oop.jl
where it similarly tests the direct and AD interfaces for gradient correctness.
This test should be a separate SimpleChains regression test.
Rebase this onto master? I think this looks generally correct now. It needs to change the tests as commented above, and of course fix tests, but it looks like it's in the right direction. |
We'll need to wait for the SimpleChains branch linked at the top to be merged. The tests are passing on it. SimpleChains master doesn't return SArray in the reverse pass so the reverse ode throws the type not constant error, it's fixed in the above branch |
test/adjoint_oop.jl
Outdated
n_du0 = ForwardDiff.gradient(G_u, u0) | ||
|
||
@test n_du0≈du0 rtol=1e-3 | ||
@test_broken n_dp≈dp' rtol=1e-3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not good?
Can you add the requested accuracy tests? |
For Lotka Volterra the issue is that making a function like function lotka(u, p, t)
du1 = p[1]*u[1] - p[2]*u[1]*u[2]
du2 = -p[3]*u[2] + p[4]*u[1]*u[2]
SVector(du1, du2)
end doesn't go well with Zygote when calling du0, dp = adjoint_sensitivities(sol, Tsit5(); t = tsteps, dgdu_discrete = dg_disc,
sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP())) on the solve as constructing an SVector as a return doesn't have an adjoint and throws
But, it has to be made an SVector to keep the type of du same as that of u. So, I had to make a function which returns an SVector as output as a result of the computation u0 = @SVector [1.0f0, 1.0f0]
p = @SMatrix [1.5f0 -1.0f0; 3.0f0 -1.0f0]
function f(u, p, t)
p * u
end |
Add a fix for that in ChainRules? |
The adjoint should construct an SVector of the same size? |
Yes |
Is this being finished? What's left? |
The StaticArrays AD stuff is almost done here JuliaArrays/StaticArrays.jl#1068 , I hope it will get merged. |
Tested on PumasAI/SimpleChains.jl/pull/97 , so should be merged after it.
NeuralODE example