Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ilkilic committed Nov 22, 2024
1 parent 43eb7b4 commit b63047b
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 187 deletions.
200 changes: 44 additions & 156 deletions bluepyemodel/emodel_pipeline/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import matplotlib.font_manager
import matplotlib.pyplot as plt
import numpy
from bluepyopt.ephys.objectives import SingletonWeightObjective
from bluepyopt.ephys.protocols import SweepProtocol
from bluepyopt.ephys.recordings import CompRecording
from bluepyopt.ephys.stimuli import NrnSquarePulse
Expand All @@ -44,22 +43,23 @@
from bluepyemodel.emodel_pipeline.plotting_utils import get_experimental_FI_curve_for_plotting
from bluepyemodel.emodel_pipeline.plotting_utils import get_impedance
from bluepyemodel.emodel_pipeline.plotting_utils import get_ordered_currentscape_keys
from bluepyemodel.emodel_pipeline.plotting_utils import get_original_protocol_name
from bluepyemodel.emodel_pipeline.plotting_utils import get_recording_names
from bluepyemodel.emodel_pipeline.plotting_utils import get_simulated_FI_curve_for_plotting
from bluepyemodel.emodel_pipeline.plotting_utils import get_sinespec_evaluator
from bluepyemodel.emodel_pipeline.plotting_utils import get_title
from bluepyemodel.emodel_pipeline.plotting_utils import get_traces_names_and_float_responses
from bluepyemodel.emodel_pipeline.plotting_utils import get_traces_ylabel
from bluepyemodel.emodel_pipeline.plotting_utils import get_voltage_currents_from_files
from bluepyemodel.emodel_pipeline.plotting_utils import plot_fi_curves
from bluepyemodel.emodel_pipeline.plotting_utils import rel_to_abs_amplitude
from bluepyemodel.evaluation.efel_feature_bpem import eFELFeatureBPEM
from bluepyemodel.emodel_pipeline.plotting_utils import save_fig
from bluepyemodel.emodel_pipeline.plotting_utils import update_evaluator
from bluepyemodel.evaluation.evaluation import compute_responses
from bluepyemodel.evaluation.evaluation import get_evaluator_from_access_point
from bluepyemodel.evaluation.evaluator import PRE_PROTOCOLS
from bluepyemodel.evaluation.evaluator import add_recordings_to_evaluator
from bluepyemodel.evaluation.evaluator import define_protocol
from bluepyemodel.evaluation.evaluator import soma_loc
from bluepyemodel.evaluation.protocol_configuration import ProtocolConfiguration
from bluepyemodel.evaluation.protocols import ThresholdBasedProtocol
from bluepyemodel.evaluation.utils import define_bAP_feature
from bluepyemodel.evaluation.utils import define_bAP_protocol
Expand Down Expand Up @@ -89,14 +89,6 @@
}


def save_fig(figures_dir, figure_name, dpi=100):
"""Save a matplotlib figure"""
p = Path(figures_dir) / figure_name
plt.savefig(str(p), dpi=dpi, bbox_inches="tight")
plt.close("all")
plt.clf()


