Skip to content

Commit

Permalink
Merge pull request MouseLand#234 from Brody-Lab/master
Browse files Browse the repository at this point in the history
optional post-processing step for duplicate spike removal

Former-commit-id: d703473
  • Loading branch information
marius10p authored Oct 28, 2020
2 parents c15f7d2 + 1d429f6 commit a7d1edd
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 24 deletions.
4 changes: 4 additions & 0 deletions main_kilosort.m
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
% main tracking and template matching algorithm
rez = learnAndSolve8b(rez);

% OPTIONAL: remove double-counted spikes - solves issue in which individual spikes are assigned to multiple templates.
% See issue 29: https://github.com/MouseLand/Kilosort2/issues/29
%rez = remove_ks2_duplicate_spikes(rez);

% final merges
rez = find_merges(rez, 1);

Expand Down
134 changes: 134 additions & 0 deletions postProcess/remove_ks2_duplicate_spikes.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
% REMOVE_KS2_DUPLICATE_SPIKES2 Double-counted spikes are hard to avoid with
% Kilosort's template matching algorithm since the overall fit can be
% improved by having multiple templates jointly account for a single variable waveform.
%
% This function takes the kilosort2 output rez and identifies pair of
% spikes that are close together in time and space. The temporal threshold
% is give by the parameter OVERLAP_S which is 0.5ms by default and
% the spatial threshold (applied to the template primary sites) is given by
% CHANNEL_SEPARATION_UM and is 100 by default.
%
% From these spike pairs, it identifies the pair with the larger template as
% being the "main" or "reference" cluster and the duplicate spikes from the
% other cluster are removed.
%
% All spike pairs are considered, not just those from CCG-contaminated
% pairs, as in REMOVE_KS2_DUPLICATE_SPIKES2.
%
% Adrian Bondy, 2020
%
%=INPUT
%
% rez structure
%
%=OPTIONAL INPUT, NAME-VALUE PAIRS
%
% overlap_s
% the time interval, in second, within which a sequence of spikes are
% vetted for duplicates.
%
% channel_separation_um
% When the primay channels of two spikes are within this distance, in
% microns, then the two spikes are vetted for duplicate.
%
%=EXAMPLE
%
% >> rez = remove_ks2_duplicate_spikes(rez)
function rez = remove_ks2_duplicate_spikes(rez, varargin)
input_parser = inputParser;
addParameter(input_parser, 'overlap_s', 5e-4, @(x) (isnumeric(x))) % the temporal window within which pairs of spikes will be considered duplicates (if they are also within the spatial window)
addParameter(input_parser, 'channel_separation_um', 100, @(x) (ischar(x))) % the spatial window within which pairs of spikes will be considered duplicates (if they are also within the temporal window)
parse(input_parser, varargin{:});
P = input_parser.Results;

spike_times = uint64(rez.st3(:,1));
spike_templates = uint32(rez.st3(:,2));

rez.U=gather(rez.U);
rez.W = gather(rez.W);
templates = zeros(rez.ops.Nchan, size(rez.W,1), size(rez.W,2), 'single');
for iNN = 1:size(templates,3)
templates(:,:,iNN) = squeeze(rez.U(:,iNN,:)) * squeeze(rez.W(:,iNN,:))';
end
templates = permute(templates, [3 2 1]); % now it's nTemplates x nSamples x nChannels
n_spikes=numel(spike_times);
%% Make sure that the spike times are sorted
if ~issorted(spike_times)
[spike_times, spike_idx] = sort(spike_times);
spike_templates = spike_templates(spike_idx);
else
spike_idx=(1:n_spikes)';
end
%% deal with cluster 0
if any(spike_templates==0)
error('Currently this function can''t deal with existence of cluster 0. Should be OK since it ought to be run first in the post-processing.');
end
%% Determine the channel where each spike had that largest amplitude (i.e., the primary) and determine the template amplitude of each cluster
whiteningMatrix = rez.Wrot/rez.ops.scaleproc;
whiteningMatrixInv = whiteningMatrix^-1;

% here we compute the amplitude of every template...
% unwhiten all the templates
tempsUnW = zeros(size(templates));
for t = 1:size(templates,1)
tempsUnW(t,:,:) = squeeze(templates(t,:,:))*whiteningMatrixInv;
end

% The amplitude on each channel is the positive peak minus the negative
tempChanAmps = squeeze(max(tempsUnW,[],2))-squeeze(min(tempsUnW,[],2));

% The template amplitude is the amplitude of its largest channel
[tempAmpsUnscaled,template_primary] = max(tempChanAmps,[],2);
%without undoing the whitening
%template_amplitude = squeeze(max(templates, [], 2) - min(templates, [], 2));
%[~, template_primary] = max(template_amplitude, [], 2);

