Skip to content

Commit

Permalink
Fix array save_idxs
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Apr 22, 2021
1 parent 3a6cad5 commit 7b22040
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ function DiffEqBase._concrete_solve_adjoint(prob,alg,
elseif _save_idxs isa Colon
vec(_out) .= -vec(adapt(DiffEqBase.parameterless_type(u0),reshape(Δ, prod(size(Δ)[1:end-1]), size(Δ)[end])[:, i]))
else
vec(@view(_out[_save_idxs])) .= -vec(adapt(DiffEqBase.parameterless_type(u0),reshape(Δ, prod(size(Δ)[1:end-1]), size(Δ)[end])[_save_idxs, i]))
vec(@view(_out[_save_idxs])) .= -vec(adapt(DiffEqBase.parameterless_type(u0),reshape(Δ, prod(size(Δ)[1:end-1]), size(Δ)[end])[:, i]))
end
end
end
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ if GROUP == "All" || GROUP == "Core1" || GROUP == "Downstream"
@time @safetestset "Concrete Solve Derivatives" begin include("concrete_solve_derivatives.jl") end
@time @safetestset "Branching Derivatives" begin include("branching_derivatives.jl") end
@time @safetestset "Derivative Shapes" begin include("derivative_shapes.jl") end
@time @safetestset "save_idxs" begin include("save_idxs.jl") end
@time @safetestset "ArrayPartitions" begin include("array_partitions.jl") end
@time @safetestset "Complex Adjoints" begin include("complex_adjoints.jl") end
@time @safetestset "Forward Remake" begin include("forward_remake.jl") end
Expand Down
31 changes: 31 additions & 0 deletions test/save_idxs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using OrdinaryDiffEq, DiffEqSensitivity, Zygote, ForwardDiff, Test

function lotka_volterra!(du, u, p, t)
x, y = u
α, β, δ, γ = p
du[1] = dx = α*x - β*x*y
du[2] = dy = -δ*y + γ*x*y
end

# Initial condition
u0 = [1.0, 1.0]

# Simulation interval and intermediary points
tspan = (0.0, 10.0)
tsteps = 0.0:0.1:10.0

# LV equation parameter. p = [α, β, δ, γ]
p = [1.5, 1.0, 3.0, 1.0]

# Setup the ODE problem, then solve
prob = ODEProblem(lotka_volterra!, u0, tspan, p)

function loss(p)
sol = solve(prob, Tsit5(), p=p, save_idxs=[2], saveat = tsteps, abstol=1e-14, reltol=1e-14)
loss = sum(abs2, sol.-1)
return loss
end

grad1 = Zygote.gradient(loss,p)[1]
grad2 = ForwardDiff.gradient(loss,p)
@test grad1 grad2

0 comments on commit 7b22040

Please sign in to comment.