Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make neurons from Decision Making tutorial much faster with GraphDynamics #484

Merged
merged 8 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ DataFrames = "1.3"
Distributions = "0.25.102"
ExponentialUtilities = "1"
ForwardDiff = "0.10"
GraphDynamics = "0.1.5"
GraphDynamics = "0.2"
Graphs = "1"
Interpolations = "0.14, 0.15"
MetaGraphs = "0.7"
Expand Down
24 changes: 6 additions & 18 deletions src/GraphDynamicsInterop/GraphDynamicsInterop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ using GraphDynamics:
StateIndex,
ParamIndex,
event_times,
calculate_inputs,
connection_index
calculate_inputs

using Random:
Random,
Expand Down Expand Up @@ -256,12 +255,6 @@ function populate_flatgraph(h, g, blox, v, g_i, h_i)
if length(components(blox)) == 1 && only(components(blox)) == blox
h_i += 1
add_subsystem!(h, to_subsystem(blox), Neuroblox.namespaced_nameof(blox))
# add_vertices!(h, 1)
# subsystem = to_subsystem(blox)
# name = Neuroblox.namespaced_nameof(blox)

# set_subsystem!(h, to_subsystem(blox), h_i)
# set_name!(h, name, h_i)
if v isa Dict
@assert !haskey(v, g_i)
v[g_i] = h_i
Expand Down Expand Up @@ -341,14 +334,9 @@ situations where tiny matrices like (e.g. 5x5) get stored as sparse arrays rathe
function graphsystem_from_graph(_g::MetaDiGraph; sparsity_heuristic=1.0, sparse_length_cutoff=0)
check_all_supported_blox(_g)
g = flat_graph(_g)

total_eltype = mapreduce(promote_type, vertices(g)) do i
eltype(get_subsystem(g, i))
end
fix_eltype(s::Subsystem{Name}) where {Name} = convert(Subsystem{Name, total_eltype}, s)