template_primary = cast(template_primary, class(spike_templates));
spike_primary = template_primary(spike_templates);

%% Number of samples in the overlap
n_samples_overlap = round(P.overlap_s * rez.ops.fs);
n_samples_overlap = cast(n_samples_overlap, class(spike_times));
%% Distance between each channel
chan_dist = ((rez.xcoords - rez.xcoords').^2 + (rez.ycoords - rez.ycoords').^2).^0.5;
n_spikes=numel(spike_times);
remove_idx = [];
reference_idx = [];
% check pairs of spikes in the time-ordered list for being close together in space and time.
% Check pairs that are separated by N other spikes,
% starting with N=0. Increasing N until there are no spikes within the temporal overlap window.
% This means only ever computing a vector operation (i.e. diff(spike_times))
% rather than a matrix one (i.e. spike_times - spike_times').
diff_order=0;
while 1==1
diff_order=diff_order+1;
fprintf('Now comparing spikes separated by %g other spikes.\n',diff_order-1);
isis=spike_times(1+diff_order:end) - spike_times(1:end-diff_order);
simultaneous = isis<n_samples_overlap;
if any(isis<0)
error('ISIs less than zero? Something is wrong because spike times should be sorted.');
end
if ~any(simultaneous)
fprintf('No remaining simultaneous spikes.\n');
break
end
nearby = chan_dist(sub2ind(size(chan_dist),spike_primary(1:end-diff_order),spike_primary(1+diff_order:end)))<P.channel_separation_um;
first_duplicate = find(simultaneous & nearby); % indexes the first (earliest in time) member of the pair
n_duplicates = length(first_duplicate);
if ~isempty(first_duplicate)
fprintf('On iteration %g, %g duplicate spike pairs were identified.\n',diff_order,n_duplicates);
amps_to_compare=tempAmpsUnscaled(spike_templates([first_duplicate first_duplicate(:)+diff_order]));
if length(first_duplicate)==1
amps_to_compare = amps_to_compare(:)'; % special case requiring a dimension change
end
first_is_bigger = diff(amps_to_compare,[],2)<=0;
remove_idx = [remove_idx ; spike_idx([first_duplicate(~first_is_bigger);(first_duplicate(first_is_bigger)+diff_order)])];
reference_idx = [reference_idx ; spike_idx([(first_duplicate(~first_is_bigger)+diff_order);first_duplicate(first_is_bigger)])];
end
end
[remove_idx,idx] = unique(remove_idx);
reference_idx = reference_idx(idx);
logical_remove_idx = ismember((1:n_spikes)',remove_idx);
rez = remove_spikes(rez,logical_remove_idx,'duplicate','reference_time',spike_times(reference_idx),...
'reference_cluster',spike_templates(reference_idx));
end
52 changes: 52 additions & 0 deletions postProcess/remove_spikes.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
function rez = remove_spikes(rez,remove_idx,label,varargin)
if ~islogical(remove_idx)
error('remove_idx must be logical.');
end
if ~any(remove_idx)
return
end
fprintf('Removing %g spikes from rez structure.\n',sum(remove_idx));
if ~isfield(rez,'removed')
[rez.removed.cProj,rez.removed.cProjPC,rez.removed.st2,rez.removed.st3] = deal([]);
rez.removed.label={};
end
L2 = size(rez.removed.cProj,1);
if ~isempty(rez.cProj)
rez.removed.cProj = cat(1,rez.removed.cProj,rez.cProj(remove_idx,:));
rez.removed.cProjPC = cat(1,rez.removed.cProjPC,rez.cProjPC(remove_idx,:,:));
end
rez.removed.st3 = cat(1,rez.removed.st3,rez.st3(remove_idx,:));
rez.removed.st2 = cat(1,rez.removed.st2,rez.st2(remove_idx,:));

rez.removed.label = cat(1,rez.removed.label,repmat({label},sum(remove_idx),1));
k=0;
while k<length(varargin)
k=k+1;
if ~isfield(rez.removed,varargin{k}) || size(rez.removed.(varargin{k}),1)<L2
if isfield(rez.removed,varargin{k})
L = size(rez.removed.(varargin{k}),1);
else
L=0;
end
if ischar(varargin{k+1})
fill='';
rez.removed.(varargin{k}){L+1:size(rez.removed.cProj,1),1} = fill;
elseif isnumeric(varargin{k+1})
fill=NaN;
rez.removed.(varargin{k})(L+1:L2,1) = fill;
end
end
if size(varargin{k+1},1)~=sum(remove_idx)
error('optional arg in incorrect number of rows.');
end
rez.removed.(varargin{k}) = cat(1,rez.removed.(varargin{k}),varargin{k+1});
k=k+1;
end

if ~isempty(rez.cProj)
rez.cProj = rez.cProj(~remove_idx,:);
rez.cProjPC = rez.cProjPC(~remove_idx,:,:);
end
rez.st3 = rez.st3(~remove_idx,:);
rez.st2 = rez.st2(~remove_idx,:);
end
9 changes: 2 additions & 7 deletions postProcess/set_cutoff.m
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,5 @@
% we sometimes get NaNs, why? replace with full contamination
rez.est_contam_rate(isnan(rez.est_contam_rate)) = 1;

% remove spikes assigned to the 0 cluster
ix = rez.st3(:,2)==0;
rez.st3(ix, :) = [];
if ~isempty(rez.cProj)
rez.cProj(ix, :) = []; % remove their template projections too
rez.cProjPC(ix, :,:) = []; % and their PC projections
end
% remove spikes from the 0th cluster
rez = remove_spikes(rez,rez.st3(:,2)==0,'below_cutoff');
39 changes: 24 additions & 15 deletions preProcess/clusterSingleBatches.m
Original file line number Diff line number Diff line change
Expand Up @@ -178,24 +178,33 @@
% sort by manifold embedding algorithm
% iorig is the sorting of the batches
% ccbsort is the resorted matrix (useful for diagnosing drift)
[ccbsort, iorig] = sortBatches2(ccb0);
[ccbsort, iorig, xs] = sortBatches2(rez.ccb);
rez.iorig = gather(iorig);
rez.ccbsort = gather(ccbsort);

% some mandatory diagnostic plots to understand drift in this dataset
figure;
subplot(1,2,1)
imagesc(ccb0, [-5 5]); drawnow
xlabel('batches')
ylabel('batches')
title('batch to batch distance')

subplot(1,2,2)
imagesc(ccbsort, [-5 5]); drawnow
xlabel('sorted batches')
ylabel('sorted batches')
title('AFTER sorting')

rez.iorig = gather(iorig);
rez.ccbsort = gather(ccbsort);
% distance matrices
subplot(2,2,1)
imagesc(rez.ccb, [-5 5]); axis tight
xlabel('Batches')
ylabel('Batches')
title('Distance Matrix, before sorting')
subplot(2,2,2)
imagesc(rez.ccbsort, [-5 5]); axis tight
xlabel('Sorted Batches')
ylabel('Sorted Batches')
title('Distance Matrix, after sorting')
% drift plots (in a 1D embedding)
subplot(2,2,3);
plot(xs);set(gca,'ytick',[]);xlabel('Batches');
title('Drift Plot, before sorting');axis tight
ylabel({'Manifold Position'});
subplot(2,2,4);
plot(xs(iorig));set(gca,'ytick',[]);xlabel('Sorted batches');
title('Drift Plot, after sorting');axis tight
ylabel({'Manifold Position'});
drawnow;

fprintf('time %2.2f, Re-ordered %d batches. \n', toc, nBatches)
%%
11 changes: 10 additions & 1 deletion preProcess/preprocessDataSub.m
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,13 @@
fprintf('Time %3.0fs. Loading raw data and applying filters... \n', toc);

fid = fopen(ops.fbinary, 'r'); % open for reading raw data
if fid<3
error('Could not open %s for reading.',ops.fbinary);
end
fidW = fopen(ops.fproc, 'w'); % open for writing processed data
if fidW<3
error('Could not open %s for writing.',ops.fproc);
end

for ibatch = 1:Nbatch
% we'll create a binary file of batches of NT samples, which overlap consecutively on ops.ntbuff samples
Expand Down Expand Up @@ -103,7 +109,10 @@
datr = datr * Wrot; % whiten the data and scale by 200 for int16 range

datcpu = gather(int16(datr)); % convert to int16, and gather on the CPU side
fwrite(fidW, datcpu, 'int16'); % write this batch to binary file
count = fwrite(fidW, datcpu, 'int16'); % write this batch to binary file
if count~=numel(datcpu)
error('Error writing batch %g to %s. Check available disk space.',ibatch,ops.fproc);
end
end

rez.Wrot = gather(Wrot); % gather the whitening matrix as a CPU variable
Expand Down
2 changes: 1 addition & 1 deletion preProcess/sortBatches2.m
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function [ccb1, isort] = sortBatches2(ccb0)
function [ccb1, isort, xs] = sortBatches2(ccb0)
% takes as input a matrix of nBatches by nBatches containing
% dissimilarities.
% outputs a matrix of sorted batches, and the sorting order, such that
Expand Down

0 comments on commit a7d1edd

Please sign in to comment.