Skip to content

Commit

Permalink
Merge pull request #210 from SciML/preset_reset
Browse files Browse the repository at this point in the history
Reset the indices in the PresetTimeCallback
  • Loading branch information
ChrisRackauckas authored Mar 11, 2024
2 parents 2a46de7 + e9ccff4 commit 10e0c24
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/iterative_and_periodic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ function add_next_tstop!(integrator, S)
tdir
=#
tdir_tnew = integrator.tdir * tnew
index[] += 1
if tdir_tnew < maximum(tstops.valtree)
index[] += 1
add_tstop!(integrator, tnew)
end
end
Expand Down Expand Up @@ -143,7 +143,8 @@ function PeriodicCallback(f, Δt::Number;
index = Ref(0)

condition = function (u, t, integrator)
t == (t0[] + index[] * Δt) || (final_affect && isfinished(integrator))
fin = isfinished(integrator)
(t == (t0[] + index[] * Δt) && !fin) || (final_affect && fin)
end

# Call f, update tnext, and make sure we stop at the new tnext
Expand Down
6 changes: 6 additions & 0 deletions src/preset_time.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ function PresetTimeCallback(tstops, user_affect!;
# Initialization: first call to `f` should be *before* any time steps have been taken:
initialize_preset = function (c, u, t, integrator)
initialize(c, u, t, integrator)
if tstops isa AbstractVector
search_start, search_end = firstindex(tstops), lastindex(tstops)
else
search_start, search_end = 0, 0
end

if filter_tstops
tdir = integrator.tdir
_tstops = tstops[@.((tdir * tstops >
Expand Down
29 changes: 29 additions & 0 deletions test/periodic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,32 @@ cb = PeriodicCallback(periodic, 11.0, initial_affect = false)
prob = ODEProblem(fff, u0, tspan, p)
sol = solve(prob, Tsit5(), callback = cb)
@test sol.t[end] == tspan[2]

# Fix indexing repeats
# https://github.com/SciML/ModelingToolkit.jl/issues/2528

function lineardecay(du, u, p, t)
du[1] = -u[1]
end

function bumpaffect!(integ)
integ.u[1] += 10
end

cb = PeriodicCallback(bumpaffect!, 24.0)
prob = ODEProblem(lineardecay, [0.0], (0.0, 130.0))
sol1 = solve(prob, Tsit5(), callback = cb)

@test sol1(0.0) == [0.0]
@test sol1(24.0 + eps(24.0)) [10.0]
@test sol1(48.0 + eps(48.0)) [10.0]
@test sol1(72.0 + eps(72.0)) [10.0]
@test sol1(96.0 + eps(96.0)) [10.0]
@test sol1(120.0 + eps(120.0)) [10.0]
sol2 = solve(prob, Tsit5(), callback = cb)
@test sol2(0.0) == [0.0]
@test sol2(24.0 + eps(24.0)) [10.0]
@test sol2(48.0 + eps(48.0)) [10.0]
@test sol2(72.0 + eps(72.0)) [10.0]
@test sol2(96.0 + eps(96.0)) [10.0]
@test sol2(120.0 + eps(120.0)) [10.0]
30 changes: 30 additions & 0 deletions test/preset_time.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,33 @@ cb = PresetTimeCallback([-0.2], integrator -> begin
end, filter_tstops = false)
sol = solve(prob, Tsit5(), callback = cb)
@test !notcalled

# Test indexes reset
# https://github.com/SciML/DifferentialEquations.jl/issues/1022

function mod(du, u, p, t)
du[1] = -p[1] * u[1]
end

p = [1.0]
u0 = [10.0]
tspan = (0.0, 72.0)

times1 = 0.0:24.0:tspan[2]
times2 = 24.0:24.0:tspan[2]
affect!(integrator) = integrator.u[1] += 10.0
cb1 = PresetTimeCallback(times1, affect!)
cb2 = PresetTimeCallback(times2, affect!)

prob1 = ODEProblem(mod, u0, tspan, p, callback = cb1)
prob2 = ODEProblem(mod, u0, tspan, p)

sol1 = solve(prob1, Tsit5())
sol2 = solve(prob2, Tsit5(), callback = cb1)

@test sol1(0.0) == [10.0]
@test sol1(24.0 + eps(24.0)) [10.0]
@test sol1(48.0 + eps(48.0)) [10.0]
@test sol2(0.0) == [10.0]
@test sol2(24.0 + eps(24.0)) [10.0]
@test sol2(48.0 + eps(48.0)) [10.0]

0 comments on commit 10e0c24

Please sign in to comment.