Skip to content

Commit

Permalink
Lussac v2.0 b6
Browse files Browse the repository at this point in the history
Merge pull request #15 from BarbourLab/dev
  • Loading branch information
DradeAW authored Jun 10, 2024
2 parents dad5261 + e8b75c5 commit 02eb5c1
Show file tree
Hide file tree
Showing 40 changed files with 884 additions and 406 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ conda create -n lussac python=3.11 # Must be >= 3.10
conda activate lussac

# Install Lussac.
pip install -e .[dev]
pip install -e ".[dev]"

# To upgrade Lussac.
git pull
Expand All @@ -54,4 +54,4 @@ You can find the documentation [here](https://lussac.readthedocs.io/).

## Migration from Lussac1

Lussac2 is not backwards-compatible with Lussac1. We advise you to make a new conda environment, and to remake your `params.json` file (which is also not backwards-compatible).
Lussac2 is not backwards-compatible with Lussac1. We advise you to make a new conda environment, and to remake your `params.json` file (which is also not backwards-compatible).
19 changes: 10 additions & 9 deletions docs/source/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ This module will label units as belonging to a certain category if they meet som

This module takes as a key the name of the category, and as a value a dictionary containing the criteria. Each criterion return a value for each unit, and a minimum and/or maximum can be set.

For some parameters, you can specify the parameters for :code:`wvf_extraction` (parameters given to the SpikeInterface method :code:`core.extract_waveforms`) and :code:`filter` (a list :code:`[freq_min, freq_max]` for Gaussian bandpass filtering).
You can also specify the parameters for :code:`wvf_extraction` (`max_spikes_per_unit`, `ms_before`, `ms_after`, `filter`).

- :code:`firing_rate`: returns the mean firing rate of the unit (in Hz).
- :code:`contamination`: returns the estimated contamination of the unit (between 0 and 1; 0 being pure). The :code:`refractory_period = [censored_period, refractory_period]` has to be set (in ms).
Expand All @@ -32,6 +32,12 @@ Here is an example for categorizing complex-spikes from more regular spikes (cer
"units_categorization": {
"all": { // Categorize all units.
"wvf_extraction": { // Parameters for the waveform extraction.
"ms_before": 1.0,
"ms_after": 1.5,
"max_spikes_per_unit": 500,
"filter": [150.0, 7000.0] // Gaussian bandpass filter with cutoffs at 150 and 7,000 Hz.
},
"CS": { // Criteria for complex-spikes category.
"firing_rate": { // Firing rate < 5 Hz
"max": 5.0
Expand All @@ -52,13 +58,7 @@ Here is an example for categorizing complex-spikes from more regular spikes (cer
},
"SNR": {
"peak_sign": "neg", // Example of an SI parameter.
"min": 2.5,
"wvf_extraction": { // Specifying how to extract the waveforms for SNR computing.
"ms_before": 1.0,
"ms_after": 1.0,
"max_spikes_per_unit": 500
},
"filter": [150, 7_000] // Gaussian bandpass filter with cutoffs at 150 and 7,000 Hz.
"min": 2.5
}
}
}
Expand Down Expand Up @@ -127,6 +127,7 @@ Example of units removal
"remove_bad_units": {
"CS": { // Remove complex-spike units with contamination > 35%
"wvf_extraction": {}, // If you want to change how the waveforms are extracted.
"contamination": {
"refractory_period": [1.5, 25.0],
"max": 0.35
Expand Down Expand Up @@ -303,7 +304,7 @@ The :code:`export_to_sigui` module
----------------------------------
| This module will export all sortings in their current state to the SpikeInterface GUI format (if :code:`merge_sortings` was called before, will only export the merged sorting).
| This is equivalent to just a :code:`WaveformExtractor` with some extra arguments.
| This is equivalent to just a :code:`SortingAnalyzer` with some extra arguments.
This module's parameters are:
Expand Down
11 changes: 9 additions & 2 deletions docs/source/params.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ There are multiple keys:

- :code:`recording_extractor`: the name of the recording extractor in SpikeInterface.
- :code:`extractor_params`: the parameters for the recording extractor.
- :code:`probe_file`: the path to the probe file (in the `ProbeInterface <https://github.com/SpikeInterface/probeinterface>`_ format) if not already in the recording (optional).
- :code:`probe_file`: (optional) the path to the probe file (in the `ProbeInterface <https://github.com/SpikeInterface/probeinterface>`_ format) if not already in the recording.
- :code:`preprocessing`: (optional) a :code:`dict` mapping a function in `spikeinterface.preprocessing` to a :code:`dict` containing the arguments for that function. "remove_bad_channels" is also added (combining "detect_bad_channels" and "remove_channels")

Example for a binary file:
""""""""""""""""""""""""""
Expand All @@ -95,6 +96,8 @@ Example for a binary file:
"offset_to_uV": 0.0
},
"probe_file": "$PARAMS_FOLDER/probe.json"
// No preprocessing
}
Example for a SpikeGLX recording:
"""""""""""""""""""""""""""""""""
Expand All @@ -106,8 +109,12 @@ Example for a SpikeGLX recording:
"extractor_params": {
"folder_path": "$PARAMS_FOLDER/recording",
"stream_id": "imec0.ap"
}
},
// Probe is already loaded with the SpikeGLXRecordingExtractor.
"preprocessing": {
"phase_shift": {},
"remove_bad_channels": {}
}
}
Creating the probe file for geometry:
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = ["pip>=21.3", "setuptools>=65.6.3"]

[project]
name = "lussac"
version = "2.0.0b5.post1"
version = "2.0.0b6"
authors = [
{name="Aurélien Wyngaard", email="aurelien.wyngaard@gmail.com"}
]
Expand All @@ -23,7 +23,8 @@ dependencies = [
"tqdm >= 4.64.0",
"requests >= 2.28.0",
"overrides >= 7.3.1",
"spikeinterface >= 0.100.0, < 0.101.0"
"psutil",
"spikeinterface >= 0.101.0rc0"
]

[project.scripts]
Expand Down
1 change: 1 addition & 0 deletions src/lussac/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .lussac_data import LussacData, MonoSortingData, MultiSortingsData
from .lussac_params import LussacParams
from .module import LussacModule, MonoSortingModule, MultiSortingsModule
from .module_factory import ModuleFactory
from .pipeline import LussacPipeline
Expand Down
36 changes: 29 additions & 7 deletions src/lussac/core/lussac_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import spikeinterface.core as si
import spikeinterface.curation as scur
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre


class LussacData:
Expand Down Expand Up @@ -51,6 +52,9 @@ def __init__(self, recording: si.BaseRecording, sortings: dict[str, si.BaseSorti

self.recording = recording
self.sortings = {name: scur.remove_excess_spikes(sorting.remove_empty_units(), recording) for name, sorting in sortings.items()}
for name, sorting in self.sortings.items():
sorting.annotate(name=name)

params['lussac']['pipeline'] = self._format_params(params['lussac']['pipeline'])
self.params = params
self._tmp_directory = self._setup_tmp_directory(params['lussac']['tmp_folder'])
Expand Down Expand Up @@ -146,9 +150,10 @@ def _sanity_check(self) -> None:

# Check that spike trains are valid.
spike_vector = sorting.to_spike_vector()
assert spike_vector['sample_index'][0] >= 0
assert spike_vector['sample_index'][-1] < self.recording.get_num_frames()
assert np.all(np.diff(spike_vector['sample_index']) >= 0)
if len(spike_vector) > 0:
assert spike_vector['sample_index'][0] >= 0
assert spike_vector['sample_index'][-1] < self.recording.get_num_frames()
assert np.all(np.diff(spike_vector['sample_index']) >= 0)

@staticmethod
def _setup_probe(recording: si.BaseRecording, filename: str) -> si.BaseRecording:
Expand All @@ -170,6 +175,7 @@ def _setup_probe(recording: si.BaseRecording, filename: str) -> si.BaseRecording
def _load_recording(params: dict) -> si.BaseRecording:
"""
Loads the recording from the given parameters.
If specified, will apply some pre-processing steps.
@param params: dict
A dictionary containing Lussac's recording parameters.
Expand All @@ -178,7 +184,26 @@ def _load_recording(params: dict) -> si.BaseRecording:
"""

recording_extractor = se.extractorlist.get_recording_extractor_from_name(params['recording_extractor'])
return recording_extractor(**params['extractor_params'])
recording = recording_extractor(**params['extractor_params'])

if 'probe_file' in params:
recording = LussacData._setup_probe(recording, str(pathlib.Path(params['probe_file']).absolute()))

if 'preprocessing' in params and isinstance(params['preprocessing'], dict):
for preprocess_func, arguments in params['preprocessing'].items():
if preprocess_func == "cache":
continue
elif preprocess_func == "remove_bad_channels":
bad_channel_ids, channel_labels = spre.detect_bad_channels(recording, **arguments)
recording = recording.remove_channels(bad_channel_ids)
else:
function = getattr(spre, preprocess_func)
recording = function(recording, **arguments)

if 'cache' in params['preprocessing']:
recording = recording.save(folder=params['preprocessing']['cache'])

return recording

@staticmethod
def _load_sortings(sortings_path: dict[str, str]) -> dict[str, si.BaseSorting]:
Expand All @@ -204,7 +229,6 @@ def _load_sortings(sortings_path: dict[str, str]) -> dict[str, si.BaseSorting]:
sorting = si.load_extractor(path, base_folder=True)
assert isinstance(sorting, si.BaseSorting)

sorting.annotate(name=name)
sortings[name] = sorting

return sortings
Expand Down Expand Up @@ -299,8 +323,6 @@ def create_from_params(params: dict[str, dict]) -> 'LussacData':
"""

recording = LussacData._load_recording(params['recording'])
if 'probe_file' in params['recording']:
recording = LussacData._setup_probe(recording, str(pathlib.Path(params['recording']['probe_file']).absolute()))
sortings = LussacData._load_sortings(params['analyses'] if 'analyses' in params else {})

return LussacData(recording, sortings, params)
Expand Down
75 changes: 75 additions & 0 deletions src/lussac/core/lussac_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import json
import pathlib
import platform

import jsmin

import lussac


class LussacParams:

@staticmethod
def load_from_string(params: str, params_folder: pathlib.Path | str | None = None):
"""
Loads the parameters from a string and returns them as a dict.
@param params: str
Lussac's parameters.
@param params_folder: str
Path to replace the "$PARAMS_FOLDER".
"""

if params_folder is not None:
params_folder = str(pathlib.Path(params_folder).absolute())
params = params.replace("$PARAMS_FOLDER", params_folder)
if platform.system() == "Windows": # pragma: no cover (OS specific).
params = params.replace("\\", "\\\\")

return json.loads(params)

@staticmethod
def load_from_json_file(filename: str, params_folder: pathlib.Path | str | None = None) -> dict:
"""
Loads the JSON parameters file and returns its content as a dict.
@param filename: str
Path to the file containing Lussac's parameters.
@param params_folder: Path | str | None
Path to replace the "$PARAMS_FOLDER".
If None (default), will use the parent folder of the filename.
@return params: dict
Lussac's parameters.
"""

if params_folder is None:
params_folder = str(pathlib.Path(filename).parent.absolute())
else:
params_folder = str(pathlib.Path(params_folder).absolute())

with open(filename) as json_file:
minified = jsmin.jsmin(json_file.read()) # Parses out comments.
return LussacParams.load_from_string(minified, params_folder)

@staticmethod
def load_default_params(name: str, folder: pathlib.Path | str) -> dict:
"""
Loads the default parameters from the "params_example" folder.
@param name: str
The name of the default params file to load.
@param folder: str
Path to the folder where to create the "lussac" folder.
@return params: dict
The default parameters.
"""

if not name.startswith("params_"):
name = f"params_{name}"
if not name.endswith(".json"):
name = f"{name}.json"

params_folder = pathlib.Path(lussac.__file__).parent / "params_examples"
file = params_folder / name

return LussacParams.load_from_json_file(str(file), folder)
Loading

0 comments on commit 02eb5c1

Please sign in to comment.