Skip to content

Commit

Permalink
Support ForeachConnectedSubsystem (#9)
Browse files Browse the repository at this point in the history
* remove unneeded `connection_index` function

* add SubsystemStatesView for the 0D view of a single SubsystemStates

* add ForeachConnectedSubsystem for effects modifying downstream Subsystems

* remove comment

* bump version
  • Loading branch information
MasonProtter authored Nov 4, 2024
1 parent f3e3d30 commit c696789
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 39 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "GraphDynamics"
uuid = "bcd5d0fe-e6b7-4ef1-9848-780c183c7f4c"
version = "0.1.5"
version = "0.2.0"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
20 changes: 2 additions & 18 deletions src/GraphDynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ end
isstochastic,

event_times,
connection_index
ForeachConnectedSubsystem
)

export
Expand Down Expand Up @@ -232,6 +232,7 @@ add methods to this function if a subsystem or connection type has a discrete ev
event_times(::Any) = ()

abstract type ConnectionRule end
Base.zero(::T) where {T <: ConnectionRule} = zero(T)
struct NotConnected <: ConnectionRule end
(::NotConnected)(l, r) = zero(promote_type(eltype(l), eltype(r)))
struct ConnectionMatrix{N, CR, Tup <: NTuple{N, NTuple{N, Union{NotConnected, AbstractMatrix{CR}}}}}
Expand All @@ -245,23 +246,6 @@ Base.getindex(m::ConnectionMatrices, i) = m.matrices[i]
Base.length(m::ConnectionMatrices) = length(m.matrices)
Base.size(m::ConnectionMatrix{N}) where {N} = (N, N)

"""
connection_index(ConnType, M::ConnectionMatrices)
give the first index `n` such that `M[n]` is a `ConnectionMatrix{N, ConnType} where {N}`, or throw an error if no such index exists.
"""
connection_index(::Type{ConnType}, M::ConnectionMatrices) where {ConnType} = _conn_index(ConnType, M.matrices, 1)
function _conn_index(::Type{ConnType}, tup::Tuple, i) where {ConnType}
if first(tup) isa ConnectionMatrix{N, ConnType} where {N}
return i
else
_conn_index(ConnType, Base.tail(tup), i+1)
end
end
@noinline _conn_index(::Type{ConnType}, ::Tuple{}, _) where {ConnType} =
error("ConnectionMatrices did not contain a ConnectionMatrix with connection type ", ConnType)


abstract type GraphSystem end

@kwdef struct ODEGraphSystem{CM <: ConnectionMatrices, S, P, EVT, CDEP, CCEP, Ns, SNM, PNM} <: GraphSystem
Expand Down
103 changes: 92 additions & 11 deletions src/graph_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,12 @@ function _continuous_affect!(integrator,
sview = @view states_partitioned[i][j]
pview = @view params_partitioned[i][j]
sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j])
F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices)
if continuous_events_require_inputs(sys)
input = calculate_inputs(Val(i), j, states_partitioned, params_partitioned, connection_matrices)
apply_continuous_event!(integrator, sview, pview, sys, input)
apply_continuous_event!(integrator, sview, pview, sys, F, input)
else
apply_continuous_event!(integrator, sview, pview, sys)
apply_continuous_event!(integrator, sview, pview, sys, F)
end
end
offset += N
Expand Down Expand Up @@ -326,34 +327,37 @@ end
t) where {Len, NConn}
quote
@nexprs $Len i -> begin
# First we apply events to the states
if has_discrete_events(eltype(states_partitioned[i]))
for j eachindex(states_partitioned[i])
sys_dst = Subsystem(states_partitioned[i][j], params_partitioned[i][j])
sview_dst = @view states_partitioned[i][j]
pview_dst = @view params_partitioned[i][j]
if discrete_event_condition(sys_dst, t)
if discrete_events_require_inputs(sys_dst)
@inbounds for j eachindex(states_partitioned[i])
sys = Subsystem(states_partitioned[i][j], params_partitioned[i][j])
sview = @view states_partitioned[i][j]
pview = @view params_partitioned[i][j]
if discrete_event_condition(sys, t)
F = ForeachConnectedSubsystem{i}(j, states_partitioned, params_partitioned, connection_matrices)
if discrete_events_require_inputs(sys)
input = calculate_inputs(Val(i), j, states_partitioned, params_partitioned, connection_matrices)
apply_discrete_event!(integrator, sview_dst, pview_dst, sys_dst, input)
apply_discrete_event!(integrator, sview, pview, sys, F, input)
else
apply_discrete_event!(integrator, sview_dst, pview_dst, sys_dst)
apply_discrete_event!(integrator, sview, pview, sys, F)
end
end
end
end
# Then we do the connection events
@nexprs $NConn nc -> begin
@nexprs $Len k -> begin
f = _discrete_connection_affect!(Val(i), Val(k), Val(nc), t,
states_partitioned, params_partitioned, connection_matrices,
integrator)
foreach(f, eachindex(states_partitioned[i]))

