Skip to content

Commit

Permalink
Merge pull request #176 from staticfloat/sf/independently_linearized
Browse files Browse the repository at this point in the history
Rework `LinearizingSavingCallback`
  • Loading branch information
ChrisRackauckas authored Dec 18, 2023
2 parents 4abfc7f + 98db71d commit e434aae
Show file tree
Hide file tree
Showing 7 changed files with 658 additions and 183 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ LinearAlgebra = "1"
Markdown = "1"
NLsolve = "4.2"
ODEProblemLibrary = "0.1"
OrdinaryDiffEq = "6.14"
OrdinaryDiffEq = "6.63"
Parameters = "0.12"
QuadGK = "2"
RecipesBase = "0.7, 0.8, 1.0"
Expand Down
1 change: 1 addition & 0 deletions src/DiffEqCallbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ include("manifold.jl")
include("domain.jl")
include("stepsizelimiters.jl")
include("function_caller.jl")
include("independentlylinearizedutils.jl")
include("saving.jl")
include("integrating.jl")
include("integrating_sum.jl")
Expand Down
311 changes: 311 additions & 0 deletions src/independentlylinearizedutils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
using SciMLBase

export IndependentlyLinearizedSolution

"""
IndependentlyLinearizedSolutionChunks
When constructing an `IndependentlyLinearizedSolution` via the `IndependentlyLinearizingCallback`,
we use this indermediate structure to reduce allocations and collect the unknown number of timesteps
that the solve will generate.
"""
mutable struct IndependentlyLinearizedSolutionChunks{T, S}
t_chunks::Vector{Vector{T}}
u_chunks::Vector{Vector{Vector{S}}}
time_masks::Vector{BitMatrix}

chunk_size::Int

# Index of next write into the last chunk
u_offsets::Vector{Int}
t_offset::Int

function IndependentlyLinearizedSolutionChunks{T, S}(num_us::Int,
chunk_size::Int = 100) where {T, S}
return new([Vector{T}(undef, chunk_size)],
[[Vector{S}(undef, chunk_size)] for _ in 1:num_us],
[BitMatrix(undef, chunk_size, num_us)],
chunk_size,
[1 for _ in 1:num_us],
1,
)
end
end

function Base.isempty(ilsc::IndependentlyLinearizedSolutionChunks)
return length(ilsc.t_chunks) == 1 && ilsc.t_offset == 1
end

function get_chunks(ilsc::IndependentlyLinearizedSolutionChunks{T, S}) where {T, S}
# Check if we need to allocate new `t` chunk
if ilsc.t_offset > ilsc.chunk_size
push!(ilsc.t_chunks, Vector{T}(undef, ilsc.chunk_size))
push!(ilsc.time_masks, BitMatrix(undef, ilsc.chunk_size, length(ilsc.u_offsets)))
ilsc.t_offset = 1
end

# Check if we need to allocate any new `u` chunks (but only for those with `u_mask`)
for (u_idx, u_chunks) in enumerate(ilsc.u_chunks)
if ilsc.u_offsets[u_idx] > ilsc.chunk_size
push!(u_chunks, Vector{S}(undef, ilsc.chunk_size))
ilsc.u_offsets[u_idx] = 1
end
end

# return the last chunk for each
return ilsc.t_chunks[end],
ilsc.time_masks[end],
[u_chunks[end] for u_chunks in ilsc.u_chunks]
end

function get_prev_t(ilsc::IndependentlyLinearizedSolutionChunks)
# Are we set to write into the beginning of a new chunk?
if ilsc.t_offset == 1
# Try to reach back to the previous chunk
if length(ilsc.t_chunks) == 1
# If this is the absolute first timepoint, just return -Inf
return -Inf
end
# Otherwise, return the last element of the previous chunk
return ilsc.t_chunks[end-1][end]
end
# Otherwise return the previous element of the current chunk
return ilsc.t_chunks[end][ilsc.t_offset-1]
end

"""
store!(ilsc::IndependentlyLinearizedSolutionChunks, t, u, u_mask)
Store a new `u` vector into our `ilsc`, but only the values identified by the
given `u_mask`.
"""
function store!(ilsc::IndependentlyLinearizedSolutionChunks{T, S},
t::T,
u::AbstractVector{S},
u_mask::BitVector) where {T, S}
# If `t` has been stored before, drop it.
# We don't store duplicate timepoints, even though the solver sometimes does.
if get_prev_t(ilsc) == t
return
end

