Skip to content

Commit

Permalink
pass foreach_connected_subsystem to the event condition functions
Browse files Browse the repository at this point in the history
  • Loading branch information
MasonProtter committed Nov 4, 2024
1 parent c696789 commit 778a05a
Showing 1 changed file with 44 additions and 12 deletions.
56 changes: 44 additions & 12 deletions src/graph_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,22 +212,25 @@ end
#----------------------------------------------------------

function continuous_condition(out, u, t, integrator)
(;params_partitioned, state_types_val) = integrator.p
(;params_partitioned, state_types_val, connection_matrices) = integrator.p
states_partitioned = to_vec_o_states(u.x, state_types_val)
_continuous_condition!(out, states_partitioned, params_partitioned, t)
_continuous_condition!(out, states_partitioned, params_partitioned, connection_matrices, t)
end

function _continuous_condition!(out,
states_partitioned ::NTuple{Len, Any},
params_partitioned ::NTuple{Len, Any},
connection_matrices,
t) where {Len}

idx = 0
@unroll 16 for i 1:Len
if has_continuous_events(eltype(states_partitioned[i]))
for j eachindex(states_partitioned[i])
idx += 1
out[idx] = continuous_event_condition(Subsystem(states_partitioned[i][j], params_partitioned[i][j]), t)
F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices)
sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j])
out[idx] = continuous_event_condition(sys, t, F)
end
end
end
Expand Down Expand Up @@ -291,7 +294,8 @@ tany(f, coll; kwargs...) = tmapreduce(f, |, coll; kwargs...)
@nexprs $Len i -> begin
if has_discrete_events(eltype(states_partitioned[i]))
for j eachindex(states_partitioned[i])
discrete_event_condition(Subsystem(states_partitioned[i][j], params_partitioned[i][j]), t) && return true
F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices)
discrete_event_condition(Subsystem(states_partitioned[i][j], params_partitioned[i][j]), t, F) && return true
end
end
end
Expand Down Expand Up @@ -333,8 +337,8 @@ end
sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j])
sview = @view states_partitioned[i][j]
pview = @view params_partitioned[i][j]
if discrete_event_condition(sys, t)
F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices)
F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices)
if discrete_event_condition(sys, t, F)
if discrete_events_require_inputs(sys)
input = calculate_inputs(Val(i), j, states_partitioned, params_partitioned, connection_matrices)
apply_discrete_event!(integrator, sview, pview, sys, F, input)
Expand Down Expand Up @@ -453,11 +457,36 @@ struct ForeachConnectedSubsystem{k, Len, NConn, S, P, CMs}
end
end

@generated function ((;l,
states_partitioned,
params_partitioned,
connection_matrices)::ForeachConnectedSubsystem{k, Len, NConn})(f::F) where {k, Len, NConn, F}
# @generated function ((;l,
# states_partitioned,
# params_partitioned,
# connection_matrices)::ForeachConnectedSubsystem{k, Len, NConn})(f::F) where {k, Len, NConn, F}
# quote
# @nexprs $Len i -> begin
# @nexprs $NConn nc -> begin
# M = connection_matrices[nc][k, i]
# if M isa NotConnected
# nothing
# else
# for j ∈ eachindex(states_partitioned[i])
# @inbounds conn = M[l, j]
# if !iszero(conn)
# @inbounds states_view_dst = @view states_partitioned[i][j]
# @inbounds params_view_dst = @view params_partitioned[i][j]
# sys_dst = Subsystem(states_view_dst[], params_view_dst[])
# f(conn, sys_dst, states_view_dst, params_view_dst)
# end
# end
# end
# end
# end
# end
# end

@generated function Base.mapreduce(f::F, op::Op, FCS::ForeachConnectedSubsystem{k, Len, NConn}; init) where {k, Len, NConn, F, Op}
quote
(;l, states_partitioned, params_partitioned, connection_matrices) = FCS
state = init
@nexprs $Len i -> begin
@nexprs $NConn nc -> begin
M = connection_matrices[nc][k, i]
Expand All @@ -470,11 +499,14 @@ end
@inbounds states_view_dst = @view states_partitioned[i][j]
@inbounds params_view_dst = @view params_partitioned[i][j]
sys_dst = Subsystem(states_view_dst[], params_view_dst[])
f(conn, sys_dst, states_view_dst, params_view_dst)
res = f(conn, sys_dst, states_view_dst, params_view_dst)
state = op(state, res)
end
end
end
end
end
end
state
end
end
(FCS::ForeachConnectedSubsystem)(f::F) where {F} = mapreduce(f, (_, _) -> nothing, FCS; init=nothing)

0 comments on commit 778a05a

Please sign in to comment.