Skip to content

Commit

Permalink
Merge pull request #203 from jenniferColonell/master
Browse files Browse the repository at this point in the history
deterministic version of KS2
  • Loading branch information
marius10p authored May 18, 2020
2 parents 48bf2b8 + d97d42a commit f833820
Show file tree
Hide file tree
Showing 19 changed files with 854 additions and 234 deletions.
72 changes: 60 additions & 12 deletions CUDA/mexClustering2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,16 @@ __global__ void bestFilter(const double *Params, const bool *match,
//////////////////////////////////////////////////////////////////////////////////////////
__global__ void average_snips(const double *Params, const int *iC, const int *call,
const int *id, const float *uproj, const float *cmax, float *WU){


//Nfilt blocks
//Thread grid = (NrankPC, NchanNear)
//This implementation does not work correctly for real data!
//Since this_chan is function of the spike -- spikes assigned to a given template
//will have max channels that span a 2-3 channel range -- different (tidx, tidy)
//pairs can wind up trying to add to the same element of dWU, resulting in
//collisions and incorrect results. Use the single-threaded version
//average_snips_v2 instead. Speed hit is only ~ 5-6 seconds out of 360 sec for a
//typical 2 hour Neuropixels 1.0 dataset.
int my_chan, this_chan, tidx, tidy, bid, ind, Nspikes, NrankPC, NchanNear, Nchan;
float xsum = 0.0f;

Expand All @@ -106,15 +115,51 @@ __global__ void average_snips(const double *Params, const int *iC, const int *ca
tidy = threadIdx.y;
bid = blockIdx.x;

for(ind=0; ind<Nspikes;ind++)
for(ind=0; ind<Nspikes;ind++) {
if (id[ind]==bid){
my_chan = call[ind];
this_chan = iC[tidy + NchanNear * my_chan];
xsum = uproj[tidx + NrankPC*tidy + NrankPC*NchanNear * ind];
WU[tidx + NrankPC*this_chan + NrankPC*Nchan * bid] += xsum;
WU[tidx + NrankPC*this_chan + NrankPC*Nchan * bid] += xsum;
}
}
}

}

//////////////////////////////////////////////////////////////////////////////////////////
__global__ void average_snips_v2(const double *Params, const int *iC, const int *call,
const int *id, const float *uproj, const float *cmax, float *WU){


// jic, version with no threading over features, to avoid
// collisions when summing WU
// run

int my_chan, this_chan, bid, ind, Nspikes, NrankPC, NchanNear, Nchan;
float xsum = 0.0f;
int chanIndex, pcIndex;

Nspikes = (int) Params[0];
NrankPC = (int) Params[1];
Nchan = (int) Params[7];
NchanNear = (int) Params[6];


bid = blockIdx.x;

for(ind=0; ind<Nspikes;ind++)
if (id[ind]==bid){
my_chan = call[ind];
for (chanIndex = 0; chanIndex < NchanNear; ++chanIndex) {
this_chan = iC[chanIndex + NchanNear * my_chan];
for (pcIndex = 0; pcIndex < NrankPC; ++pcIndex) {
xsum = uproj[pcIndex + NrankPC*chanIndex + NrankPC*NchanNear * ind];
WU[pcIndex + NrankPC*this_chan + NrankPC*Nchan * bid] += xsum;
}
}

}
}


//////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -162,8 +207,8 @@ __global__ void sum_dWU(const double *Params, const float *bigArray, float *WU)
int tid,bid, ind, Nfilters, Nthreads, Nfeatures, Nblocks, NfeatW, nWU, nElem;
float sum = 0.0f;

Nfeatures = (int) Params[1];
NfeatW = (int) Params[4];
Nfeatures = (int) Params[1]; //NrankPC, number of pcs
NfeatW = (int) Params[4]; //Nchan*nPC
Nfilters = (int) Params[2];
Nthreads = blockDim.x;
Nblocks = gridDim.x;
Expand Down Expand Up @@ -248,7 +293,7 @@ void mexFunction(int nlhs, mxArray *plhs[],
{
/* Declare input variables*/
double *Params, *d_Params;
unsigned int Nchan, NrankPC, Nspikes, Nfilters, NchanNear;
unsigned int Nchan, NrankPC, Nspikes, Nfilters;


/* Initialize the MathWorks GPU API. */
Expand All @@ -259,7 +304,6 @@ void mexFunction(int nlhs, mxArray *plhs[],
Nspikes = (unsigned int) Params[0];
Nfilters = (unsigned int) Params[2];
NrankPC = (unsigned int) Params[1];
NchanNear = (unsigned int) Params[6];
Nchan = (unsigned int) Params[7];

// copy Params to GPU
Expand Down Expand Up @@ -319,15 +363,19 @@ void mexFunction(int nlhs, mxArray *plhs[],
bestFilter<<<40, 256>>>(d_Params, d_iMatch, d_iC, d_call, d_cmax, d_id, d_x);

// average all spikes for same template -- ORIGINAL
dim3 thNN(NrankPC, NchanNear);
average_snips<<<Nfilters, thNN>>>(d_Params, d_iC, d_call, d_id, d_uproj, d_cmax, d_dWU);

// dim3 thNN(NrankPC, NchanNear);
// average_snips<<<Nfilters, thNN>>>(d_Params, d_iC, d_call, d_id, d_uproj, d_cmax, d_dWU);

// average all spikes for same template -- threaded over filters, but not features
// avoid collision when adding to elements of d_dWU
average_snips_v2<<<Nfilters, 1>>>(d_Params, d_iC, d_call, d_id, d_uproj, d_cmax, d_dWU);

//-------------------------------------------------
//jic for running average_snips_v3 with Nfeature threads
// float *d_bigArray;
// int bSize;
// int NfeatW = (int) Params[4];
// int Nfeatures = (int) Params[1];
// bSize = Nfeatures*NfeatW*Nfilters;
// cudaMalloc(&d_bigArray, bSize*sizeof(float) );
// cudaMemset(d_bigArray, 0, bSize*sizeof(float) );
Expand All @@ -336,7 +384,7 @@ void mexFunction(int nlhs, mxArray *plhs[],
// d_uproj, d_cmax, d_bigArray);
// sum_dWU<<<128,1024>>>( d_Params, d_bigArray, d_dWU );
// cudaFree(d_bigArray);
//-------------------------------------------------
//-------------------------------------------------


count_spikes<<<7, 256>>>(d_Params, d_id, d_nsp, d_x, d_V);
Expand Down
2 changes: 1 addition & 1 deletion CUDA/mexDistances2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ __global__ void computeCost(const double *Params, const float *Ws, const float *
int j, tid, bid, Nspikes, my_chan, this_chan, Nchan, NrankPC, NchanNear, Nthreads, k;
float xsum = 0.0f, Ci;

Nspikes = (int) Params[0];
Nspikes = (int) Params[0]; //more accurately, number of comparisons, Nfilt*Nbatch
Nchan = (int) Params[7];
NrankPC = (int) Params[1];
NchanNear = (int) Params[6];
Expand Down
14 changes: 13 additions & 1 deletion CUDA/mexGPUall.m
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,21 @@
% Matlab GPU library first (see README files for platform-specific
% information)

enableStableMode = true;

mexcuda -largeArrayDims mexThSpkPC.cu
mexcuda -largeArrayDims mexGetSpikes2.cu
mexcuda -largeArrayDims mexMPnu8.cu

if enableStableMode
% For algorithm development purposes which require guaranteed
% deterministic calculations, add -DENSURE_DETERM swtich to
% compile line for mexMPnu8.cu. -DENABLE_STABLEMODE must also
% be specified. This version will run ~2X slower than the
% non deterministic version.
mexcuda -largeArrayDims -dynamic -DENABLE_STABLEMODE mexMPnu8.cu
else
mexcuda -largeArrayDims mexMPnu8.cu
end

mexcuda -largeArrayDims mexSVDsmall2.cu
mexcuda -largeArrayDims mexWtW2.cu
Expand Down
Loading

0 comments on commit f833820

Please sign in to comment.