# Otherwise, let's get new chunks to deal with
ts, time_mask, us = get_chunks(ilsc)

# Store into the chunks, gated by `u_mask`
for u_idx in 1:length(us)
if u_mask[u_idx]
us[u_idx][ilsc.u_offsets[u_idx]] = u[u_idx]
ilsc.u_offsets[u_idx] += 1
end
end
ts[ilsc.t_offset] = t
time_mask[ilsc.t_offset, :] .= u_mask
ilsc.t_offset += 1
end



"""
IndependentlyLinearizedSolution
Efficient datastructure that holds a set of independently linearized solutions
(obtained via the `LinearizingSavingCallback`) with related, but slightly
different time vectors. Stores a single time vector with a packed `BitMatrix`
denoting which `u` vectors are sampled at which timepoints. Provides an
efficient `iterate()` method that can be used to reconstruct coherent views
of the state variables at all timepoints, as well as an efficient `sample!()`
method that can sample at arbitrary timesteps.
"""
mutable struct IndependentlyLinearizedSolution{T, S}
# All timepoints, shared by all `us`
ts::Vector{T}

# Ragged matrix of `us`
us::Vector{Vector{S}}

# Bitmatrix denoting which time indices are used for which us.
time_mask::BitMatrix

# Temporary object used during construction, will be set to `nothing` at the end.
ilsc::Union{Nothing,IndependentlyLinearizedSolutionChunks{T,S}}
end
# Helper function to create an ILS wrapped around an in-progress ILSC
function IndependentlyLinearizedSolution(ilsc::IndependentlyLinearizedSolutionChunks{T,S}) where {T,S}
ils = IndependentlyLinearizedSolution(
T[],
Vector{S}[],
BitMatrix(undef, 0,0),
ilsc,
)
return ils
end
# Automatically create an ILS wrapped around an ILSC from a `prob`
function IndependentlyLinearizedSolution(prob::SciMLBase.AbstractDEProblem)
return IndependentlyLinearizedSolution(
IndependentlyLinearizedSolutionChunks{eltype(prob.tspan),eltype(prob.u0)}(length(prob.u0))
)
end

Base.size(ils::IndependentlyLinearizedSolution) = size(ils.time_mask)
Base.length(ils::IndependentlyLinearizedSolution) = length(ils.ts)

function finish!(ils::IndependentlyLinearizedSolution)
function trim_chunk(chunks::Vector, offset)
chunks = [chunk for chunk in chunks]
if eltype(chunks) <: Vector
chunks[end] = chunks[end][1:(offset - 1)]
elseif eltype(chunks) <: BitMatrix
chunks[end] = chunks[end][1:(offset - 1), :]
else
error(eltype(chunks))
end
if isempty(chunks[end])
pop!(chunks)
end
return chunks
end

ilsc = ils.ilsc::IndependentlyLinearizedSolutionChunks
ts = vcat(trim_chunk(ilsc.t_chunks, ilsc.t_offset)...)
time_mask = vcat(trim_chunk(ilsc.time_masks, ilsc.t_offset)...)
us = [vcat(trim_chunk(ilsc.u_chunks[u_idx], ilsc.u_offsets[u_idx])...)
for u_idx in 1:length(ilsc.u_chunks)]

# Sanity-check lengths
if length(ts) != size(time_mask, 1)
throw(ArgumentError("`length(ts)` must equal `size(time_mask, 1)`!"))
end

# All time masks must start and end with `1`:
if !all(@view time_mask[1, :]) || !all(@view time_mask[end, :])
throw(ArgumentError("Time mask must start and end with 1s!"))
end

# Length of time mask enable vectors must equal the lengths of our `us`:
time_mask_lens = vec(sum(time_mask; dims = 1))
if !all(time_mask_lens .== length.(us))
throw(ArgumentError("Time mask must indicate same length as `us` ($(time_mask_lens) != $(length.(us)))"))
end

# Update our struct, release the `ilsc`
ils.ilsc = nothing
ils.ts = ts
ils.us = us
ils.time_mask = time_mask
return ils
end

struct ILSStateCursor
# The index into `us` that identifies this state
u_idx::Int

idx_u₀::Int
# idx_u₁ is by definition idx_u₀ + 1, so we don't store it

