Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Out of place QuadratureAdjoint for Working with StaticArrays #680

Merged
merged 33 commits into from
Aug 13, 2022
Merged

Out of place QuadratureAdjoint for Working with StaticArrays #680

merged 33 commits into from
Aug 13, 2022

Conversation

ba2tripleO
Copy link
Contributor

Tested on PumasAI/SimpleChains.jl/pull/97 , so should be merged after it.
NeuralODE example

using SimpleChains, StaticArrays, OrdinaryDiffEq, SciMLSensitivity, Optimization, OptimizationFlux, Plots

u0 = @SArray Float32[2.0, 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)

function trueODE(u, p, t)
    true_A = @SMatrix Float32[-0.1 2.0; -2.0 -0.1]
    ((u.^3)'true_A)'
end

prob = ODEProblem(trueODE, u0, tspan)
data = Array(solve(prob, Tsit5(), saveat = tsteps))

sc = SimpleChain(
                static(2),
                Activation(x -> x.^3),
                TurboDense{true}(tanh, static(50)),
                TurboDense{true}(identity, static(2))
            )

p_nn = SimpleChains.init_params(sc)

f(u,p,t) = sc(u,p)

prob_nn = ODEProblem(f, u0, tspan)

function predict_neuralode(p)
    Array(solve(prob_nn, Tsit5();p=p,saveat=tsteps,sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())))
end

function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, data .- pred)
    return loss, pred
end

callback = function (p, l, pred; doplot = true)
    display(l)
    plt = scatter(tsteps, data[1,:],label="data")
    scatter!(plt, tsteps, pred[1,:], label = "prediction")
    if doplot
        display(plot(plt))
    end
    return false
end

optf = Optimization.OptimizationFunction((x,p)->loss_neuralode(x), Optimization.AutoZygote())
optprob = Optimization.OptimizationProblem(optf, p_nn)

res = Optimization.solve(optprob, ADAM(0.05),callback=callback,maxiters=300)

Copy link
Member

@frankschae frankschae left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add a test set.

To format the files, you can run

using JuliaFormatter, SciMLSensitivity
format(joinpath(dirname(pathof(SciMLSensitivity)), ".."))

src/concrete_solve.jl Outdated Show resolved Hide resolved
src/concrete_solve.jl Outdated Show resolved Hide resolved
src/concrete_solve.jl Outdated Show resolved Hide resolved
src/derivative_wrappers.jl Outdated Show resolved Hide resolved
@ChrisRackauckas
Copy link
Member

This doesn't need a whole new dispatch. You just need to handle the mutating cases with a branch, not copy paste the entire code and change like 30 lines.

@ChrisRackauckas
Copy link
Member

This is much better than before. You need to have an alternative dgdu definition for the out of place version which doesn't take an argument to mutate (out). You should create tests for this by using the direct adjoint interface: don't even try the concrete solve / sensealg handling pieces until you have the direct adjoint definitions working.

end
end

function df(u, p, t, i;outtype=nothing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to make this work, you probably need to change the dgdu accmulation code in adjoint_common inside of the affect! definition

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, so its inferred as a redefinition of the first df, which is causing the tests using the first one to fail

src/concrete_solve.jl Outdated Show resolved Hide resolved
test/adjoint.jl Outdated
Comment on lines 870 to 876
####Fully oop Adjoint

u0 = @SArray Float32[2.0, 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a StaticArray/SimpleChain regression test for optimization, not an adjoint test. Please add a proper adjoint test. It should probably be its own file adjoint_oop.jl where it similarly tests the direct and AD interfaces for gradient correctness.

This test should be a separate SimpleChains regression test.

@ChrisRackauckas
Copy link
Member

Rebase this onto master?

I think this looks generally correct now. It needs to change the tests as commented above, and of course fix tests, but it looks like it's in the right direction.

@ba2tripleO
Copy link
Contributor Author

We'll need to wait for the SimpleChains branch linked at the top to be merged. The tests are passing on it. SimpleChains master doesn't return SArray in the reverse pass so the reverse ode throws the type not constant error, it's fixed in the above branch

n_du0 = ForwardDiff.gradient(G_u, u0)

@test n_du0≈du0 rtol=1e-3
@test_broken n_dp≈dp' rtol=1e-3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not good?

@ChrisRackauckas
Copy link
Member

Can you add the requested accuracy tests?

@ba2tripleO
Copy link
Contributor Author

ba2tripleO commented Jul 20, 2022

For Lotka Volterra the issue is that making a function like

function lotka(u, p, t)
    du1 = p[1]*u[1] - p[2]*u[1]*u[2]
    du2 = -p[3]*u[2] + p[4]*u[1]*u[2]
    SVector(du1, du2)
end

doesn't go well with Zygote when calling

du0, dp = adjoint_sensitivities(sol, Tsit5(); t = tsteps, dgdu_discrete = dg_disc,
                                sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP()))

on the solve as constructing an SVector as a return doesn't have an adjoint and throws

ERROR: Need an adjoint for constructor SVector{2, Float32}. Gradient is of type SVector{2, Float32}

But, it has to be made an SVector to keep the type of du same as that of u. So, I had to make a function which returns an SVector as output as a result of the computation

u0 = @SVector [1.0f0, 1.0f0]
p = @SMatrix [1.5f0 -1.0f0; 3.0f0 -1.0f0]

function f(u, p, t)
    p * u
end

@ChrisRackauckas
Copy link
Member

ERROR: Need an adjoint for constructor SVector{2, Float32}. Gradient is of type SVector{2, Float32}

Add a fix for that in ChainRules?

@ba2tripleO
Copy link
Contributor Author

The adjoint should construct an SVector of the same size?

@ChrisRackauckas
Copy link
Member

Yes

@ChrisRackauckas
Copy link
Member

Is this being finished? What's left?

@ba2tripleO
Copy link
Contributor Author

ba2tripleO commented Aug 6, 2022

The StaticArrays AD stuff is almost done here JuliaArrays/StaticArrays.jl#1068 , I hope it will get merged.

@ChrisRackauckas
Copy link
Member

That error is #707 which is hopefully fixed in #706. I'll rebase after that fix is merged and merge if tests pass.

@ChrisRackauckas ChrisRackauckas merged commit b4fd1ce into SciML:master Aug 13, 2022
@ba2tripleO ba2tripleO deleted the sarray branch December 26, 2022 05:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants