Skip to content

Commit

Permalink
Merge pull request #183 from jpgill86/elephant-tools
Browse files Browse the repository at this point in the history
Rename elephant functions
  • Loading branch information
jpgill86 authored Jan 21, 2020
2 parents 9a9897b + 20c9d32 commit 6329e5d
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 128 deletions.
247 changes: 128 additions & 119 deletions neurotic/elephant_tools.py → neurotic/_elephant_tools.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# -*- coding: utf-8 -*-
"""
The :mod:`neurotic.elephant_tools` module contains functions copied from the
The :mod:`neurotic._elephant_tools` module contains functions copied from the
elephant package, which are included for convenience and to eliminate
dependency on that package.
This module and the functions it contains are not intended to be part of
neurotic's public API, so all function names begin with underscores. This
neurotic's public API, so the module name begins with an underscore. This
module may be removed at a future date.
elephant is licensed under BSD-3-Clause:
Expand Down Expand Up @@ -43,8 +43,11 @@
from neo import SpikeTrain


def _butter(signal, highpass_freq=None, lowpass_freq=None, order=4,
filter_function='filtfilt', fs=1.0, axis=-1):
###############################################################################
# elephant.signal_processing

def butter(signal, highpass_freq=None, lowpass_freq=None, order=4,
filter_function='filtfilt', fs=1.0, axis=-1):
"""
Butterworth filtering function for neo.AnalogSignal. Filter type is
determined according to how values of `highpass_freq` and `lowpass_freq`
Expand Down Expand Up @@ -164,121 +167,7 @@ def _butter(signal, highpass_freq=None, lowpass_freq=None, order=4,
else:
return filtered_data

def _isi(spiketrain, axis=-1):
"""
Return an array containing the inter-spike intervals of the SpikeTrain.
Accepts a Neo SpikeTrain, a Quantity array, or a plain NumPy array.
If either a SpikeTrain or Quantity array is provided, the return value will
be a quantities array, otherwise a plain NumPy array. The units of
the quantities array will be the same as spiketrain.
Parameters
----------
spiketrain : Neo SpikeTrain or Quantity array or NumPy ndarray
The spike times.
axis : int, optional
The axis along which the difference is taken.
Default is the last axis.
Returns
-------
NumPy array or quantities array.
"""
if axis is None:
axis = -1
if isinstance(spiketrain, neo.SpikeTrain):
intervals = np.diff(
np.sort(spiketrain.times.view(pq.Quantity)), axis=axis)
else:
intervals = np.diff(np.sort(spiketrain), axis=axis)
return intervals

def _peak_detection(signal, threshold=0.0 * mV, sign='above', format=None):
"""
Return the peak times for all events that cross threshold.
Usually used for extracting spike times from a membrane potential.
Similar to spike_train_generation.threshold_detection.
Parameters
----------
signal : neo AnalogSignal object
'signal' is an analog signal.
threshold : A quantity, e.g. in mV
'threshold' contains a value that must be reached
for an event to be detected.
sign : 'above' or 'below'
'sign' determines whether to count thresholding crossings that
cross above or below the threshold. Default: 'above'.
format : None or 'raw'
Whether to return as SpikeTrain (None) or as a plain array
of times ('raw'). Default: None.
Returns
-------
result_st : neo SpikeTrain object
'result_st' contains the spike times of each of the events
(spikes) extracted from the signal.
"""
assert threshold is not None, "A threshold must be provided"

if sign == 'above':
cutout = np.where(signal > threshold)[0]
peak_func = np.argmax
elif sign == 'below':
cutout = np.where(signal < threshold)[0]
peak_func = np.argmin
else:
raise ValueError("sign must be 'above' or 'below'")

if len(cutout) <= 0:
events_base = np.zeros(0)
else:
# Select thr crossings lasting at least 2 dtps, np.diff(cutout) > 2
# This avoids empty slices
border_start = np.where(np.diff(cutout) > 1)[0]
border_end = border_start + 1
borders = np.concatenate((border_start, border_end))
borders = np.append(0, borders)
borders = np.append(borders, len(cutout)-1)
borders = np.sort(borders)
true_borders = cutout[borders]
right_borders = true_borders[1::2] + 1
true_borders = np.sort(np.append(true_borders[0::2], right_borders))

# Workaround for bug that occurs when signal goes below thr for 1 dtp,
# Workaround eliminates empy slices from np. split
backward_mask = np.absolute(np.ediff1d(true_borders, to_begin=1)) > 0
forward_mask = np.absolute(np.ediff1d(true_borders[::-1],
to_begin=1)[::-1]) > 0
true_borders = true_borders[backward_mask * forward_mask]
split_signal = np.split(np.array(signal), true_borders)[1::2]

maxima_idc_split = np.array([peak_func(x) for x in split_signal])

max_idc = maxima_idc_split + true_borders[0::2]

events = signal.times[max_idc]
events_base = events.magnitude

if events_base is None:
# This occurs in some Python 3 builds due to some
# bug in quantities.
events_base = np.array([event.magnitude for event in events]) # Workaround
if format is None:
result_st = SpikeTrain(events_base, units=signal.times.units,
t_start=signal.t_start, t_stop=signal.t_stop)
elif 'raw':
result_st = events_base
else:
raise ValueError("Format argument must be None or 'raw'")

return result_st

def _rauc(signal, baseline=None, bin_duration=None, t_start=None, t_stop=None):
def rauc(signal, baseline=None, bin_duration=None, t_start=None, t_stop=None):
'''
Calculate the rectified area under the curve (RAUC) for an AnalogSignal.
Expand Down Expand Up @@ -395,3 +284,123 @@ def _rauc(signal, baseline=None, bin_duration=None, t_start=None, t_stop=None):
t_start=signal.t_start.rescale(bin_duration.units)+bin_duration/2,
sampling_period=bin_duration)
return rauc_sig

###############################################################################
# elephant.spike_train_generation

def peak_detection(signal, threshold=0.0 * mV, sign='above', format=None):
"""
Return the peak times for all events that cross threshold.
Usually used for extracting spike times from a membrane potential.
Similar to spike_train_generation.threshold_detection.
Parameters
----------
signal : neo AnalogSignal object
'signal' is an analog signal.
threshold : A quantity, e.g. in mV
'threshold' contains a value that must be reached
for an event to be detected.
sign : 'above' or 'below'
'sign' determines whether to count thresholding crossings that
cross above or below the threshold. Default: 'above'.
format : None or 'raw'
Whether to return as SpikeTrain (None) or as a plain array
of times ('raw'). Default: None.
Returns
-------
result_st : neo SpikeTrain object
'result_st' contains the spike times of each of the events
(spikes) extracted from the signal.
"""
assert threshold is not None, "A threshold must be provided"

if sign == 'above':
cutout = np.where(signal > threshold)[0]
peak_func = np.argmax
elif sign == 'below':
cutout = np.where(signal < threshold)[0]
peak_func = np.argmin
else:
raise ValueError("sign must be 'above' or 'below'")

if len(cutout) <= 0:
events_base = np.zeros(0)
else:
# Select thr crossings lasting at least 2 dtps, np.diff(cutout) > 2
# This avoids empty slices
border_start = np.where(np.diff(cutout) > 1)[0]
border_end = border_start + 1
borders = np.concatenate((border_start, border_end))
borders = np.append(0, borders)
borders = np.append(borders, len(cutout)-1)
borders = np.sort(borders)
true_borders = cutout[borders]
right_borders = true_borders[1::2] + 1
true_borders = np.sort(np.append(true_borders[0::2], right_borders))

# Workaround for bug that occurs when signal goes below thr for 1 dtp,
# Workaround eliminates empy slices from np. split
backward_mask = np.absolute(np.ediff1d(true_borders, to_begin=1)) > 0
forward_mask = np.absolute(np.ediff1d(true_borders[::-1],
to_begin=1)[::-1]) > 0
true_borders = true_borders[backward_mask * forward_mask]
split_signal = np.split(np.array(signal), true_borders)[1::2]

maxima_idc_split = np.array([peak_func(x) for x in split_signal])

max_idc = maxima_idc_split + true_borders[0::2]

events = signal.times[max_idc]
events_base = events.magnitude

if events_base is None:
# This occurs in some Python 3 builds due to some
# bug in quantities.
events_base = np.array([event.magnitude for event in events]) # Workaround
if format is None:
result_st = SpikeTrain(events_base, units=signal.times.units,
t_start=signal.t_start, t_stop=signal.t_stop)
elif 'raw':
result_st = events_base
else:
raise ValueError("Format argument must be None or 'raw'")

return result_st

###############################################################################
# elephant.statistics

def isi(spiketrain, axis=-1):
"""
Return an array containing the inter-spike intervals of the SpikeTrain.
Accepts a Neo SpikeTrain, a Quantity array, or a plain NumPy array.
If either a SpikeTrain or Quantity array is provided, the return value will
be a quantities array, otherwise a plain NumPy array. The units of
the quantities array will be the same as spiketrain.
Parameters
----------
spiketrain : Neo SpikeTrain or Quantity array or NumPy ndarray
The spike times.
axis : int, optional
The axis along which the difference is taken.
Default is the last axis.
Returns
-------
NumPy array or quantities array.
"""
if axis is None:
axis = -1
if isinstance(spiketrain, neo.SpikeTrain):
intervals = np.diff(
np.sort(spiketrain.times.view(pq.Quantity)), axis=axis)
else:
intervals = np.diff(np.sort(spiketrain), axis=axis)
return intervals
10 changes: 5 additions & 5 deletions neurotic/datasets/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import neo

from ..datasets.metadata import _abs_path
from ..elephant_tools import _butter, _isi, _peak_detection
from .. import _elephant_tools

import logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -500,7 +500,7 @@ def _apply_filters(metadata, blk):
high *= pq.Hz
if low:
low *= pq.Hz
blk.segments[0].analogsignals[index] = _butter(
blk.segments[0].analogsignals[index] = _elephant_tools.butter(
signal = blk.segments[0].analogsignals[index],
highpass_freq = high,
lowpass_freq = low,
Expand Down Expand Up @@ -555,8 +555,8 @@ def _detect_spikes(sig, discriminator, epochs):
else:
raise ValueError('amplitude discriminator must have two nonnegative thresholds or two nonpositive thresholds: {}'.format(discriminator))

spikes_crossing_min = _peak_detection(sig, pq.Quantity(min_threshold, discriminator['units']), sign, 'raw')
spikes_crossing_max = _peak_detection(sig, pq.Quantity(max_threshold, discriminator['units']), sign, 'raw')
spikes_crossing_min = _elephant_tools.peak_detection(sig, pq.Quantity(min_threshold, discriminator['units']), sign, 'raw')
spikes_crossing_max = _elephant_tools.peak_detection(sig, pq.Quantity(max_threshold, discriminator['units']), sign, 'raw')
if sign == 'above':
spikes_between_min_and_max = np.setdiff1d(spikes_crossing_min, spikes_crossing_max)
elif sign == 'below':
Expand Down Expand Up @@ -644,7 +644,7 @@ def _find_bursts(st, start_freq, stop_freq):
``start_freq``, since otherwise bursts may not be detected.
"""

isi = _isi(st).rescale('s')
isi = _elephant_tools.isi(st).rescale('s')
iff = 1/isi

start_mask = iff > start_freq
Expand Down
2 changes: 1 addition & 1 deletion neurotic/example/example-notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@
"rauc_sigs = []\n",
"if not lazy:\n",
" for sig in blk.segments[0].analogsignals:\n",
" rauc = neurotic.elephant_tools._rauc(sig, baseline=metadata['rauc_baseline'], bin_duration=metadata['rauc_bin_duration']*pq.s)\n",
" rauc = neurotic._elephant_tools.rauc(sig, baseline=metadata['rauc_baseline'], bin_duration=metadata['rauc_bin_duration']*pq.s)\n",
" rauc.name = sig.name + ' RAUC'\n",
" rauc_sigs.append(rauc)\n",
"\n",
Expand Down
5 changes: 2 additions & 3 deletions neurotic/gui/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
import quantities as pq
from ephyviewer import QT, QT_MODE

from .. import __version__, default_log_level, log_file
from .. import __version__, _elephant_tools, default_log_level, log_file
from ..datasets import MetadataSelector, load_dataset
from ..datasets.metadata import _selector_labels
from ..elephant_tools import _rauc
from ..gui.config import EphyviewerConfigurator

import logging
Expand Down Expand Up @@ -285,7 +284,7 @@ def launch(self):
rauc_sigs = []
if not self.lazy:
for sig in blk.segments[0].analogsignals:
rauc = _rauc(sig, baseline=metadata['rauc_baseline'], bin_duration=metadata['rauc_bin_duration']*pq.s)
rauc = _elephant_tools.rauc(sig, baseline=metadata['rauc_baseline'], bin_duration=metadata['rauc_bin_duration']*pq.s)
rauc.name = sig.name + ' RAUC'
rauc_sigs.append(rauc)

Expand Down

0 comments on commit 6329e5d

Please sign in to comment.