Skip to content

Commit

Permalink
add optional postProcessing steps
Browse files Browse the repository at this point in the history
1. remove_ks2_duplicate_spikes - removes duplicate spikes resulting from multiple templates explaining variance from a single waveform
  • Loading branch information
agbondy committed Aug 27, 2020
1 parent baf5e64 commit c4ed558
Show file tree
Hide file tree
Showing 5 changed files with 308 additions and 0 deletions.
29 changes: 29 additions & 0 deletions postProcess/recompute_clusters.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
function rez = recompute_clusters(rez)

%recompute cluster statistics after a post-processing operation

NchanNear = min(rez.ops.Nchan, 32);
Nnearest = min(rez.ops.Nchan, 32);
iC = getClosestChannels(rez, rez.ops.sigmaMask, NchanNear);
Nfilt = size(rez.W,2); % new number of templates
Nrank = 3;
Nchan = rez.ops.Nchan;
Params = double([0 Nfilt 0 0 size(rez.W,1) Nnearest ...
Nrank 0 0 Nchan NchanNear rez.ops.nt0min 0]); % make a new Params to pass on parameters to CUDA

% we need to re-estimate the spatial profiles
[Ka, Kb] = getKernels(rez.ops, 10, 1); % we get the time upsampling kernels again
[~, iW] = max(abs(rez.dWU(rez.ops.nt0min, :, :)), [], 2); % find the peak abs channel for each template
iW = squeeze(int32(iW));
[rez.W, rez.U, rez.mu] = mexSVDsmall2(Params, rez.dWU, rez.W, iC-1, iW-1, Ka, Kb); % we run SVD

[WtW, iList] = getMeWtW(single(rez.W), single(rez.U), Nnearest); % we re-compute similarity scores between templates
rez.iList = iList; % over-write the list of nearest templates

rez.simScore = gather(max(WtW, [], 3));

rez.iNeigh = gather(iList(:, 1:Nfilt)); % get the new neighbor templates
rez.iNeighPC = gather(iC(:, iW(1:Nfilt))); % get the new neighbor channels

rez.Wphy = cat(1, zeros(1+rez.ops.nt0min, Nfilt, Nrank), rez.W); % for Phy, we need to pad the spikes with zeros so the spikes are aligned to the center of the window
end
27 changes: 27 additions & 0 deletions postProcess/remove_inactive_clusters.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
function rez = remove_inactive_clusters(rez,varargin)

p=inputParser;
p.addParameter('min_spikes',50,@(x)validateattributes(x,{'numeric'},{'scalar','nonnegative'}));
p.parse(varargin{:});
params=p.Results;
[template_idx,~,uniqueIdx]=unique(rez.st3(:,2));
n_spikes=accumarray(uniqueIdx,1);
n_spikes = n_spikes(template_idx>0);
template_idx = template_idx(template_idx>0);

below_n_spikes_threshold= n_spikes<params.min_spikes;
recording_duration_sec = rez.ops.sampsToRead./rez.ops.fs;
below_rate_threshold = n_spikes<rez.ops.minFR*recording_duration_sec;

remove_clusters = below_n_spikes_threshold | below_rate_threshold;

spikes_to_remove = ismember(rez.st3(:,2),template_idx(remove_clusters));

rez.inactive = template_idx(remove_clusters);

rez = remove_spikes(rez,spikes_to_remove,'inactive');
rez = recompute_clusters(rez);

fprintf('Removed %g inactive clusters.\n',sum(remove_clusters));

end
145 changes: 145 additions & 0 deletions postProcess/remove_ks2_duplicate_spikes.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
% REMOVE_KS2_DUPLICATE_SPIKES2 Double-counted spikes are hard to avoid with
% Kilosort's template matching algorithm since the overall fit can be
% improved by having multiple templates jointly account for a single variable waveform.
%
% This function takes the kilosort2 output rez and identifies pair of
% spikes that are close together in time and space. The temporal threshold
% is give by the parameter OVERLAP_S which is 5e-4 (0.5ms by default) and
% the spatial threshold (applied to the template primary sites) is given by
% CHANNEL_SEPARATION_UM and is 50 by default.
%
% From these spike pairs, it identifies the pair with the larger template as
% being the "main" or "reference" cluster and the duplicate spikes from the
% other cluster are removed.
%
% All spike pairs are considered, not just those from CCG-contaminated
% pairs, as in REMOVE_KS2_DUPLICATE_SPIKES2.
%
%=INPUT
%
% rez structure
%
%=OPTIONAL INPUT, NAME-VALUE PAIRS
%
% overlap_s
% the time interval, in second, within which a sequence of spikes are
% vetted for duplicates.
%
% channel_separation_um
% When the primay channels of two spikes are within this distance, in
% microns, then the two spikes are vetted for duplicate.
%
%=EXAMPLE
%
% >> rez = remove_ks2_duplicate_spikes(rez)
function rez = remove_ks2_duplicate_spikes(rez, varargin)
input_parser = inputParser;
addParameter(input_parser, 'overlap_s', 5e-4, @(x) (isnumeric(x)))
addParameter(input_parser, 'channel_separation_um', 50, @(x) (ischar(x)))
parse(input_parser, varargin{:});
P = input_parser.Results;

