Skip to content

Commit

Permalink
Reordered elephant functions by their original modules
Browse files Browse the repository at this point in the history
  • Loading branch information
jpgill86 committed Jan 21, 2020
1 parent 558a88a commit 20c9d32
Showing 1 changed file with 123 additions and 114 deletions.
237 changes: 123 additions & 114 deletions neurotic/_elephant_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
from neo import SpikeTrain


###############################################################################
# elephant.signal_processing

def butter(signal, highpass_freq=None, lowpass_freq=None, order=4,
filter_function='filtfilt', fs=1.0, axis=-1):
"""
Expand Down Expand Up @@ -164,120 +167,6 @@ def butter(signal, highpass_freq=None, lowpass_freq=None, order=4,
else:
return filtered_data

def isi(spiketrain, axis=-1):
"""
Return an array containing the inter-spike intervals of the SpikeTrain.
Accepts a Neo SpikeTrain, a Quantity array, or a plain NumPy array.
If either a SpikeTrain or Quantity array is provided, the return value will
be a quantities array, otherwise a plain NumPy array. The units of
the quantities array will be the same as spiketrain.
Parameters
----------
spiketrain : Neo SpikeTrain or Quantity array or NumPy ndarray
The spike times.
axis : int, optional
The axis along which the difference is taken.
Default is the last axis.
Returns
-------
NumPy array or quantities array.
"""
if axis is None:
axis = -1
if isinstance(spiketrain, neo.SpikeTrain):
intervals = np.diff(
np.sort(spiketrain.times.view(pq.Quantity)), axis=axis)
else:
intervals = np.diff(np.sort(spiketrain), axis=axis)
return intervals

def peak_detection(signal, threshold=0.0 * mV, sign='above', format=None):
"""
Return the peak times for all events that cross threshold.
Usually used for extracting spike times from a membrane potential.
Similar to spike_train_generation.threshold_detection.
Parameters
----------
signal : neo AnalogSignal object
'signal' is an analog signal.
threshold : A quantity, e.g. in mV
'threshold' contains a value that must be reached
for an event to be detected.
sign : 'above' or 'below'
'sign' determines whether to count thresholding crossings that
cross above or below the threshold. Default: 'above'.
format : None or 'raw'
Whether to return as SpikeTrain (None) or as a plain array
of times ('raw'). Default: None.
Returns
-------
result_st : neo SpikeTrain object
'result_st' contains the spike times of each of the events
(spikes) extracted from the signal.
"""
assert threshold is not None, "A threshold must be provided"

if sign == 'above':
cutout = np.where(signal > threshold)[0]
peak_func = np.argmax
elif sign == 'below':
cutout = np.where(signal < threshold)[0]
peak_func = np.argmin
else:
raise ValueError("sign must be 'above' or 'below'")

if len(cutout) <= 0:
events_base = np.zeros(0)
else:
# Select thr crossings lasting at least 2 dtps, np.diff(cutout) > 2
# This avoids empty slices
border_start = np.where(np.diff(cutout) > 1)[0]
border_end = border_start + 1
borders = np.concatenate((border_start, border_end))
borders = np.append(0, borders)
borders = np.append(borders, len(cutout)-1)
borders = np.sort(borders)
true_borders = cutout[borders]
right_borders = true_borders[1::2] + 1
true_borders = np.sort(np.append(true_borders[0::2], right_borders))

# Workaround for bug that occurs when signal goes below thr for 1 dtp,
# Workaround eliminates empy slices from np. split
backward_mask = np.absolute(np.ediff1d(true_borders, to_begin=1)) > 0
forward_mask = np.absolute(np.ediff1d(true_borders[::-1],
to_begin=1)[::-1]) > 0
true_borders = true_borders[backward_mask * forward_mask]
split_signal = np.split(np.array(signal), true_borders)[1::2]

maxima_idc_split = np.array([peak_func(x) for x in split_signal])

max_idc = maxima_idc_split + true_borders[0::2]

events = signal.times[max_idc]
events_base = events.magnitude

if events_base is None:
# This occurs in some Python 3 builds due to some
# bug in quantities.
events_base = np.array([event.magnitude for event in events]) # Workaround
if format is None:
result_st = SpikeTrain(events_base, units=signal.times.units,
t_start=signal.t_start, t_stop=signal.t_stop)
elif 'raw':
result_st = events_base
else:
raise ValueError("Format argument must be None or 'raw'")

return result_st

def rauc(signal, baseline=None, bin_duration=None, t_start=None, t_stop=None):
'''
Calculate the rectified area under the curve (RAUC) for an AnalogSignal.
Expand Down Expand Up @@ -395,3 +284,123 @@ def rauc(signal, baseline=None, bin_duration=None, t_start=None, t_stop=None):
t_start=signal.t_start.rescale(bin_duration.units)+bin_duration/2,
sampling_period=bin_duration)
return rauc_sig

###############################################################################
# elephant.spike_train_generation

def peak_detection(signal, threshold=0.0 * mV, sign='above', format=None):
"""
Return the peak times for all events that cross threshold.
Usually used for extracting spike times from a membrane potential.
Similar to spike_train_generation.threshold_detection.
Parameters
----------
signal : neo AnalogSignal object
'signal' is an analog signal.
threshold : A quantity, e.g. in mV
'threshold' contains a value that must be reached
for an event to be detected.
sign : 'above' or 'below'
'sign' determines whether to count thresholding crossings that
cross above or below the threshold. Default: 'above'.
format : None or 'raw'
Whether to return as SpikeTrain (None) or as a plain array
of times ('raw'). Default: None.
Returns
-------
result_st : neo SpikeTrain object
'result_st' contains the spike times of each of the events
(spikes) extracted from the signal.
"""
assert threshold is not None, "A threshold must be provided"

if sign == 'above':
cutout = np.where(signal > threshold)[0]
peak_func = np.argmax
elif sign == 'below':
cutout = np.where(signal < threshold)[0]
peak_func = np.argmin
else:
raise ValueError("sign must be 'above' or 'below'")

if len(cutout) <= 0:
events_base = np.zeros(0)
else:
# Select thr crossings lasting at least 2 dtps, np.diff(cutout) > 2
# This avoids empty slices
border_start = np.where(np.diff(cutout) > 1)[0]
border_end = border_start + 1
borders = np.concatenate((border_start, border_end))
borders = np.append(0, borders)
borders = np.append(borders, len(cutout)-1)
borders = np.sort(borders)
true_borders = cutout[borders]
right_borders = true_borders[1::2] + 1
true_borders = np.sort(np.append(true_borders[0::2], right_borders))

# Workaround for bug that occurs when signal goes below thr for 1 dtp,
# Workaround eliminates empy slices from np. split
backward_mask = np.absolute(np.ediff1d(true_borders, to_begin=1)) > 0
forward_mask = np.absolute(np.ediff1d(true_borders[::-1],
to_begin=1)[::-1]) > 0
true_borders = true_borders[backward_mask * forward_mask]
split_signal = np.split(np.array(signal), true_borders)[1::2]

maxima_idc_split = np.array([peak_func(x) for x in split_signal])

max_idc = maxima_idc_split + true_borders[0::2]

events = signal.times[max_idc]
events_base = events.magnitude

if events_base is None:
# This occurs in some Python 3 builds due to some
# bug in quantities.
events_base = np.array([event.magnitude for event in events]) # Workaround
if format is None:
result_st = SpikeTrain(events_base, units=signal.times.units,
t_start=signal.t_start, t_stop=signal.t_stop)
elif 'raw':
result_st = events_base
else:
raise ValueError("Format argument must be None or 'raw'")

return result_st

###############################################################################
# elephant.statistics

def isi(spiketrain, axis=-1):
"""
Return an array containing the inter-spike intervals of the SpikeTrain.
Accepts a Neo SpikeTrain, a Quantity array, or a plain NumPy array.
If either a SpikeTrain or Quantity array is provided, the return value will
be a quantities array, otherwise a plain NumPy array. The units of
the quantities array will be the same as spiketrain.
Parameters
----------
spiketrain : Neo SpikeTrain or Quantity array or NumPy ndarray
The spike times.
axis : int, optional
The axis along which the difference is taken.
Default is the last axis.
Returns
-------
NumPy array or quantities array.
"""
if axis is None:
axis = -1
if isinstance(spiketrain, neo.SpikeTrain):
intervals = np.diff(
np.sort(spiketrain.times.view(pq.Quantity)), axis=axis)
else:
intervals = np.diff(np.sort(spiketrain), axis=axis)
return intervals

0 comments on commit 20c9d32

Please sign in to comment.