diff --git a/allensdk/brain_observatory/ecephys/ecephys_session_api/ecephys_nwb_session_api.py b/allensdk/brain_observatory/ecephys/ecephys_session_api/ecephys_nwb_session_api.py index c4e16f3af..55c32465d 100644 --- a/allensdk/brain_observatory/ecephys/ecephys_session_api/ecephys_nwb_session_api.py +++ b/allensdk/brain_observatory/ecephys/ecephys_session_api/ecephys_nwb_session_api.py @@ -11,8 +11,6 @@ from allensdk.brain_observatory.nwb.nwb_api import NwbApi import allensdk.brain_observatory.ecephys.nwb # noqa Necessary to import pyNWB namespaces from allensdk.brain_observatory.ecephys import get_unit_filter_value -from allensdk.brain_observatory.ecephys.write_nwb.__main__ import \ - remove_invalid_spikes_from_units color_triplet_re = re.compile(r"\[(-{0,1}\d*\.\d*,\s*)*(-{0,1}\d*\.\d*)\]") @@ -329,7 +327,6 @@ def _get_full_units_table(self) -> pd.DataFrame: units = units[units["presence_ratio"] >= self.presence_ratio_minimum] units = units[units["isi_violations"] <= self.isi_violations_maximum] - units = remove_invalid_spikes_from_units(units) return units def get_metadata(self): diff --git a/allensdk/brain_observatory/ecephys/write_nwb/__main__.py b/allensdk/brain_observatory/ecephys/write_nwb/__main__.py index 9304a0676..47b0cc423 100644 --- a/allensdk/brain_observatory/ecephys/write_nwb/__main__.py +++ b/allensdk/brain_observatory/ecephys/write_nwb/__main__.py @@ -1,6 +1,6 @@ import logging import sys -from typing import Dict +from typing import Any, Dict, List, Tuple from pathlib import Path, PurePath import multiprocessing as mp from functools import partial @@ -187,79 +187,43 @@ def scale_amplitudes(spike_amplitudes, templates, spike_templates, scale_factor= return spike_amplitudes -def remove_invalid_spikes( - row: pd.Series, - times_key: str = "spike_times", - amps_key: str = "spike_amplitudes" -) -> pd.Series: - """ Given a row from a units table, ensure that invalid spike times and - corresponding amplitudes are removed. Also ensure the spikes are sorted - ascending in time. +def filter_and_sort_spikes(spike_times_mapping: Dict[int, np.ndarray], + spike_amplitudes_mapping: Dict[int, np.ndarray]) -> Tuple[Dict[int, np.ndarray], Dict[int, np.ndarray]]: + """Filter out invalid spike timepoints and sort spike data + (times + amplitudes) by times. Parameters ---------- - row : a row representing a single sorted unit - times_key : name of column containing spike times - amps_key : name of column containing spike amplitudes + spike_times_mapping : Dict[int, np.ndarray] + Keys: unit identifiers, Values: spike time arrays + spike_amplitudes_mapping : Dict[int, np.ndarray] + Keys: unit identifiers, Values: spike amplitude arrays Returns ------- - A version of the input row, with spike times sorted and invalid times - removed - - Notes - ----- - This function is needed because currently released NWB files might have - invalid spike times. It can be removed if these files are updated. - + Tuple[Dict[int, np.ndarray], Dict[int, np.ndarray]] + A tuple containing filtered and sorted spike_times_mapping and + spike_amplitudes_mapping data. """ + sorted_spike_times_mapping = {} + sorted_spike_amplitudes_mapping = {} - out = row.copy(deep=True) - - spike_times = np.array(out.pop(times_key)) - amps = np.array(out.pop(amps_key)) - - valid = spike_times >= 0 - spike_times = spike_times[valid] - amps = amps[valid] - - order = np.argsort(spike_times) - out[times_key] = spike_times[order] - out[amps_key] = amps[order] - - return out - - -def remove_invalid_spikes_from_units( - units: pd.DataFrame, - times_key: str = "spike_times", - amps_key: str = "spike_amplitudes" -) -> pd.DataFrame: - """ Given a units table, ensure that invalid spike times and - corresponding amplitudes are removed. Also ensure the spikes are sorted - ascending in time. - - Parameters - ---------- - units : A units table - times_key : name of column containing spike times - amps_key : name of column containing spike amplitudes + for unit_id, _ in spike_times_mapping.items(): + spike_times = spike_times_mapping[unit_id] + spike_amplitudes = spike_amplitudes_mapping[unit_id] - Returns - ------- - A version of the input table, with spike times sorted and invalid times - removed + valid = spike_times >= 0 + filtered_spike_times = spike_times[valid] + filtered_spike_amplitudes = spike_amplitudes[valid] - Notes - ----- - This function is needed because currently released NWB files might have - invalid spike times. It can be removed if these files are updated. + order = np.argsort(filtered_spike_times) + sorted_spike_times = filtered_spike_times[order] + sorted_spike_amplitudes = filtered_spike_amplitudes[order] - """ + sorted_spike_times_mapping[unit_id] = sorted_spike_times + sorted_spike_amplitudes_mapping[unit_id] = sorted_spike_amplitudes - remover = partial( - remove_invalid_spikes, times_key=times_key, amps_key=amps_key) - return units.apply(remover, axis=1) + return (sorted_spike_times_mapping, sorted_spike_amplitudes_mapping) def group_1d_by_unit(data, data_unit_map, local_to_global_unit_map=None): @@ -670,32 +634,43 @@ def write_probewise_lfp_files(probes, session_start_time, pool_size=3): return output_paths -def add_probewise_data_to_nwbfile(nwbfile, probes): - """ Adds channel and spike data for a single probe to the session-level nwb file. +ParsedProbeData = Tuple[pd.DataFrame, # unit_tables + Dict[int, np.ndarray], # spike_times + Dict[int, np.ndarray], # spike_amplitudes + Dict[int, np.ndarray]] # mean_waveforms + + +def parse_probes_data(probes: List[Dict[str, Any]]) -> ParsedProbeData: + """Given a list of probe dictionaries specifying data file locations, load + and parse probe data into intermediate data structures needed for adding + probe data to an nwbfile. + + Parameters + ---------- + probes : List[Dict[str, Any]] + A list of dictionaries (one entry for each probe), where each probe + dictionary contains metadata (id, name, sampling_rate, etc...) as well + as filepaths pointing to where probe lfp data can be found. + + Returns + ------- + ParsedProbeData : Tuple[...] + unit_tables : pd.DataFrame + A table containing unit metadata from all probes. + spike_times : Dict[int, np.ndarray] + Keys: unit identifiers, Values: spike time arrays + spike_amplitudes : Dict[int, np.ndarray] + Keys: unit identifiers, Values: spike amplitude arrays + mean_waveforms : Dict[int, np.ndarray] + Keys: unit identifiers, Values: mean waveform arrays """ - channel_tables = {} unit_tables = [] spike_times = {} spike_amplitudes = {} mean_waveforms = {} for probe in probes: - logging.info(f'found probe {probe["id"]} with name {probe["name"]}') - - if probe.get("temporal_subsampling_factor", None) is not None: - probe["lfp_sampling_rate"] = probe["lfp_sampling_rate"] / probe["temporal_subsampling_factor"] - - nwbfile, probe_nwb_device, probe_nwb_electrode_group = add_probe_to_nwbfile( - nwbfile, - probe_id=probe["id"], - description=probe["name"], - sampling_rate=probe["sampling_rate"], - lfp_sampling_rate=probe["lfp_sampling_rate"], - has_lfp_data=probe["lfp"] is not None - ) - - channel_tables[probe["id"]] = prepare_probewise_channel_table(probe['channels'], probe_nwb_electrode_group) unit_tables.append(pd.DataFrame(probe['units'])) local_to_global_unit_map = {unit['cluster_id']: unit['id'] for unit in probe['units']} @@ -714,22 +689,52 @@ def add_probewise_data_to_nwbfile(nwbfile, probes): scale_factor=probe["amplitude_scale_factor"] )) + units_table = pd.concat(unit_tables).set_index(keys='id', drop=True) + + return (units_table, spike_times, spike_amplitudes, mean_waveforms) + + +def add_probewise_data_to_nwbfile(nwbfile, probes): + """ Adds channel and spike data for a single probe to the session-level nwb file. + """ + + channel_tables = {} + + for probe in probes: + logging.info(f'found probe {probe["id"]} with name {probe["name"]}') + + if probe.get("temporal_subsampling_factor", None) is not None: + probe["lfp_sampling_rate"] = probe["lfp_sampling_rate"] / probe["temporal_subsampling_factor"] + + nwbfile, probe_nwb_device, probe_nwb_electrode_group = add_probe_to_nwbfile( + nwbfile, + probe_id=probe["id"], + description=probe["name"], + sampling_rate=probe["sampling_rate"], + lfp_sampling_rate=probe["lfp_sampling_rate"], + has_lfp_data=probe["lfp"] is not None + ) + + channel_tables[probe["id"]] = prepare_probewise_channel_table(probe['channels'], probe_nwb_electrode_group) + electrodes_table = fill_df(pd.concat(list(channel_tables.values()))) nwbfile.electrodes = pynwb.file.ElectrodeTable().from_dataframe(electrodes_table, name='electrodes') - units_table = pd.concat(unit_tables).set_index(keys='id', drop=True) - units_table = remove_invalid_spikes_from_units(units_table) + + units_table, spike_times, spike_amplitudes, mean_waveforms = parse_probes_data(probes) nwbfile.units = pynwb.misc.Units.from_dataframe(fill_df(units_table), name='units') + sorted_spike_times, sorted_spike_amplitudes = filter_and_sort_spikes(spike_times, spike_amplitudes) + add_ragged_data_to_dynamic_table( table=nwbfile.units, - data=spike_times, + data=sorted_spike_times, column_name="spike_times", column_description="times (s) of detected spiking events", ) add_ragged_data_to_dynamic_table( table=nwbfile.units, - data=spike_amplitudes, + data=sorted_spike_amplitudes, column_name="spike_amplitudes", column_description="amplitude (s) of detected spiking events" ) diff --git a/allensdk/test/brain_observatory/ecephys/test_write_nwb.py b/allensdk/test/brain_observatory/ecephys/test_write_nwb.py index 1f2611df1..17e131045 100644 --- a/allensdk/test/brain_observatory/ecephys/test_write_nwb.py +++ b/allensdk/test/brain_observatory/ecephys/test_write_nwb.py @@ -695,61 +695,100 @@ def test_read_spike_amplitudes_to_dictionary(tmpdir_factory, spike_amplitudes, t assert np.allclose(expected_amplitudes[3:], obtained[1]) -def test_remove_invalid_spikes(): - row = pd.Series({ - "spike_times": np.array([0, 1, 2, -1, 5, 4]), - "spike_amplitudes": np.arange(6), - "a": "b" - }, name=1000) - - expct = pd.Series({ - "spike_times": np.array([0, 1, 2, 4, 5]), - "spike_amplitudes": np.array([0, 1, 2, 5, 4]), - "a": "b" - }, name=1000) - - obt = write_nwb.remove_invalid_spikes(row) - assert np.allclose(obt["spike_times"], expct["spike_times"]) - assert np.allclose(obt["spike_amplitudes"], expct["spike_amplitudes"]) - - -def test_remove_invalid_spikes_from_units(): - - units = pd.DataFrame({ - "spike_times": [ - np.array([0, 1, 2, -1, 5, 4]), - np.arange(2), - np.arange(3) - ], - "spike_amplitudes": [ - np.arange(6), - np.arange(2), - np.arange(3) - ], - "a": ["b", "c", "d"] - }, index=pd.Index(name="id", data=[5, 6, 7])) - - expct = pd.DataFrame({ - "spike_times": [ - np.array([0, 1, 2, 4, 5]), - np.arange(2), - np.arange(3) - ], - "spike_amplitudes": [ - np.array([0, 1, 2, 5, 4]), - np.arange(2), - np.arange(3) - ], - "a": ["b", "c", "d"] - }, index=pd.Index(name="id", data=[5, 6, 7])) - - obt = write_nwb.remove_invalid_spikes_from_units(units) - for (_, expct_row), (_, obt_row) in zip(expct.iterrows(), obt.iterrows()): - assert np.allclose(obt_row["spike_times"], expct_row["spike_times"]) - assert np.allclose( - obt_row["spike_amplitudes"], - expct_row["spike_amplitudes"]) - assert obt_row["a"] == expct_row["a"] +@pytest.mark.parametrize("spike_times_mapping, spike_amplitudes_mapping, expected", [ + + ({12345: np.array([0, 1, 2, -1, 5, 4])}, # spike_times_mapping + + {12345: np.array([0, 1, 2, 3, 4, 5])}, # spike_amplitudes_mapping + + ({12345: np.array([0, 1, 2, 4, 5])}, # expected + {12345: np.array([0, 1, 2, 5, 4])})), + + ({12345: np.array([0, 1, 2, -1, 5, 4]), # spike_times_mapping + 54321: np.array([5, 4, 3, -1, 6])}, + + {12345: np.array([0, 1, 2, 3, 4, 5]), # spike_amplitudes_mapping + 54321: np.array([0, 1, 2, 3, 4])}, + + ({12345: np.array([0, 1, 2, 4, 5]), # expected + 54321: np.array([3, 4, 5, 6])}, + {12345: np.array([0, 1, 2, 5, 4]), + 54321: np.array([2, 1, 0, 4])})), +]) +def test_filter_and_sort_spikes(spike_times_mapping, spike_amplitudes_mapping, expected): + expected_spike_times, expected_spike_amplitudes = expected + + obtained_spike_times, obtained_spike_amplitudes = write_nwb.filter_and_sort_spikes(spike_times_mapping, + spike_amplitudes_mapping) + + np.testing.assert_equal(obtained_spike_times, expected_spike_times) + np.testing.assert_equal(obtained_spike_amplitudes, expected_spike_amplitudes) + + +@pytest.mark.parametrize('roundtrip', [True, False]) +@pytest.mark.parametrize('probes, parsed_probe_data, expected', [ + ([{"id": 1234, + "name": "probeA", + "sampling_rate": 29999.9655245905, + "lfp_sampling_rate": 2499.99712704921, + "temporal_subsampling_factor": 2.0, + "lfp": {}, + "spike_times_path": "/dummy_path", + "spike_clusters_files": "/dummy_path", + "mean_waveforms_path": "/dummy_path", + "channels": [{"id": 1, + "probe_id": 1234, + "valid_data": True, + "local_index": 0, + "a": 42.0}, + {"id": 2, + "probe_id": 1234, + "valid_data": True, + "local_index": 1, + "a": 84.0}], + "units": [{"id": 777, + "local_index": 7, + "quality": "good", + "a": 0.5, + "b": 5}, + {"id": 778, + "local_index": 9, + "quality": "noise", + "a": 1.0, + "b": 10}]}], + + (pd.DataFrame({"id": [777, 778], "local_index": [7, 9], # units_table + "a": [0.5, 1.0], "b": [5, 10]}).set_index(keys='id', drop=True), + {777: np.array([0., 1., 2., -1., 5., 4.]), # spike_times + 778: np.array([5., 4., 3., -1., 6.])}, + {777: np.array([0., 1., 2., 3., 4., 5.]), # spike_amplitudes + 778: np.array([0., 1., 2., 3., 4.])}, + {777: np.array([1., 2., 3., 4., 5., 6.]), # mean_waveforms + 778: np.array([1., 2., 3., 4., 5.])}), + + pd.DataFrame({"id": [777, 778], "local_index": [7, 9], # units_table + "a": [0.5, 1.0], "b": [5, 10], + "spike_times": [[0., 1., 2., 4., 5.], [3., 4., 5., 6.]], + "spike_amplitudes": [[0., 1., 2., 5., 4.], [2., 1., 0., 4.]], + "waveform_mean": [[1., 2., 3., 4., 5., 6.], [1., 2., 3., 4., 5.]]} + ).set_index(keys='id', drop=True)), +]) +def test_add_probewise_data_to_nwbfile(monkeypatch, nwbfile, roundtripper, + roundtrip, probes, parsed_probe_data, + expected): + + def mock_parse_probes_data(probes): + return parsed_probe_data + + monkeypatch.setattr(write_nwb, 'parse_probes_data', mock_parse_probes_data) + nwbfile = write_nwb.add_probewise_data_to_nwbfile(nwbfile, probes) + + if roundtrip: + obt = roundtripper(nwbfile, EcephysNwbSessionApi) + else: + obt = EcephysNwbSessionApi.from_nwbfile(nwbfile) + + pd.testing.assert_frame_equal(obt.nwbfile.units.to_dataframe(), expected) @pytest.mark.parametrize('roundtrip', [True, False])