From 20c9d32cf77a79781baf8a9e8b026f58132d8486 Mon Sep 17 00:00:00 2001 From: Jeffrey Gill Date: Mon, 20 Jan 2020 19:35:30 -0500 Subject: [PATCH] Reordered elephant functions by their original modules --- neurotic/_elephant_tools.py | 237 +++++++++++++++++++----------------- 1 file changed, 123 insertions(+), 114 deletions(-) diff --git a/neurotic/_elephant_tools.py b/neurotic/_elephant_tools.py index a412ffe..46aeede 100644 --- a/neurotic/_elephant_tools.py +++ b/neurotic/_elephant_tools.py @@ -43,6 +43,9 @@ from neo import SpikeTrain +############################################################################### +# elephant.signal_processing + def butter(signal, highpass_freq=None, lowpass_freq=None, order=4, filter_function='filtfilt', fs=1.0, axis=-1): """ @@ -164,120 +167,6 @@ def butter(signal, highpass_freq=None, lowpass_freq=None, order=4, else: return filtered_data -def isi(spiketrain, axis=-1): - """ - Return an array containing the inter-spike intervals of the SpikeTrain. - - Accepts a Neo SpikeTrain, a Quantity array, or a plain NumPy array. - If either a SpikeTrain or Quantity array is provided, the return value will - be a quantities array, otherwise a plain NumPy array. The units of - the quantities array will be the same as spiketrain. - - Parameters - ---------- - - spiketrain : Neo SpikeTrain or Quantity array or NumPy ndarray - The spike times. - axis : int, optional - The axis along which the difference is taken. - Default is the last axis. - - Returns - ------- - - NumPy array or quantities array. - - """ - if axis is None: - axis = -1 - if isinstance(spiketrain, neo.SpikeTrain): - intervals = np.diff( - np.sort(spiketrain.times.view(pq.Quantity)), axis=axis) - else: - intervals = np.diff(np.sort(spiketrain), axis=axis) - return intervals - -def peak_detection(signal, threshold=0.0 * mV, sign='above', format=None): - """ - Return the peak times for all events that cross threshold. - Usually used for extracting spike times from a membrane potential. - Similar to spike_train_generation.threshold_detection. - - Parameters - ---------- - signal : neo AnalogSignal object - 'signal' is an analog signal. - threshold : A quantity, e.g. in mV - 'threshold' contains a value that must be reached - for an event to be detected. - sign : 'above' or 'below' - 'sign' determines whether to count thresholding crossings that - cross above or below the threshold. Default: 'above'. - format : None or 'raw' - Whether to return as SpikeTrain (None) or as a plain array - of times ('raw'). Default: None. - - Returns - ------- - result_st : neo SpikeTrain object - 'result_st' contains the spike times of each of the events - (spikes) extracted from the signal. - """ - assert threshold is not None, "A threshold must be provided" - - if sign == 'above': - cutout = np.where(signal > threshold)[0] - peak_func = np.argmax - elif sign == 'below': - cutout = np.where(signal < threshold)[0] - peak_func = np.argmin - else: - raise ValueError("sign must be 'above' or 'below'") - - if len(cutout) <= 0: - events_base = np.zeros(0) - else: - # Select thr crossings lasting at least 2 dtps, np.diff(cutout) > 2 - # This avoids empty slices - border_start = np.where(np.diff(cutout) > 1)[0] - border_end = border_start + 1 - borders = np.concatenate((border_start, border_end)) - borders = np.append(0, borders) - borders = np.append(borders, len(cutout)-1) - borders = np.sort(borders) - true_borders = cutout[borders] - right_borders = true_borders[1::2] + 1 - true_borders = np.sort(np.append(true_borders[0::2], right_borders)) - - # Workaround for bug that occurs when signal goes below thr for 1 dtp, - # Workaround eliminates empy slices from np. split - backward_mask = np.absolute(np.ediff1d(true_borders, to_begin=1)) > 0 - forward_mask = np.absolute(np.ediff1d(true_borders[::-1], - to_begin=1)[::-1]) > 0 - true_borders = true_borders[backward_mask * forward_mask] - split_signal = np.split(np.array(signal), true_borders)[1::2] - - maxima_idc_split = np.array([peak_func(x) for x in split_signal]) - - max_idc = maxima_idc_split + true_borders[0::2] - - events = signal.times[max_idc] - events_base = events.magnitude - - if events_base is None: - # This occurs in some Python 3 builds due to some - # bug in quantities. - events_base = np.array([event.magnitude for event in events]) # Workaround - if format is None: - result_st = SpikeTrain(events_base, units=signal.times.units, - t_start=signal.t_start, t_stop=signal.t_stop) - elif 'raw': - result_st = events_base - else: - raise ValueError("Format argument must be None or 'raw'") - - return result_st - def rauc(signal, baseline=None, bin_duration=None, t_start=None, t_stop=None): ''' Calculate the rectified area under the curve (RAUC) for an AnalogSignal. @@ -395,3 +284,123 @@ def rauc(signal, baseline=None, bin_duration=None, t_start=None, t_stop=None): t_start=signal.t_start.rescale(bin_duration.units)+bin_duration/2, sampling_period=bin_duration) return rauc_sig + +############################################################################### +# elephant.spike_train_generation + +def peak_detection(signal, threshold=0.0 * mV, sign='above', format=None): + """ + Return the peak times for all events that cross threshold. + Usually used for extracting spike times from a membrane potential. + Similar to spike_train_generation.threshold_detection. + + Parameters + ---------- + signal : neo AnalogSignal object + 'signal' is an analog signal. + threshold : A quantity, e.g. in mV + 'threshold' contains a value that must be reached + for an event to be detected. + sign : 'above' or 'below' + 'sign' determines whether to count thresholding crossings that + cross above or below the threshold. Default: 'above'. + format : None or 'raw' + Whether to return as SpikeTrain (None) or as a plain array + of times ('raw'). Default: None. + + Returns + ------- + result_st : neo SpikeTrain object + 'result_st' contains the spike times of each of the events + (spikes) extracted from the signal. + """ + assert threshold is not None, "A threshold must be provided" + + if sign == 'above': + cutout = np.where(signal > threshold)[0] + peak_func = np.argmax + elif sign == 'below': + cutout = np.where(signal < threshold)[0] + peak_func = np.argmin + else: + raise ValueError("sign must be 'above' or 'below'") + + if len(cutout) <= 0: + events_base = np.zeros(0) + else: + # Select thr crossings lasting at least 2 dtps, np.diff(cutout) > 2 + # This avoids empty slices + border_start = np.where(np.diff(cutout) > 1)[0] + border_end = border_start + 1 + borders = np.concatenate((border_start, border_end)) + borders = np.append(0, borders) + borders = np.append(borders, len(cutout)-1) + borders = np.sort(borders) + true_borders = cutout[borders] + right_borders = true_borders[1::2] + 1 + true_borders = np.sort(np.append(true_borders[0::2], right_borders)) + + # Workaround for bug that occurs when signal goes below thr for 1 dtp, + # Workaround eliminates empy slices from np. split + backward_mask = np.absolute(np.ediff1d(true_borders, to_begin=1)) > 0 + forward_mask = np.absolute(np.ediff1d(true_borders[::-1], + to_begin=1)[::-1]) > 0 + true_borders = true_borders[backward_mask * forward_mask] + split_signal = np.split(np.array(signal), true_borders)[1::2] + + maxima_idc_split = np.array([peak_func(x) for x in split_signal]) + + max_idc = maxima_idc_split + true_borders[0::2] + + events = signal.times[max_idc] + events_base = events.magnitude + + if events_base is None: + # This occurs in some Python 3 builds due to some + # bug in quantities. + events_base = np.array([event.magnitude for event in events]) # Workaround + if format is None: + result_st = SpikeTrain(events_base, units=signal.times.units, + t_start=signal.t_start, t_stop=signal.t_stop) + elif 'raw': + result_st = events_base + else: + raise ValueError("Format argument must be None or 'raw'") + + return result_st + +############################################################################### +# elephant.statistics + +def isi(spiketrain, axis=-1): + """ + Return an array containing the inter-spike intervals of the SpikeTrain. + + Accepts a Neo SpikeTrain, a Quantity array, or a plain NumPy array. + If either a SpikeTrain or Quantity array is provided, the return value will + be a quantities array, otherwise a plain NumPy array. The units of + the quantities array will be the same as spiketrain. + + Parameters + ---------- + + spiketrain : Neo SpikeTrain or Quantity array or NumPy ndarray + The spike times. + axis : int, optional + The axis along which the difference is taken. + Default is the last axis. + + Returns + ------- + + NumPy array or quantities array. + + """ + if axis is None: + axis = -1 + if isinstance(spiketrain, neo.SpikeTrain): + intervals = np.diff( + np.sort(spiketrain.times.view(pq.Quantity)), axis=axis) + else: + intervals = np.diff(np.sort(spiketrain), axis=axis) + return intervals