spike_times = uint64(rez.st3(:,1));
spike_templates = uint32(rez.st3(:,2));

rez.U=gather(rez.U);
rez.W = gather(rez.W);
templates = zeros(rez.ops.Nchan, size(rez.W,1), size(rez.W,2), 'single');
for iNN = 1:size(templates,3)
templates(:,:,iNN) = squeeze(rez.U(:,iNN,:)) * squeeze(rez.W(:,iNN,:))';
end
templates = permute(templates, [3 2 1]); % now it's nTemplates x nSamples x nChannels

%% Make sure that the spike times are sorted
if ~issorted(spike_times)
[spike_times, I] = sort(spike_times);
spike_templates = spike_templates(I);
end
%% deal with cluster 0
if any(spike_templates==0)
error('Currently this function can''t deal with existence of cluster 0. Should be OK since it ought to be run first in the post-processing.');
end
%% Determine the channel where each spike had that largest amplitude (i.e., the primary) and determine the template amplitude of each cluster
whiteningMatrix = rez.Wrot/rez.ops.scaleproc;
whiteningMatrixInv = whiteningMatrix^-1;

% here we compute the amplitude of every template...
% unwhiten all the templates
tempsUnW = zeros(size(templates));
for t = 1:size(templates,1)
tempsUnW(t,:,:) = squeeze(templates(t,:,:))*whiteningMatrixInv;
end

% The amplitude on each channel is the positive peak minus the negative
tempChanAmps = squeeze(max(tempsUnW,[],2))-squeeze(min(tempsUnW,[],2));

% The template amplitude is the amplitude of its largest channel
[tempAmpsUnscaled,template_primary] = max(tempChanAmps,[],2);
%without undoing the whitening
%template_amplitude = squeeze(max(templates, [], 2) - min(templates, [], 2));
%[~, template_primary] = max(template_amplitude, [], 2);

template_primary = cast(template_primary, class(spike_templates));
spike_primary = template_primary(spike_templates);

%% Number of samples in the overlap
n_samples_overlap = round(P.overlap_s * rez.ops.fs);
n_samples_overlap = cast(n_samples_overlap, class(spike_times));
%% Distance between each channel
chan_dist = ((rez.xcoords - rez.xcoords').^2 + (rez.ycoords - rez.ycoords').^2).^0.5;
%imagesc(chan_dist)