# Time index of u₀
idx_t₀::Int
# Time index of u₁
idx_t₁::Int
end
# Helper to construct a state cursor off of an ILS, at a particular time index
function ILSStateCursor(ils::IndependentlyLinearizedSolution, u_idx::Int, t_idx::Int = 1)
cursor = ILSStateCursor(u_idx,
1,
1,
findfirst(@view ils.time_mask[2:end, u_idx]) + 1)
return seek_forward(ils, cursor, ils.ts[t_idx])
end
function interpolate(ils::IndependentlyLinearizedSolution{T},
cursor::ILSStateCursor,
t::T) where {T}
u₀ = ils.us[cursor.u_idx][cursor.idx_u₀]
u₁ = ils.us[cursor.u_idx][cursor.idx_u₀ + 1]
t₀ = ils.ts[cursor.idx_t₀]
t₁ = ils.ts[cursor.idx_t₁]
return (u₁ - u₀) / (t₁ - t₀) * (t - t₀) + u₀
end

"""
seek_forward(ils::IndependentlyLinearizedSolution, cursor::ILSStateCursor, t_target)
Seek the given `cursor` forward until it contains `t_target`. Does not seek backward, use `seek()`
for the more general formulation, this form is optimized for the inner loop of `iterate()`.
"""
function seek_forward(ils::IndependentlyLinearizedSolution{T},
cursor::ILSStateCursor,
t_target::T) where {T}
# We do not test `t_start` because we don't support seeking backward here
while ils.ts[cursor.idx_t₁] < t_target
next_t = findfirst(@view ils.time_mask[(cursor.idx_t₁ + 1):end, cursor.u_idx])
cursor = ILSStateCursor(cursor.u_idx,
cursor.idx_u₀ + 1,
cursor.idx_t₁,
next_t + cursor.idx_t₁)
end
return cursor
end

function Base.seek(ils::IndependentlyLinearizedSolution{T},
cursor::ILSStateCursor,
t_target::T) where {T}
# If we need to rewind, just start from the beginning
if t_target < ils.ts[cursor.idx_t₀]
cursor = ILSStateCursor(ils, cursor.u_idx)
end
return seek_forward(ils, cursor, t_target)
end

function Base.seek(ils::IndependentlyLinearizedSolution, t_idx = 1)
return [ILSStateCursor(ils, u_idx, t_idx) for u_idx in 1:length(ils.us)]
end
function iteration_state(ils::IndependentlyLinearizedSolution{T, S}, t_idx=1) where {T, S}
# Nice little hack so we don't have to allocate `u` over and over again
u = S[S(0) for _ in ils.us]
cursors = seek(ils, t_idx)
return (t_idx, u, cursors)
end
function Base.iterate(ils::IndependentlyLinearizedSolution{T, S},
(t_idx, u, cursors) = iteration_state(ils)) where {T, S}
if t_idx > length(ils.ts)
return nothing
end

# We iteratively inch `offsets` forward, efficiently reconstructing a full set of `u`'s
t = ils.ts[t_idx]
for u_idx in 1:length(u)
cursors[u_idx] = seek_forward(ils, cursors[u_idx], t)
u[u_idx] = interpolate(ils, cursors[u_idx], t)
end

return (t, u), (t_idx + 1, u, cursors)
end

"""
sample!(out::Matrix{S}, ils::IndependentlyLinearizedSolution, ts::Vector{T})
Batch-sample `ils` at the given timepoints, storing into `out`.
"""
function sample!(out::Matrix{S},
ils::IndependentlyLinearizedSolution{T, S},
ts::AbstractVector{T}) where {T, S}
sampled_size = (length(ts), length(ils.us))
if size(out) != sampled_size
throw(ArgumentError("Output size ($(size(out))) != sampled size ($(sampled_size))"))
end

# We don't make use of `iterate` here because we're sampling at arbitrary timepoints
cursors = seek(ils)
for (t_idx, t) in enumerate(ts)
for u_idx in 1:length(ils.us)
cursors[u_idx] = seek_forward(ils, cursors[u_idx], t)
out[t_idx, u_idx] = interpolate(ils, cursors[u_idx], t)
end
end
return out
end
function sample(ils::IndependentlyLinearizedSolution{T, S},
ts::AbstractVector{T}) where {T, S}
out = Matrix{S}(undef, length(ts), length(ils.us))
return sample!(out, ils, ts)
end
Loading

0 comments on commit e434aae

Please sign in to comment.