forked from MouseLand/Kilosort
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1. remove_ks2_duplicate_spikes - removes duplicate spikes resulting from multiple templates explaining variance from a single waveform
- Loading branch information
Showing
5 changed files
with
308 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
function rez = recompute_clusters(rez) | ||
|
||
%recompute cluster statistics after a post-processing operation | ||
|
||
NchanNear = min(rez.ops.Nchan, 32); | ||
Nnearest = min(rez.ops.Nchan, 32); | ||
iC = getClosestChannels(rez, rez.ops.sigmaMask, NchanNear); | ||
Nfilt = size(rez.W,2); % new number of templates | ||
Nrank = 3; | ||
Nchan = rez.ops.Nchan; | ||
Params = double([0 Nfilt 0 0 size(rez.W,1) Nnearest ... | ||
Nrank 0 0 Nchan NchanNear rez.ops.nt0min 0]); % make a new Params to pass on parameters to CUDA | ||
|
||
% we need to re-estimate the spatial profiles | ||
[Ka, Kb] = getKernels(rez.ops, 10, 1); % we get the time upsampling kernels again | ||
[~, iW] = max(abs(rez.dWU(rez.ops.nt0min, :, :)), [], 2); % find the peak abs channel for each template | ||
iW = squeeze(int32(iW)); | ||
[rez.W, rez.U, rez.mu] = mexSVDsmall2(Params, rez.dWU, rez.W, iC-1, iW-1, Ka, Kb); % we run SVD | ||
|
||
[WtW, iList] = getMeWtW(single(rez.W), single(rez.U), Nnearest); % we re-compute similarity scores between templates | ||
rez.iList = iList; % over-write the list of nearest templates | ||
|
||
rez.simScore = gather(max(WtW, [], 3)); | ||
|
||
rez.iNeigh = gather(iList(:, 1:Nfilt)); % get the new neighbor templates | ||
rez.iNeighPC = gather(iC(:, iW(1:Nfilt))); % get the new neighbor channels | ||
|
||
rez.Wphy = cat(1, zeros(1+rez.ops.nt0min, Nfilt, Nrank), rez.W); % for Phy, we need to pad the spikes with zeros so the spikes are aligned to the center of the window | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
function rez = remove_inactive_clusters(rez,varargin) | ||
|
||
p=inputParser; | ||
p.addParameter('min_spikes',50,@(x)validateattributes(x,{'numeric'},{'scalar','nonnegative'})); | ||
p.parse(varargin{:}); | ||
params=p.Results; | ||
[template_idx,~,uniqueIdx]=unique(rez.st3(:,2)); | ||
n_spikes=accumarray(uniqueIdx,1); | ||
n_spikes = n_spikes(template_idx>0); | ||
template_idx = template_idx(template_idx>0); | ||
|
||
below_n_spikes_threshold= n_spikes<params.min_spikes; | ||
recording_duration_sec = rez.ops.sampsToRead./rez.ops.fs; | ||
below_rate_threshold = n_spikes<rez.ops.minFR*recording_duration_sec; | ||
|
||
remove_clusters = below_n_spikes_threshold | below_rate_threshold; | ||
|
||
spikes_to_remove = ismember(rez.st3(:,2),template_idx(remove_clusters)); | ||
|
||
rez.inactive = template_idx(remove_clusters); | ||
|
||
rez = remove_spikes(rez,spikes_to_remove,'inactive'); | ||
rez = recompute_clusters(rez); | ||
|
||
fprintf('Removed %g inactive clusters.\n',sum(remove_clusters)); | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
% REMOVE_KS2_DUPLICATE_SPIKES2 Double-counted spikes are hard to avoid with | ||
% Kilosort's template matching algorithm since the overall fit can be | ||
% improved by having multiple templates jointly account for a single variable waveform. | ||
% | ||
% This function takes the kilosort2 output rez and identifies pair of | ||
% spikes that are close together in time and space. The temporal threshold | ||
% is give by the parameter OVERLAP_S which is 5e-4 (0.5ms by default) and | ||
% the spatial threshold (applied to the template primary sites) is given by | ||
% CHANNEL_SEPARATION_UM and is 50 by default. | ||
% | ||
% From these spike pairs, it identifies the pair with the larger template as | ||
% being the "main" or "reference" cluster and the duplicate spikes from the | ||
% other cluster are removed. | ||
% | ||
% All spike pairs are considered, not just those from CCG-contaminated | ||
% pairs, as in REMOVE_KS2_DUPLICATE_SPIKES2. | ||
% | ||
%=INPUT | ||
% | ||
% rez structure | ||
% | ||
%=OPTIONAL INPUT, NAME-VALUE PAIRS | ||
% | ||
% overlap_s | ||
% the time interval, in second, within which a sequence of spikes are | ||
% vetted for duplicates. | ||
% | ||
% channel_separation_um | ||
% When the primay channels of two spikes are within this distance, in | ||
% microns, then the two spikes are vetted for duplicate. | ||
% | ||
%=EXAMPLE | ||
% | ||
% >> rez = remove_ks2_duplicate_spikes(rez) | ||
function rez = remove_ks2_duplicate_spikes(rez, varargin) | ||
input_parser = inputParser; | ||
addParameter(input_parser, 'overlap_s', 5e-4, @(x) (isnumeric(x))) | ||
addParameter(input_parser, 'channel_separation_um', 50, @(x) (ischar(x))) | ||
parse(input_parser, varargin{:}); | ||
P = input_parser.Results; | ||
|
||
spike_times = uint64(rez.st3(:,1)); | ||
spike_templates = uint32(rez.st3(:,2)); | ||
|
||
rez.U=gather(rez.U); | ||
rez.W = gather(rez.W); | ||
templates = zeros(rez.ops.Nchan, size(rez.W,1), size(rez.W,2), 'single'); | ||
for iNN = 1:size(templates,3) | ||
templates(:,:,iNN) = squeeze(rez.U(:,iNN,:)) * squeeze(rez.W(:,iNN,:))'; | ||
end | ||
templates = permute(templates, [3 2 1]); % now it's nTemplates x nSamples x nChannels | ||
|
||
%% Make sure that the spike times are sorted | ||
if ~issorted(spike_times) | ||
[spike_times, I] = sort(spike_times); | ||
spike_templates = spike_templates(I); | ||
end | ||
%% deal with cluster 0 | ||
if any(spike_templates==0) | ||
error('Currently this function can''t deal with existence of cluster 0. Should be OK since it ought to be run first in the post-processing.'); | ||
end | ||
%% Determine the channel where each spike had that largest amplitude (i.e., the primary) and determine the template amplitude of each cluster | ||
whiteningMatrix = rez.Wrot/rez.ops.scaleproc; | ||
whiteningMatrixInv = whiteningMatrix^-1; | ||
|
||
% here we compute the amplitude of every template... | ||
% unwhiten all the templates | ||
tempsUnW = zeros(size(templates)); | ||
for t = 1:size(templates,1) | ||
tempsUnW(t,:,:) = squeeze(templates(t,:,:))*whiteningMatrixInv; | ||
end | ||
|
||
% The amplitude on each channel is the positive peak minus the negative | ||
tempChanAmps = squeeze(max(tempsUnW,[],2))-squeeze(min(tempsUnW,[],2)); | ||
|
||
% The template amplitude is the amplitude of its largest channel | ||
[tempAmpsUnscaled,template_primary] = max(tempChanAmps,[],2); | ||
%without undoing the whitening | ||
%template_amplitude = squeeze(max(templates, [], 2) - min(templates, [], 2)); | ||
%[~, template_primary] = max(template_amplitude, [], 2); | ||
|
||
template_primary = cast(template_primary, class(spike_templates)); | ||
spike_primary = template_primary(spike_templates); | ||
|
||
%% Number of samples in the overlap | ||
n_samples_overlap = round(P.overlap_s * rez.ops.fs); | ||
n_samples_overlap = cast(n_samples_overlap, class(spike_times)); | ||
%% Distance between each channel | ||
chan_dist = ((rez.xcoords - rez.xcoords').^2 + (rez.ycoords - rez.ycoords').^2).^0.5; | ||
%imagesc(chan_dist) | ||
|
||
n_spikes=numel(spike_times); | ||
n_duplicates=1; % set to 1 to initialize while loop | ||
count=0; | ||
remove_idx = []; | ||
reference_idx = []; | ||
spike_idx = [1:n_spikes]'; | ||
current_spike_times = spike_times; | ||
current_spike_idx = spike_idx; | ||
current_primaries = spike_primary; | ||
% only check nearest temporal neighbors in the list of spikes times. | ||
% but go recursively until no nearest neighbors are left that are both within the overlap | ||
% period and sufficiently nearby. | ||
% this means only ever computing a vector operation (i.e. diff(spike_times)) | ||
% rather than a matrix one (i.e. spike_times - spike_times'). | ||
diff_order=1; | ||
while 1==1 | ||
count=count+1; | ||
if n_duplicates==0 | ||
diff_order=diff_order+1; | ||
fprintf('No duplicates but simultaneous spikes haven''t been fully explored.\nNow comparing spikes separated by %g other spikes.\n',diff_order-1); | ||
end | ||
keep_idx = ~ismember(spike_idx,remove_idx); | ||
current_spike_idx = spike_idx(keep_idx); | ||
current_spike_times = spike_times(keep_idx); | ||
current_primaries = spike_primary(keep_idx); | ||
isis=current_spike_times(1+diff_order:end) - current_spike_times(1:end-diff_order); | ||
simultaneous = isis<n_samples_overlap; | ||
if any(isis<0) | ||
error('ISIs less than zero? Something is wrong.'); | ||
end | ||
if ~any(simultaneous) | ||
fprintf('No remaining simultaneous spikes.\n'); | ||
break | ||
end | ||
nearby = chan_dist(sub2ind(size(chan_dist),current_primaries(1:end-diff_order),current_primaries(1+diff_order:end)))<P.channel_separation_um; | ||
first_duplicate = find(simultaneous & nearby); % indexes the first member of the pair | ||
n_duplicates = length(first_duplicate); | ||
if ~isempty(first_duplicate) | ||
fprintf('On iteration %g, %g duplicate spike pairs were identified.\n',count,n_duplicates); | ||
amps_to_compare=tempAmpsUnscaled(current_primaries([first_duplicate first_duplicate(:)+diff_order])); | ||
if length(first_duplicate)==1 | ||
amps_to_compare = amps_to_compare(:)'; % special case requiring a dimension change | ||
end | ||
first_is_bigger = diff(amps_to_compare,[],2)<=0; | ||
remove_idx = [remove_idx ; current_spike_idx([first_duplicate(~first_is_bigger);(first_duplicate(first_is_bigger)+diff_order)])]; | ||
reference_idx = [reference_idx ; current_spike_idx([(first_duplicate(~first_is_bigger)+diff_order);first_duplicate(first_is_bigger)])]; | ||
[remove_idx,idx] = unique(remove_idx); | ||
reference_idx = reference_idx(idx); | ||
end | ||
end | ||
logical_remove_idx = ismember(spike_idx,remove_idx); | ||
rez = remove_spikes(rez,logical_remove_idx,'duplicate','reference_time',spike_times(reference_idx),... | ||
'reference_cluster',spike_templates(reference_idx)); | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
function rez = remove_lowamp_clusters(rez) | ||
|
||
threshold = 30; % uVpp | ||
amplitudes = rez.st3(:,3); | ||
spikeTemplates = uint32(rez.st3(:,2)); | ||
|
||
|
||
|
||
|
||
templates = gpuArray.zeros(rez.ops.Nchan, size(rez.W,1), size(rez.W,2), 'single'); | ||
for iNN = 1:size(templates,3) | ||
templates(:,:,iNN) = squeeze(rez.U(:,iNN,:)) * squeeze(rez.W(:,iNN,:))'; | ||
end | ||
templates = permute(templates, [3 2 1]); % now it's nTemplates x nSamples x nChannels | ||
|
||
|
||
whiteningMatrix = rez.Wrot/rez.ops.scaleproc; | ||
whiteningMatrixInv = whiteningMatrix^-1; | ||
|
||
% here we compute the amplitude of every template... | ||
|
||
% unwhiten all the templates | ||
tempsUnW = gpuArray.zeros(size(templates)); | ||
for t = 1:size(templates,1) | ||
tempsUnW(t,:,:) = squeeze(templates(t,:,:))*whiteningMatrixInv; | ||
end | ||
|
||
% The amplitude on each channel is the positive peak minus the negative | ||
tempChanAmps = squeeze(max(tempsUnW,[],2))-squeeze(min(tempsUnW,[],2)); | ||
|
||
% The template amplitude is the amplitude of its largest channel | ||
tempAmpsUnscaled = max(tempChanAmps,[],2); | ||
|
||
% assign all spikes the amplitude of their template multiplied by their | ||
% scaling amplitudes | ||
spikeAmps = tempAmpsUnscaled(spikeTemplates).*amplitudes; | ||
|
||
% take the average of all spike amps to get actual template amps (since | ||
% tempScalingAmps are equal mean for all templates) | ||
ta = clusterAverage(spikeTemplates, spikeAmps); | ||
tids = unique(spikeTemplates); | ||
tempAmps(tids) = ta; % because ta only has entries for templates that had at least one spike | ||
gain = getOr(rez.ops, 'gain', 1); | ||
tempAmps = gain*tempAmps'; | ||
|
||
below_threshold = tempAmps<threshold; | ||
|
||
to_remove = ismember(spikeTemplates,find(below_threshold)); | ||
|
||
rez = remove_spikes(rez,to_remove,'amp_threshold'); | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
function rez = remove_spikes(rez,remove_idx,label,varargin) | ||
if ~islogical(remove_idx) | ||
error('remove_idx must be logical.'); | ||
end | ||
if ~any(remove_idx) | ||
return | ||
end | ||
fprintf('Removing %g spikes from rez structure.\n',sum(remove_idx)); | ||
if ~isfield(rez,'removed') | ||
[rez.removed.cProj,rez.removed.cProjPC,rez.removed.st2,rez.removed.st3] = deal([]); | ||
rez.removed.label={}; | ||
end | ||
L2 = size(rez.removed.cProj,1); | ||
if ~isempty(rez.cProj) | ||
rez.removed.cProj = cat(1,rez.removed.cProj,rez.cProj(remove_idx,:)); | ||
rez.removed.cProjPC = cat(1,rez.removed.cProjPC,rez.cProjPC(remove_idx,:,:)); | ||
end | ||
rez.removed.st3 = cat(1,rez.removed.st3,rez.st3(remove_idx,:)); | ||
rez.removed.st2 = cat(1,rez.removed.st2,rez.st2(remove_idx,:)); | ||
|
||
rez.removed.label = cat(1,rez.removed.label,repmat({label},sum(remove_idx),1)); | ||
k=0; | ||
while k<length(varargin) | ||
k=k+1; | ||
if ~isfield(rez.removed,varargin{k}) || size(rez.removed.(varargin{k}),1)<L2 | ||
if isfield(rez.removed,varargin{k}) | ||
L = size(rez.removed.(varargin{k}),1); | ||
else | ||
L=0; | ||
end | ||
if ischar(varargin{k+1}) | ||
fill=''; | ||
rez.removed.(varargin{k}){L+1:size(rez.removed.cProj,1),1} = fill; | ||
elseif isnumeric(varargin{k+1}) | ||
fill=NaN; | ||
rez.removed.(varargin{k})(L+1:L2,1) = fill; | ||
end | ||
end | ||
if size(varargin{k+1},1)~=sum(remove_idx) | ||
error('optional arg in incorrect number of rows.'); | ||
end | ||
rez.removed.(varargin{k}) = cat(1,rez.removed.(varargin{k}),varargin{k+1}); | ||
k=k+1; | ||
end | ||
|
||
if ~isempty(rez.cProj) | ||
rez.cProj = rez.cProj(~remove_idx,:); | ||
rez.cProjPC = rez.cProjPC(~remove_idx,:,:); | ||
end | ||
rez.st3 = rez.st3(~remove_idx,:); | ||
rez.st2 = rez.st2(~remove_idx,:); | ||
|
||
rez = recompute_clusters(rez); | ||
|
||
end |