n_spikes=numel(spike_times);
n_duplicates=1; % set to 1 to initialize while loop
count=0;
remove_idx = [];
reference_idx = [];
spike_idx = [1:n_spikes]';
current_spike_times = spike_times;
current_spike_idx = spike_idx;
current_primaries = spike_primary;
% only check nearest temporal neighbors in the list of spikes times.
% but go recursively until no nearest neighbors are left that are both within the overlap
% period and sufficiently nearby.
% this means only ever computing a vector operation (i.e. diff(spike_times))
% rather than a matrix one (i.e. spike_times - spike_times').
diff_order=1;
while 1==1
count=count+1;
if n_duplicates==0
diff_order=diff_order+1;
fprintf('No duplicates but simultaneous spikes haven''t been fully explored.\nNow comparing spikes separated by %g other spikes.\n',diff_order-1);
end
keep_idx = ~ismember(spike_idx,remove_idx);
current_spike_idx = spike_idx(keep_idx);
current_spike_times = spike_times(keep_idx);
current_primaries = spike_primary(keep_idx);
isis=current_spike_times(1+diff_order:end) - current_spike_times(1:end-diff_order);
simultaneous = isis<n_samples_overlap;
if any(isis<0)
error('ISIs less than zero? Something is wrong.');
end
if ~any(simultaneous)
fprintf('No remaining simultaneous spikes.\n');
break
end
nearby = chan_dist(sub2ind(size(chan_dist),current_primaries(1:end-diff_order),current_primaries(1+diff_order:end)))<P.channel_separation_um;
first_duplicate = find(simultaneous & nearby); % indexes the first member of the pair
n_duplicates = length(first_duplicate);
if ~isempty(first_duplicate)
fprintf('On iteration %g, %g duplicate spike pairs were identified.\n',count,n_duplicates);
amps_to_compare=tempAmpsUnscaled(current_primaries([first_duplicate first_duplicate(:)+diff_order]));
if length(first_duplicate)==1
amps_to_compare = amps_to_compare(:)'; % special case requiring a dimension change
end
first_is_bigger = diff(amps_to_compare,[],2)<=0;
remove_idx = [remove_idx ; current_spike_idx([first_duplicate(~first_is_bigger);(first_duplicate(first_is_bigger)+diff_order)])];
reference_idx = [reference_idx ; current_spike_idx([(first_duplicate(~first_is_bigger)+diff_order);first_duplicate(first_is_bigger)])];
[remove_idx,idx] = unique(remove_idx);
reference_idx = reference_idx(idx);
end
end
logical_remove_idx = ismember(spike_idx,remove_idx);
rez = remove_spikes(rez,logical_remove_idx,'duplicate','reference_time',spike_times(reference_idx),...
'reference_cluster',spike_templates(reference_idx));
end
52 changes: 52 additions & 0 deletions postProcess/remove_lowamp_clusters.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
function rez = remove_lowamp_clusters(rez)

threshold = 30; % uVpp
amplitudes = rez.st3(:,3);
spikeTemplates = uint32(rez.st3(:,2));




templates = gpuArray.zeros(rez.ops.Nchan, size(rez.W,1), size(rez.W,2), 'single');
for iNN = 1:size(templates,3)
templates(:,:,iNN) = squeeze(rez.U(:,iNN,:)) * squeeze(rez.W(:,iNN,:))';
end
templates = permute(templates, [3 2 1]); % now it's nTemplates x nSamples x nChannels


whiteningMatrix = rez.Wrot/rez.ops.scaleproc;
whiteningMatrixInv = whiteningMatrix^-1;

% here we compute the amplitude of every template...

% unwhiten all the templates
tempsUnW = gpuArray.zeros(size(templates));
for t = 1:size(templates,1)
tempsUnW(t,:,:) = squeeze(templates(t,:,:))*whiteningMatrixInv;
end

% The amplitude on each channel is the positive peak minus the negative
tempChanAmps = squeeze(max(tempsUnW,[],2))-squeeze(min(tempsUnW,[],2));

% The template amplitude is the amplitude of its largest channel
tempAmpsUnscaled = max(tempChanAmps,[],2);

% assign all spikes the amplitude of their template multiplied by their
% scaling amplitudes
spikeAmps = tempAmpsUnscaled(spikeTemplates).*amplitudes;

% take the average of all spike amps to get actual template amps (since
% tempScalingAmps are equal mean for all templates)
ta = clusterAverage(spikeTemplates, spikeAmps);
tids = unique(spikeTemplates);
tempAmps(tids) = ta; % because ta only has entries for templates that had at least one spike
gain = getOr(rez.ops, 'gain', 1);
tempAmps = gain*tempAmps';

below_threshold = tempAmps<threshold;

to_remove = ismember(spikeTemplates,find(below_threshold));

rez = remove_spikes(rez,to_remove,'amp_threshold');

end
55 changes: 55 additions & 0 deletions postProcess/remove_spikes.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
function rez = remove_spikes(rez,remove_idx,label,varargin)
if ~islogical(remove_idx)
error('remove_idx must be logical.');
end
if ~any(remove_idx)
return
end
fprintf('Removing %g spikes from rez structure.\n',sum(remove_idx));
if ~isfield(rez,'removed')
[rez.removed.cProj,rez.removed.cProjPC,rez.removed.st2,rez.removed.st3] = deal([]);
rez.removed.label={};
end
L2 = size(rez.removed.cProj,1);
if ~isempty(rez.cProj)
rez.removed.cProj = cat(1,rez.removed.cProj,rez.cProj(remove_idx,:));
rez.removed.cProjPC = cat(1,rez.removed.cProjPC,rez.cProjPC(remove_idx,:,:));
end
rez.removed.st3 = cat(1,rez.removed.st3,rez.st3(remove_idx,:));
rez.removed.st2 = cat(1,rez.removed.st2,rez.st2(remove_idx,:));

rez.removed.label = cat(1,rez.removed.label,repmat({label},sum(remove_idx),1));
k=0;
while k<length(varargin)
k=k+1;
if ~isfield(rez.removed,varargin{k}) || size(rez.removed.(varargin{k}),1)<L2
if isfield(rez.removed,varargin{k})
L = size(rez.removed.(varargin{k}),1);
else
L=0;
end
if ischar(varargin{k+1})
fill='';
rez.removed.(varargin{k}){L+1:size(rez.removed.cProj,1),1} = fill;
elseif isnumeric(varargin{k+1})
fill=NaN;
rez.removed.(varargin{k})(L+1:L2,1) = fill;
end
end
if size(varargin{k+1},1)~=sum(remove_idx)
error('optional arg in incorrect number of rows.');
end
rez.removed.(varargin{k}) = cat(1,rez.removed.(varargin{k}),varargin{k+1});
k=k+1;
end

if ~isempty(rez.cProj)
rez.cProj = rez.cProj(~remove_idx,:);
rez.cProjPC = rez.cProjPC(~remove_idx,:,:);
end
rez.st3 = rez.st3(~remove_idx,:);
rez.st2 = rez.st2(~remove_idx,:);

rez = recompute_clusters(rez);

end

0 comments on commit c4ed558

Please sign in to comment.