def optimisation(
optimiser,
emodel,
Expand Down Expand Up @@ -1243,29 +1235,17 @@ def plot_IV_curves(
if efel_settings is None:
efel_settings = bluepyefe.tools.DEFAULT_EFEL_SETTINGS.copy()

# get simulation data points matching the experimental data points
sim_amp_points = []
for emodel in emodels:
cells = read_extraction_output_cells(emodel.emodel_metadata.emodel)
if cells is None:
continue
exp_peak, exp_vd = extract_experimental_data_for_IV_curve(
cells, efel_settings, prot_name, n_bin
)
lower_bound = -100
upper_bound = 300

sim_amp_points += exp_peak["amp_rel"] + exp_vd["amp_rel"]
# Generate amplitude points
sim_amp_points = list(numpy.linspace(lower_bound, upper_bound, n_bin))

sim_amp_points = list(set(sim_amp_points))
sim_amp_points = [int(value) for value in sim_amp_points]
# add missing features (if any) to evaluator
updated_evaluator = fill_in_IV_curve_evaluator(
evaluator, efel_settings, prot_name, sim_amp_points
)

updated_evaluator.fitness_protocols["main_protocol"].execution_order = (
updated_evaluator.fitness_protocols["main_protocol"].compute_execution_order()
)

emodels = compute_responses(
access_point,
updated_evaluator,
Expand Down Expand Up @@ -1417,141 +1397,29 @@ def plot_FI_curves_comparison(
make_dir(figures_dir)

emodel_name, cells = None, None

def load_cells_for_emodel(emodel):
nonlocal cells, emodel_name

if custom_bluepyefe_cells_pklpath:
updated_evaluator = copy.deepcopy(evaluator)
for emodel in emodels:
# do not re-extract data if the emodel is the same as previously
if custom_bluepyefe_cells_pklpath is not None:
if cells is None:
cells = read_extraction_output(custom_bluepyefe_cells_pklpath)
return cells is not None
if cells is None:
continue

if emodel_name != emodel.emodel_metadata.emodel or cells is None:
# experimental FI curve
expt_data = get_experimental_FI_curve_for_plotting(cells, prot_name, n_bin=n_bin)
elif emodel_name != emodel.emodel_metadata.emodel or cells is None:
# take extraction data from pickle file and rearange it for plotting
cells = read_extraction_output_cells(emodel.emodel_metadata.emodel)
emodel_name = emodel.emodel_metadata.emodel
return cells is not None
return True

def get_original_protocol_name(prot_name, evaluator):
"""Retrieve the protocol name as defined by the user, preserving the original case"""
for protocol_name in evaluator.fitness_protocols["main_protocol"].protocols:
if prot_name.lower() in protocol_name.lower():
return protocol_name
return prot_name

def update_evaluator(expt_amp_rel, prot_name, evaluator):
"""update evaluator with new simulation protocols."""
updated_evaluator = copy.deepcopy(evaluator)
for amp_rel in expt_amp_rel:
protocol_name = f"{prot_name.split('_')[0]}_{int(amp_rel)}"
protocol = evaluator.fitness_protocols["main_protocol"].protocols[prot_name]
if protocol_name not in evaluator.fitness_protocols["main_protocol"].protocols:
stimuli = [
{
"holding_current": protocol.stimuli[0].holding_current,
"threshold_current": protocol.stimuli[0].threshold_current,
"thresh_perc": int(amp_rel),
"delay": protocol.stimuli[0].delay,
"duration": protocol.stimuli[0].duration,
"totduration": protocol.stimuli[0].total_duration,
}
]
recordings = [
{
"type": "CompRecording",
"name": f"{protocol_name}.soma.v",
"location": "soma",
"variable": "v",
}
]

my_protocol_config = ProtocolConfiguration(
name=protocol_name, stimuli=stimuli, recordings=recordings, validation=False
)
main_protocol = updated_evaluator.fitness_protocols["main_protocol"]
main_protocol.protocols[protocol_name] = define_protocol(my_protocol_config)

for objective in evaluator.fitness_calculator.objectives:
feat = objective.features[0]
if (
protocol_name.split("_", maxsplit=1)[0] in feat.recording_names[""]
and "mean_frequency" in feat.efel_feature_name
):
feat_name = f"{protocol_name}.soma.v.mean_frequency"
amp_rel = float(protocol_name.split("_")[1])
amp = float(feat.recording_names[""].split(".")[0].split("_")[-1])
updated_evaluator.fitness_calculator.objectives.append(
SingletonWeightObjective(
feat_name,
eFELFeatureBPEM(
feat_name,
efel_feature_name="mean_frequency",
recording_names={"": f"{protocol_name}.soma.v"},
stim_start=feat.stim_start,
stim_end=feat.stim_end,
exp_mean=1.0, # fodder: not used
exp_std=1.0, # fodder: not used
threshold=feat.threshold,
stimulus_current=feat.stimulus_current() * amp_rel / amp,
weight=1.0,
),
1.0,
)
)
break
return updated_evaluator

def plot_fi_curves(expt_data, sim_data, figures_dir, emodel, write_fig):
"""Plot and save the FI curves."""
(
expt_amp_rel,
expt_freq_rel,
expt_freq_rel_err,
expt_amp,
expt_freq_abs,
expt_freq_abs_err,
) = expt_data
simulated_amp_rel, simulated_amp, simulated_freq = sim_data

_, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 3))
ax[0].errorbar(
expt_amp_rel,
expt_freq_rel,
yerr=expt_freq_rel_err,
marker="o",
color="grey",
label="experiment",
)
ax[0].plot(simulated_amp_rel, simulated_freq, "o", color="blue", label="model")
ax[0].set_xlabel("Amplitude (% of rheobase)")
ax[0].set_ylabel("Mean Frequency (Hz)")
ax[0].set_title("FI curve (relative amplitude)")
ax[0].legend()

ax[1].errorbar(
expt_amp,
expt_freq_abs,
yerr=expt_freq_abs_err,
marker="o",
color="grey",
label="experiment",
)
ax[1].plot(simulated_amp, simulated_freq, "o", color="blue", label="model")
ax[1].set_xlabel("Amplitude (nA)")
ax[1].set_ylabel("Voltage (mV)")
ax[1].set_title("FI curve (absolute amplitude)")
ax[1].legend()
if cells is None:
continue

if write_fig:
filename = f"{emodel.emodel_metadata.as_string(emodel.seed)}__FI_curve_comparison.pdf"
save_fig(figures_dir, filename)
# experimental FI curve
expt_data = get_experimental_FI_curve_for_plotting(cells, prot_name, n_bin=n_bin)

updated_evaluator = copy.deepcopy(evaluator)
for emodel in emodels:
if not load_cells_for_emodel(emodel):
continue
emodel_name = emodel.emodel_metadata.emodel

expt_data = get_experimental_FI_curve_for_plotting(cells, prot_name, n_bin=n_bin)
expt_data_amp_rel = expt_data[0]
prot_name_original = get_original_protocol_name(prot_name, evaluator)
updated_evaluator = update_evaluator(
Expand All @@ -1564,7 +1432,27 @@ def plot_fi_curves(expt_data, sim_data, figures_dir, emodel, write_fig):

emodels = compute_responses(access_point, updated_evaluator, mapper, seeds)
for emodel in emodels:
expt_data = get_experimental_FI_curve_for_plotting(cells, prot_name, n_bin=n_bin)
# do not re-extract data if the emodel is the same as previously
if custom_bluepyefe_cells_pklpath is not None:
if cells is None:
cells = read_extraction_output(custom_bluepyefe_cells_pklpath)
if cells is None:
continue

# experimental FI curve
expt_data = get_experimental_FI_curve_for_plotting(cells, prot_name, n_bin=n_bin)
elif emodel_name != emodel.emodel_metadata.emodel or cells is None:
# take extraction data from pickle file and rearange it for plotting
cells = read_extraction_output_cells(emodel.emodel_metadata.emodel)
emodel_name = emodel.emodel_metadata.emodel
if cells is None:
continue

# experimental FI curve
expt_data = get_experimental_FI_curve_for_plotting(cells, prot_name, n_bin=n_bin)

emodel_name = emodel.emodel_metadata.emodel

sim_data = get_simulated_FI_curve_for_plotting(
updated_evaluator, emodel.responses, prot_name
)
Expand Down
Loading

0 comments on commit b63047b

Please sign in to comment.