end
end
end
end
end


function _discrete_connection_affect!(::Val{i}, ::Val{k}, ::Val{nc}, t,
states_partitioned::NTuple{Len, Any},
params_partitioned::NTuple{Len, Any},
Expand Down Expand Up @@ -397,3 +401,80 @@ function _discrete_connection_affect!(::Val{i}, ::Val{k}, ::Val{nc}, t,
end
end
end


#-----------------------------------------------------------------------

"""
ForeachConnectedSubsystem
This is a callable struct which takes in a function, and then calls that function on each subsystem which has a connection leading to it
from some previously specified subsystem.
That is, writing
```julia
F = ForeachConnectedSubsystem{k}(l, states_partitioned, params_partitioned, connection_matrices)
F() do conn, sys_dst, states_view_dst, params_view_dst
[...]
end
```
is like a type stable version of writing
```
for i in eachindex(states_partitioned)
for nc in eachindex(connection_matrices)
M = connection_matrices[nc][i, k]
for j in eachindex(states_partitioned[k])
conn = M[l, j]
if !iszero(conn)
states_view_dst = @view states_partitioned[i][j]
params_view_dst = @view params_partitioned[i][j]
sys_dst = Subsystem(states_view_dst[], params_view_dst[])
[...] # <------- User code here
ends
end
end
end
```
"""
struct ForeachConnectedSubsystem{k, Len, NConn, S, P, CMs}
l::Int
states_partitioned::S
params_partitioned::P
connection_matrices::CMs
function ForeachConnectedSubsystem{k}(l,
states_partitioned::NTuple{Len, Any},
params_partitioned::NTuple{Len, Any},
connection_matrices::ConnectionMatrices{NConn}) where {k, Len, NConn}
S = typeof(states_partitioned)
P = typeof(params_partitioned)
CMs = typeof(connection_matrices)
new{k, Len, NConn, S, P, CMs}(l, states_partitioned, params_partitioned, connection_matrices)
end
end

@generated function ((;l,
states_partitioned,
params_partitioned,
connection_matrices)::ForeachConnectedSubsystem{k, Len, NConn})(f::F) where {k, Len, NConn, F}
quote
@nexprs $Len i -> begin
@nexprs $NConn nc -> begin
M = connection_matrices[nc][k, i]
if M isa NotConnected
nothing
else
for j eachindex(states_partitioned[i])
@inbounds conn = M[l, j]
if !iszero(conn)
@inbounds states_view_dst = @view states_partitioned[i][j]
@inbounds params_view_dst = @view params_partitioned[i][j]
sys_dst = Subsystem(states_view_dst[], params_view_dst[])
f(conn, sys_dst, states_view_dst, params_view_dst)
end
end
end
end
end
end
end
55 changes: 47 additions & 8 deletions src/subsystems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,15 @@ Base.eltype(::Type{<:Subsystem{<:Any, T}}) where {T} = T

#-------------------------------------------------------------------------

@generated function to_vec_o_states(state_data::NTuple{Len, Any}, ::Val{StateTypes}) where {Len, StateTypes}
state_types = StateTypes.parameters
Expr(:tuple, (:(VectorOfSubsystemStates{$(state_types[i])}(state_data[$i])) for i 1:Len)...)
end

struct VectorOfSubsystemStates{States, Mat <: AbstractMatrix} <: AbstractVector{States}
data::Mat
end
function VectorOfSubsystemStates{SubsystemStates{Name, T, NamedTuple{snames, Tup}}}(
v::AbstractMatrix{U}
) where {Name, T, U, snames, Tup}
function VectorOfSubsystemStates{SubsystemStates{Name, T, NamedTuple{snames, Tup}}}(v::AbstractMatrix{U}) where {Name, T, U, snames, Tup}
V = promote_type(T,U)
States = SubsystemStates{Name, V, NamedTuple{snames, NTuple{length(snames), V}}}
VectorOfSubsystemStates{States, typeof(v)}(v)
Expand All @@ -217,8 +220,8 @@ Base.size(v::VectorOfSubsystemStates{States}) where {States} = (size(v.data, 2),
@inbounds States(view(v.data, 1:l, idx))
end

@noinline function sym_not_found_error(::Type{SubsystemStates{Name, T, NamedTuple{names}}}, s::Symbol) where {Name, T, names}
error("SubsystemStates{$Name} does not have a field $s, valid fields are $names")
@noinline function sym_not_found_error(::Type{S}, s::Symbol) where {S<:SubsystemStates}
error("$S does not have a field $s")
end

@propagate_inbounds function Base.getindex(v::VectorOfSubsystemStates{States}, s::Symbol, idx::Integer) where {States <: SubsystemStates}
Expand Down Expand Up @@ -247,7 +250,43 @@ end
v.data[i, idx] = val
end

@generated function to_vec_o_states(state_data::NTuple{Len, Any}, ::Val{StateTypes}) where {Len, StateTypes}
state_types = StateTypes.parameters
Expr(:tuple, (:(VectorOfSubsystemStates{$(state_types[i])}(state_data[$i])) for i 1:Len)...)


#-------------------------------------------------------------------------
struct SubsystemStatesView{States, Mat <: AbstractMatrix} <: AbstractArray{States, 0}
data::Mat
idx::Int
end
@propagate_inbounds function Base.view(v::VectorOfSubsystemStates{States, Mat}, idx::Int) where {States, Mat}
l = length(States)
@boundscheck checkbounds(v.data, 1:l, idx)
SubsystemStatesView{States, Mat}(v.data, idx)
end
Base.size(::SubsystemStatesView) = ()
function Base.getindex(v::SubsystemStatesView{States}) where {States <: SubsystemStates}
l = length(States)
@inbounds States(view(v.data, 1:l, v.idx))
end
@propagate_inbounds function Base.getindex(v::SubsystemStatesView{States}, s::Symbol) where {States <: SubsystemStates}
i = state_ind(States, s)
if isnothing(i)
sym_not_found_error(States, s)
end
@inbounds v.data[i, v.idx]
end

@propagate_inbounds function Base.setindex!(v::SubsystemStatesView{States}, state::States) where {States <: SubsystemStates}
l = length(States)
idx = v.idx
@inbounds v.data[1:l, idx] .= Tuple(state)
v
end

@propagate_inbounds function Base.setindex!(v::SubsystemStatesView{States}, val, s::Symbol) where {States <: SubsystemStates}
i = state_ind(States, s)
if isnothing(i)
sym_not_found_error(States, s)
end
@inbounds v.data[i, v.idx] = val
v
end
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ valueof(x) = x

# this just makes it so that I can easily replace all uses of `@inbounds ex` with just `ex`.
macro inbounds(ex)
# ex
#esc(ex)
esc(:($Base.@inbounds $ex))
end

Expand Down

2 comments on commit c696789

@MasonProtter
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/118674

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.0 -m "<description of version>" c696789b6c53e650e5cf7d17cec396b85ecf5777
git push origin v0.2.0

Please sign in to comment.