subsystems_and_names_flat = map(vertices(g)) do i
(subsystem = fix_eltype(get_subsystem(g, i)), name = get_name(g, i))
(subsystem = get_subsystem(g, i), name = get_name(g, i))
end
names_flat = map(last, subsystems_and_names_flat)
subsystems_flat = map(first, subsystems_and_names_flat)
Expand Down Expand Up @@ -433,9 +421,9 @@ function graphsystem_from_graph(_g::MetaDiGraph; sparsity_heuristic=1.0, sparse_
end
end
end
states_partitioned = map(v -> map(get_states, v), subsystems)
states_partitioned = map(v -> map(get_states, v), subsystems)
params_partitioned = map(v -> map(get_params, v), subsystems)
names_partitioned = map(v -> map(last, v), subsystems_and_names)
names_partitioned = map(v -> map(last, v), subsystems_and_names)

composite_continuous_events_partitioned = let
if isempty(g.composite_continuous_events_builder)
Expand Down Expand Up @@ -479,5 +467,5 @@ function graphsystem_from_graph(_g::MetaDiGraph; sparsity_heuristic=1.0, sparse_
end
end


end#module GraphDynamicsInterop

217 changes: 73 additions & 144 deletions src/GraphDynamicsInterop/connection_interop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ end
struct BasicConnection <: ConnectionRule
weight::Float64
end
Base.zero(::BasicConnection) = Base.zero(BasicConnection)
Base.zero(::Type{<:BasicConnection}) = BasicConnection(0.0)
function (c::BasicConnection)(blox_src, blox_dst)
(; jcn = c.weight * output(blox_src))
Expand Down Expand Up @@ -177,7 +178,7 @@ function get_connection(
(;conn, names)
end

struct HHConnection_GAP
struct HHConnection_GAP <: ConnectionRule
w::Float64
w_gap::Float64
w_gap_rev::Float64
Expand Down Expand Up @@ -211,7 +212,7 @@ end


#----------------------------------------------
# Kuramoto
# Kuramoto
function get_connection(src::KuramotoOscillator, dst::KuramotoOscillator, kwargs)
(;w_val, name) = generate_weight_param(src, dst, kwargs)
(;conn=BasicConnection(w_val), names=[name])
Expand All @@ -227,37 +228,6 @@ end
#----------------------------------------------
# LIFExci / LIFInh

function blox_wiring_rule!(h, blox::Union{LIFExciNeuron, LIFInhNeuron}, v, kwargs)
evbs = h.composite_discrete_events_builder
i = only(v)
push!(evbs, SpikeAffectEventBuilder(i, Int[], Int[]))
end


function blox_wiring_rule!(h,
blox_src::Union{LIFExciNeuron, LIFInhNeuron},
blox_dst::Union{LIFExciNeuron, LIFInhNeuron},
v_src, v_dst, kwargs)
#this is the fallback method for non-composite blox, hence vi and vj should have only one element
i, j = only(v_src), only(v_dst)
(; w_val, name) = generate_weight_param(blox_src, blox_dst, kwargs)
conn = BasicConnection(w_val)

let evbs = h.composite_discrete_events_builder
idx = findfirst(evb -> (evb isa SpikeAffectEventBuilder) && (evb.idx_src == i), evbs)
if isnothing(idx)
error("SpikeAffectEventBuilder for neuron not found, this indicates its blox wiring rule never ran.")
else
if blox_dst isa LIFExciNeuron
push!(evbs[idx].idx_dsts_exci, j)
elseif blox_dst isa LIFInhNeuron
push!(evbs[idx].idx_dsts_inh, j)
end
end
end
add_edge!(h, i, j, Dict(:conn => conn, :names => [name]))
end

function (c::BasicConnection)(sys_src::Subsystem{LIFExciNeuron},
sys_dst::Union{Subsystem{LIFExciNeuron}, Subsystem{LIFInhNeuron}})
w = c.weight
Expand All @@ -271,111 +241,65 @@ function (c::BasicConnection)(::Subsystem{LIFInhNeuron},
(; jcn = 0.0)
end

struct SpikeAffectEventBuilder
idx_src::Int
idx_dsts_inh::Vector{Int}
idx_dsts_exci::Vector{Int}
const LIFExciInhNeuron = Union{LIFExciNeuron, LIFInhNeuron}
GraphDynamics.has_discrete_events(::Type{LIFExciNeuron}) = true
GraphDynamics.has_discrete_events(::Type{LIFInhNeuron}) = true
function GraphDynamics.discrete_event_condition((; t_refract_end, V, θ)::Subsystem{LIF}, t, _) where {LIF <: LIFExciInhNeuron}
# Triggers when either a refractory period is ending, or the neuron spiked (voltage exceeds threshold θ)
(V > θ) || (t_refract_end == t)
end
function GraphDynamics.apply_discrete_event!(integrator,
states_view_src, params_view_src,
neuron_src::Subsystem{LIF},
foreach_connected_neuron) where {LIF <: LIFExciInhNeuron}
t = integrator.t
if t == neuron_src.t_refract_end # Refreactory period is over
params = params_view_src[]
params_view_src[] = @set params.is_refractory = 0
else # Neuron fired
# Begin refractory period
params_src = params_view_src[]
@reset params_src.t_refract_end = t + params_src.t_refract_duration
@reset params_src.is_refractory = 1

add_tstop!(integrator, params_src.t_refract_end)
params_view_src[] = params_src

struct SpikeAffectEvent{i_src, i_LIFInh, i_LIFExci}
j_src::Int
j_dsts_inh::Vector{Int}
j_dsts_exci::Vector{Int}
end
# Reset the neuron voltage
states_view_src[:V] = params_src.V_reset

function (ev::SpikeAffectEventBuilder)(index_map)
(i_src, j_src) = index_map[ev.idx_src]
i_inh, j_dsts_inh = let v = ev.idx_dsts_inh
if isempty(v)
nothing, Int[]
else
index_map[first(v)][1], map(idx -> index_map[idx][2], v)
# Now apply a function to each connected dst neuron
foreach_connected_neuron() do conn, neuron_dst, states_view_dst, params_view_dst
lif_exci_inh_update_connected_neuron(neuron_src, states_view_src, conn, neuron_dst, states_view_dst)
end
end
i_exci, j_dsts_exci = let v = ev.idx_dsts_exci
if isempty(v)
nothing, Int[]
else
index_map[first(v)][1], map(idx -> index_map[idx][2], v)
end
end
function lif_exci_inh_update_connected_neuron(neuron_src::Subsystem{LIFExciNeuron},
states_view_src,
conn::BasicConnection,
neuron_dst::Subsystem{<:LIFExciInhNeuron},
states_view_dst)
w = conn.weight
# check if the neuron is connected to itself
if states_view_src === states_view_dst
# x is the rise variable for NMDA synapses and it only applies to self-recurrent connections
states_view_dst[:x] += w
end
SpikeAffectEvent{i_src, i_inh, i_exci}(j_src, j_dsts_inh, j_dsts_exci)
states_view_dst[:S_AMPA] += w
nothing
end

function GraphDynamics.discrete_event_condition(states,
params,
connection_matrices,
ev::SpikeAffectEvent{i_src, i_dst_inh, i_dsts_exci},
t) where {i_src, i_dst_inh, i_dsts_exci}
(; j_src) = ev
neuron_src = Subsystem(states[i_src][j_src], params[i_src][j_src])
neuron_src.V > neuron_src.θ
function lif_exci_inh_update_connected_neuron(neuron_src::Subsystem{LIFInhNeuron},
states_view_src,
conn::BasicConnection,
neuron_dst::Subsystem{<:LIFExciInhNeuron},
states_view_dst)
w = conn.weight
states_view_dst[:S_GABA] += w
nothing
end




function GraphDynamics.apply_discrete_event!(integrator,
states::NTuple{Len, Any},
params::NTuple{Len, Any},
connection_matrices,
t,
ev::SpikeAffectEvent{i_src, i_dst_inh, i_dst_exci}
) where {i_src, i_dst_inh, i_dst_exci, Len}
(; j_src, j_dsts_inh, j_dsts_exci) = ev

nc = connection_index(BasicConnection, connection_matrices)

params_src = params[i_src][j_src]
@reset params_src.t_refract_end = t + params_src.t_refract_duration
@reset params_src.is_refractory = 1

params[i_src][j_src] = params_src
add_tstop!(integrator, params_src.t_refract_end)

states_src = states[i_src][j_src]
states[i_src][:V, j_src] = params_src.V_reset
if (states_src isa SubsystemStates{LIFExciNeuron}) && (j_src ∈ j_dsts_exci)
# x is the rise variable for NMDA synapses and it only applies to self-recurrent connections
w = connection_matrices[nc][i_src, i_src][j_src, j_src].weight
states[i_src][:x, j_src] += w
end

if states_src isa SubsystemStates{LIFExciNeuron}
if !isnothing(i_dst_inh)
M = connection_matrices[nc][i_src, i_dst_inh]
for j_dst ∈ j_dsts_inh
w = M[j_src, j_dst].weight
states[i_dst_inh][:S_AMPA, j_dst] += w
end
end
if !isnothing(i_dst_exci)
M = connection_matrices[nc][i_src, i_dst_exci]
for j_dst ∈ j_dsts_exci
w = M[j_src, j_dst].weight
states[i_dst_exci][:S_AMPA, j_dst] += w
end
end
elseif states_src isa SubsystemStates{LIFInhNeuron}
if !isnothing(i_dst_inh)
M = connection_matrices[nc][i_src, i_dst_inh]
for j_dst ∈ j_dsts_inh
w = M[j_src, j_dst].weight
states[i_dst_inh][:S_GABA, j_dst] += w
end
end
if !isnothing(i_dst_exci)
M = connection_matrices[nc][i_src, i_dst_exci]
for j_dst ∈ j_dsts_exci
w = M[j_src, j_dst].weight
states[i_dst_exci][:S_GABA, j_dst] += w
end
end
else
error("this should be unreachable")
end
end

function blox_wiring_rule!(h,
stim::PoissonSpikeTrain,
blox_dst::Union{LIFExciNeuron, LIFInhNeuron},
Expand All @@ -385,7 +309,7 @@ function blox_wiring_rule!(h,
conn = PoissonSpikeConn(w_val, Set(Neuroblox.generate_spike_times(stim)))
add_edge!(h, i, j, Dict(:conn => conn, :names => [name]))
end
struct PoissonSpikeConn
struct PoissonSpikeConn <: ConnectionRule
w::Float64
t_spikes::Set{Float64}
end
Expand All @@ -394,23 +318,29 @@ function ((;w)::PoissonSpikeConn)(stim::Subsystem{PoissonSpikeTrain},
blox_dst::Union{Subsystem{LIFExciNeuron}, Subsystem{LIFInhNeuron}})
(; jcn = 0.0)
end
GraphDynamics.has_discrete_events(::PoissonSpikeConn) = true
GraphDynamics.has_discrete_events(::Type{PoissonSpikeConn}) = true
GraphDynamics.event_times((;t_spikes)::PoissonSpikeConn) = (t_spikes)
GraphDynamics.discrete_event_condition((;t_spikes)::PoissonSpikeConn, t) = (t ∈ t_spikes)

GraphDynamics.has_discrete_events(::Type{PoissonSpikeTrain}) = true
function GraphDynamics.discrete_event_condition(p::Subsystem{PoissonSpikeTrain}, t, foreach_connected_neuron::F) where {F}
# check if any of the downstream connections from p spike at time t.
cond = mapreduce(|, foreach_connected_neuron; init=false) do conn, _, _, _
t ∈ conn.t_spikes
end
end
function GraphDynamics.apply_discrete_event!(integrator,
_, _,
vstates_dst, _,
_::PoissonSpikeConn,
_::Subsystem{PoissonSpikeTrain},
_::Union{Subsystem{LIFExciNeuron}, Subsystem{LIFInhNeuron}})
states = vstates_dst[]
states = @set states.S_AMPA_ext += 1
vstates_dst[] = states
nothing
states_view_src, params_view_src,
neuron_src::Subsystem{PoissonSpikeTrain},
foreach_connected_neuron::F) where {F}
t = integrator.t
foreach_connected_neuron() do conn, neuron_dst, states_view_dst, params_view_dst
# Check each downstream connection, if it's time to spike, increment the downstream neuron's S_AMPA_ext
if t ∈ conn.t_spikes
states_view_dst[:S_AMPA_ext] += 1
end
end
end


components(blox::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}) = blox.parts

issupported(::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}) = true
Expand Down Expand Up @@ -441,7 +371,6 @@ function blox_wiring_rule!(h,
blox_dst::Union{LIFExciCircuitBlox, LIFInhCircuitBlox},
v_src, v_dst, kwargs)
neurons_dst = components(blox_dst)

for (j, neuron_dst) ∈ enumerate(neurons_dst)
blox_wiring_rule!(h, stim, neuron_dst, only(v_src), v_dst[j], kwargs)
end
Expand Down Expand Up @@ -752,7 +681,7 @@ function get_connection(discr_src::Matrisome, discr_dst::Matrisome, kwargs)
MMConn(t_event)
end

struct MMConn{T}
struct MMConn{T} <: ConnectionRule
t_event::T
end

Expand Down Expand Up @@ -828,7 +757,7 @@ function get_connection(discr_src::TAN, discr_dst::Matrisome, kwargs)
(; conn = TAN_M_Conn(w_val, t_event), names=[name])
end

struct TAN_M_Conn
struct TAN_M_Conn <: ConnectionRule
w::Float64
t_event::Float64
end
Expand Down Expand Up @@ -1022,7 +951,7 @@ end

# #-------------------------
# PING Network
struct PINGConnection
struct PINGConnection <: ConnectionRule
w::Float64
V_E::Float64
V_I::Float64
Expand Down
Loading
Loading