Skip to content

Commit

Permalink
Fix and refactor ecephys nwb spike filter and sorting
Browse files Browse the repository at this point in the history
Previously (#1365), filtering and sorting of ecephys spike
data was implemented for ecephys data. That PR added
filtering/sorting of spike data when writing and loading
from ecephys nwbfiles. It turns out the filtering/sorting
implementation on the nwb writing side was assuming
data in the units table that had yet to be added,
resulting in errors when trying to write ecephys
nwb files.

This commit 1) fixes the spike filtering and sorting
functionality 2) Removes filtering/sorting when loading
from ecephys nwbfiles 3) applies some refactoring to more
easily test adding probe data (e.g. unit tables)
to the ecephys nwbfiles and 4) adds tests.

Relates to: #1510
  • Loading branch information
njmei committed Apr 23, 2020
1 parent 40d59f1 commit 614e0ab
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 142 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from allensdk.brain_observatory.nwb.nwb_api import NwbApi
import allensdk.brain_observatory.ecephys.nwb # noqa Necessary to import pyNWB namespaces
from allensdk.brain_observatory.ecephys import get_unit_filter_value
from allensdk.brain_observatory.ecephys.write_nwb.__main__ import \
remove_invalid_spikes_from_units

color_triplet_re = re.compile(r"\[(-{0,1}\d*\.\d*,\s*)*(-{0,1}\d*\.\d*)\]")

Expand Down Expand Up @@ -329,7 +327,6 @@ def _get_full_units_table(self) -> pd.DataFrame:
units = units[units["presence_ratio"] >= self.presence_ratio_minimum]
units = units[units["isi_violations"] <= self.isi_violations_maximum]

units = remove_invalid_spikes_from_units(units)
return units

def get_metadata(self):
Expand Down
173 changes: 89 additions & 84 deletions allensdk/brain_observatory/ecephys/write_nwb/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import sys
from typing import Dict
from typing import Any, Dict, List, Tuple
from pathlib import Path, PurePath
import multiprocessing as mp
from functools import partial
Expand Down Expand Up @@ -187,79 +187,43 @@ def scale_amplitudes(spike_amplitudes, templates, spike_templates, scale_factor=
return spike_amplitudes


def remove_invalid_spikes(
row: pd.Series,
times_key: str = "spike_times",
amps_key: str = "spike_amplitudes"
) -> pd.Series:
""" Given a row from a units table, ensure that invalid spike times and
corresponding amplitudes are removed. Also ensure the spikes are sorted
ascending in time.
def filter_and_sort_spikes(spike_times_mapping: Dict[int, np.ndarray],
spike_amplitudes_mapping: Dict[int, np.ndarray]) -> Tuple[Dict[int, np.ndarray], Dict[int, np.ndarray]]:
"""Filter out invalid spike timepoints and sort spike data
(times + amplitudes) by times.
Parameters
----------
row : a row representing a single sorted unit
times_key : name of column containing spike times
amps_key : name of column containing spike amplitudes
spike_times_mapping : Dict[int, np.ndarray]
Keys: unit identifiers, Values: spike time arrays
spike_amplitudes_mapping : Dict[int, np.ndarray]
Keys: unit identifiers, Values: spike amplitude arrays
Returns
-------
A version of the input row, with spike times sorted and invalid times
removed
Notes
-----
This function is needed because currently released NWB files might have
invalid spike times. It can be removed if these files are updated.
Tuple[Dict[int, np.ndarray], Dict[int, np.ndarray]]
A tuple containing filtered and sorted spike_times_mapping and
spike_amplitudes_mapping data.
"""
sorted_spike_times_mapping = {}
sorted_spike_amplitudes_mapping = {}

out = row.copy(deep=True)

spike_times = np.array(out.pop(times_key))
amps = np.array(out.pop(amps_key))

valid = spike_times >= 0
spike_times = spike_times[valid]
amps = amps[valid]

order = np.argsort(spike_times)
out[times_key] = spike_times[order]
out[amps_key] = amps[order]

return out


def remove_invalid_spikes_from_units(
units: pd.DataFrame,
times_key: str = "spike_times",
amps_key: str = "spike_amplitudes"
) -> pd.DataFrame:
""" Given a units table, ensure that invalid spike times and
corresponding amplitudes are removed. Also ensure the spikes are sorted
ascending in time.
Parameters
----------
units : A units table
times_key : name of column containing spike times
amps_key : name of column containing spike amplitudes
for unit_id, _ in spike_times_mapping.items():
spike_times = spike_times_mapping[unit_id]
spike_amplitudes = spike_amplitudes_mapping[unit_id]

Returns
-------
A version of the input table, with spike times sorted and invalid times
removed
valid = spike_times >= 0
filtered_spike_times = spike_times[valid]
filtered_spike_amplitudes = spike_amplitudes[valid]

Notes
-----
This function is needed because currently released NWB files might have
invalid spike times. It can be removed if these files are updated.
order = np.argsort(filtered_spike_times)
sorted_spike_times = filtered_spike_times[order]
sorted_spike_amplitudes = filtered_spike_amplitudes[order]

"""
sorted_spike_times_mapping[unit_id] = sorted_spike_times
sorted_spike_amplitudes_mapping[unit_id] = sorted_spike_amplitudes

remover = partial(
remove_invalid_spikes, times_key=times_key, amps_key=amps_key)
return units.apply(remover, axis=1)
return (sorted_spike_times_mapping, sorted_spike_amplitudes_mapping)


def group_1d_by_unit(data, data_unit_map, local_to_global_unit_map=None):
Expand Down Expand Up @@ -670,32 +634,43 @@ def write_probewise_lfp_files(probes, session_start_time, pool_size=3):
return output_paths


def add_probewise_data_to_nwbfile(nwbfile, probes):
""" Adds channel and spike data for a single probe to the session-level nwb file.
ParsedProbeData = Tuple[pd.DataFrame, # unit_tables
Dict[int, np.ndarray], # spike_times
Dict[int, np.ndarray], # spike_amplitudes
Dict[int, np.ndarray]] # mean_waveforms


def parse_probes_data(probes: List[Dict[str, Any]]) -> ParsedProbeData:
"""Given a list of probe dictionaries specifying data file locations, load
and parse probe data into intermediate data structures needed for adding
probe data to an nwbfile.
Parameters
----------
probes : List[Dict[str, Any]]
A list of dictionaries (one entry for each probe), where each probe
dictionary contains metadata (id, name, sampling_rate, etc...) as well
as filepaths pointing to where probe lfp data can be found.
Returns
-------
ParsedProbeData : Tuple[...]
unit_tables : pd.DataFrame
A table containing unit metadata from all probes.
spike_times : Dict[int, np.ndarray]
Keys: unit identifiers, Values: spike time arrays
spike_amplitudes : Dict[int, np.ndarray]
Keys: unit identifiers, Values: spike amplitude arrays
mean_waveforms : Dict[int, np.ndarray]
Keys: unit identifiers, Values: mean waveform arrays
"""

channel_tables = {}
unit_tables = []
spike_times = {}
spike_amplitudes = {}
mean_waveforms = {}

for probe in probes:
logging.info(f'found probe {probe["id"]} with name {probe["name"]}')

if probe.get("temporal_subsampling_factor", None) is not None:
probe["lfp_sampling_rate"] = probe["lfp_sampling_rate"] / probe["temporal_subsampling_factor"]

nwbfile, probe_nwb_device, probe_nwb_electrode_group = add_probe_to_nwbfile(
nwbfile,
probe_id=probe["id"],
description=probe["name"],
sampling_rate=probe["sampling_rate"],
lfp_sampling_rate=probe["lfp_sampling_rate"],
has_lfp_data=probe["lfp"] is not None
)

channel_tables[probe["id"]] = prepare_probewise_channel_table(probe['channels'], probe_nwb_electrode_group)
unit_tables.append(pd.DataFrame(probe['units']))

local_to_global_unit_map = {unit['cluster_id']: unit['id'] for unit in probe['units']}
Expand All @@ -714,22 +689,52 @@ def add_probewise_data_to_nwbfile(nwbfile, probes):
scale_factor=probe["amplitude_scale_factor"]
))

