Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reopen #390: Update callback #440

Merged
merged 16 commits into from
May 6, 2024
Merged
2 changes: 1 addition & 1 deletion src/TrixiParticles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export InitialCondition
export WeaklyCompressibleSPHSystem, EntropicallyDampedSPHSystem, TotalLagrangianSPHSystem,
BoundarySPHSystem, DEMSystem, BoundaryDEMSystem
export InfoCallback, SolutionSavingCallback, DensityReinitializationCallback,
PostprocessCallback, StepsizeCallback
PostprocessCallback, StepsizeCallback, UpdateCallback
export ContinuityDensity, SummationDensity
export PenaltyForceGanzenmueller
export SchoenbergCubicSplineKernel, SchoenbergQuarticSplineKernel,
Expand Down
1 change: 1 addition & 0 deletions src/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ include("solution_saving.jl")
include("density_reinit.jl")
include("post_process.jl")
include("stepsize.jl")
include("update.jl")
130 changes: 130 additions & 0 deletions src/callbacks/update.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
struct UpdateCallback{I}
interval::I
end

"""
UpdateCallback(; interval::Integer, dt=0.0)

Callback to update quantities either at the end of every `interval` time steps or
in intervals of `dt` in terms of integration time by adding additional `tstops`
(note that this may change the solution).

# Keywords
- `interval=1`: Update quantities at the end of every `interval` time steps.
- `dt`: Update quantities in regular intervals of `dt` in terms of integration time
by adding additional `tstops` (note that this may change the solution).
"""
function UpdateCallback(; interval::Integer=-1, dt=0.0)
if dt > 0 && interval !== -1
throw(ArgumentError("Setting both interval and dt is not supported!"))
end

# Update in intervals in terms of simulation time
if dt > 0
interval = Float64(dt)

# Update every time step (default)
elseif interval == -1
interval = 1
end

update_callback! = UpdateCallback(interval)

if dt > 0
# Add a `tstop` every `dt`, and save the final solution.
return PeriodicCallback(update_callback!, dt,
initialize=initial_update!,
save_positions=(false, false))
else
# The first one is the `condition`, the second the `affect!`
return DiscreteCallback(update_callback!, update_callback!,
initialize=initial_update!,
save_positions=(false, false))
end
end

# `initialize`
function initial_update!(cb, u, t, integrator)
# The `UpdateCallback` is either `cb.affect!` (with `DiscreteCallback`)
# or `cb.affect!.affect!` (with `PeriodicCallback`).
# Let recursive dispatch handle this.

initial_update!(cb.affect!, u, t, integrator)
end

initial_update!(cb::UpdateCallback, u, t, integrator) = cb(integrator)

# `condition`
function (update_callback!::UpdateCallback)(u, t, integrator)
(; interval) = update_callback!

return condition_integrator_interval(integrator, interval)
end

# `affect!`
function (update_callback!::UpdateCallback)(integrator)
t = integrator.t
semi = integrator.p
v_ode, u_ode = integrator.u.x

# Update quantities that are stored in the systems. These quantities (e.g. pressure)
# still have the values from the last stage of the previous step if not updated here.
update_systems_and_nhs(v_ode, u_ode, semi, t)

# Other updates might be added here later (e.g. Transport Velocity Formulation).
# @trixi_timeit timer() "update open boundary" foreach_system(semi) do system
# update_open_boundary_eachstep!(system, v_ode, u_ode, semi, t)
# end
#
# @trixi_timeit timer() "update TVF" foreach_system(semi) do system
# update_transport_velocity_eachstep!(system, v_ode, u_ode, semi, t)
# end

# Tell OrdinaryDiffEq that `u` has been modified
u_modified!(integrator, true)

return integrator
end

function Base.show(io::IO, cb::DiscreteCallback{<:Any, <:UpdateCallback})
@nospecialize cb # reduce precompilation time
print(io, "UpdateCallback(interval=", cb.affect!.interval, ")")
end

function Base.show(io::IO,
cb::DiscreteCallback{<:Any,
<:PeriodicCallbackAffect{<:UpdateCallback}})
@nospecialize cb # reduce precompilation time
print(io, "UpdateCallback(dt=", cb.affect!.affect!.interval, ")")
end

function Base.show(io::IO, ::MIME"text/plain",
cb::DiscreteCallback{<:Any, <:UpdateCallback})
@nospecialize cb # reduce precompilation time

if get(io, :compact, false)
show(io, cb)
else
update_cb = cb.affect!
setup = [
"interval" => update_cb.interval,
]
summary_box(io, "UpdateCallback", setup)
end
end

function Base.show(io::IO, ::MIME"text/plain",
cb::DiscreteCallback{<:Any,
<:PeriodicCallbackAffect{<:UpdateCallback}})
@nospecialize cb # reduce precompilation time

if get(io, :compact, false)
show(io, cb)
else
update_cb = cb.affect!.affect!
setup = [
"dt" => update_cb.interval,
]
summary_box(io, "UpdateCallback", setup)
end
end
1 change: 1 addition & 0 deletions test/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
include("info.jl")
include("stepsize.jl")
include("postprocess.jl")
include("update.jl")
include("solution_saving.jl")
end
48 changes: 48 additions & 0 deletions test/callbacks/update.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
@testset verbose=true "UpdateCallback" begin
@testset verbose=true "show" begin
# Default
callback0 = UpdateCallback()

show_compact = "UpdateCallback(interval=1)"
@test repr(callback0) == show_compact

show_box = """
┌──────────────────────────────────────────────────────────────────────────────────────────────────┐
│ UpdateCallback │
│ ══════════════ │
│ interval: ……………………………………………………… 1 │
└──────────────────────────────────────────────────────────────────────────────────────────────────┘"""
@test repr("text/plain", callback0) == show_box

callback1 = UpdateCallback(interval=11)

show_compact = "UpdateCallback(interval=11)"
@test repr(callback1) == show_compact

show_box = """
┌──────────────────────────────────────────────────────────────────────────────────────────────────┐
│ UpdateCallback │
│ ══════════════ │
│ interval: ……………………………………………………… 11 │
└──────────────────────────────────────────────────────────────────────────────────────────────────┘"""
@test repr("text/plain", callback1) == show_box

callback2 = UpdateCallback(dt=1.2)

show_compact = "UpdateCallback(dt=1.2)"
@test repr(callback2) == show_compact

show_box = """
┌──────────────────────────────────────────────────────────────────────────────────────────────────┐
│ UpdateCallback │
│ ══════════════ │
│ dt: ……………………………………………………………………… 1.2 │
└──────────────────────────────────────────────────────────────────────────────────────────────────┘"""
@test repr("text/plain", callback2) == show_box
end

@testset "Illegal Input" begin
error_str = "Setting both interval and dt is not supported!"
@test_throws ArgumentError(error_str) UpdateCallback(dt=0.1, interval=1)
end
end
Loading