Skip to content

Commit

Permalink
test: use symbolic save_idxs in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 21, 2024
1 parent d74ddf7 commit 56159bc
Showing 1 changed file with 5 additions and 23 deletions.
28 changes: 5 additions & 23 deletions test/downstream/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,21 +191,13 @@ end
ss1.state_map == ss2.state_map
end

ode_sol = solve(prob, Tsit5(); save_idxs = xidx)
subsys = SciMLBase.SavedSubsystem(sys, prob.p, [xidx])
@test SciMLBase.get_saved_state_idxs(subsys) == [xidx]

# FIXME: hack for save_idxs
SciMLBase.@reset ode_sol.saved_subsystem = subsys
ode_sol = solve(prob, Tsit5(); save_idxs = [x])

@mtkbuild sys = ODESystem([D(x) ~ x + p * y, 1 ~ sin(y) + cos(x)], t)
xidx = variable_index(sys, x)
prob = DAEProblem(sys, [D(x) => x + p * y, D(y) => 1 / sqrt(1 - (1 - cos(x))^2)],
[x => 1.0, y => asin(1 - cos(x))], (0.0, 1.0), [p => 2.0])
dae_sol = solve(prob, DFBDF(); save_idxs = xidx)
subsys = SciMLBase.SavedSubsystem(sys, prob.p, [xidx])
# FIXME: hack for save_idxs
SciMLBase.@reset dae_sol.saved_subsystem = subsys
[x => 1.0, y => asin(1 - cos(x))], (0.0, 1.0), [p => 2.0]; build_initializeprob = false)
dae_sol = solve(prob, DFBDF(); save_idxs = [x])

@brownian a b
@mtkbuild sys = System([D(x) ~ x + p * y + x * a, D(y) ~ 2p + x^2 + y * b], t)
Expand Down Expand Up @@ -256,21 +248,11 @@ end

@test SciMLBase.SavedSubsystem(sys, prob.p, [x, y, q, r, s, u]) === nothing

sol = solve(prob; save_idxs = xidx)
sol = solve(prob; save_idxs = [x, q, r])
xvals = sol[x]
subsys = SciMLBase.SavedSubsystem(sys, prob.p, [x, q, r])
@test SciMLBase.get_saved_state_idxs(subsys) == [xidx]
@test SciMLBase.get_saved_state_idxs(sol.saved_subsystem) == [xidx]
qvals = sol.ps[q]
rvals = sol.ps[r]
# FIXME: hack for save_idxs
SciMLBase.@reset sol.saved_subsystem = subsys
discq = DiffEqArray(SciMLBase.TupleOfArraysWrapper.(tuple.(Base.vect.(qvals))),
sol.discretes[qpidx.timeseries_idx].t, (1, 1))
discr = DiffEqArray(SciMLBase.TupleOfArraysWrapper.(tuple.(Base.vect.(rvals))),
sol.discretes[rpidx.timeseries_idx].t, (1, 1))
SciMLBase.@reset sol.discretes.collection[qpidx.timeseries_idx] = discq
SciMLBase.@reset sol.discretes.collection[rpidx.timeseries_idx] = discr

@test sol[x] == xvals

@test all(Base.Fix1(is_parameter, sol), [p, q, r, s, u])
Expand Down

0 comments on commit 56159bc

Please sign in to comment.