units_table = pd.concat(unit_tables).set_index(keys='id', drop=True)

return (units_table, spike_times, spike_amplitudes, mean_waveforms)


def add_probewise_data_to_nwbfile(nwbfile, probes):
""" Adds channel and spike data for a single probe to the session-level nwb file.
"""

channel_tables = {}

for probe in probes:
logging.info(f'found probe {probe["id"]} with name {probe["name"]}')

if probe.get("temporal_subsampling_factor", None) is not None:
probe["lfp_sampling_rate"] = probe["lfp_sampling_rate"] / probe["temporal_subsampling_factor"]

nwbfile, probe_nwb_device, probe_nwb_electrode_group = add_probe_to_nwbfile(
nwbfile,
probe_id=probe["id"],
description=probe["name"],
sampling_rate=probe["sampling_rate"],
lfp_sampling_rate=probe["lfp_sampling_rate"],
has_lfp_data=probe["lfp"] is not None
)

channel_tables[probe["id"]] = prepare_probewise_channel_table(probe['channels'], probe_nwb_electrode_group)

electrodes_table = fill_df(pd.concat(list(channel_tables.values())))
nwbfile.electrodes = pynwb.file.ElectrodeTable().from_dataframe(electrodes_table, name='electrodes')
units_table = pd.concat(unit_tables).set_index(keys='id', drop=True)
units_table = remove_invalid_spikes_from_units(units_table)

units_table, spike_times, spike_amplitudes, mean_waveforms = parse_probes_data(probes)
nwbfile.units = pynwb.misc.Units.from_dataframe(fill_df(units_table), name='units')

sorted_spike_times, sorted_spike_amplitudes = filter_and_sort_spikes(spike_times, spike_amplitudes)

add_ragged_data_to_dynamic_table(
table=nwbfile.units,
data=spike_times,
data=sorted_spike_times,
column_name="spike_times",
column_description="times (s) of detected spiking events",
)

add_ragged_data_to_dynamic_table(
table=nwbfile.units,
data=spike_amplitudes,
data=sorted_spike_amplitudes,
column_name="spike_amplitudes",
column_description="amplitude (s) of detected spiking events"
)
Expand Down
Loading

0 comments on commit 614e0ab

Please sign in to comment.