diff --git a/CUDA/mexClustering2.cu b/CUDA/mexClustering2.cu index 3abef644..8a570019 100644 --- a/CUDA/mexClustering2.cu +++ b/CUDA/mexClustering2.cu @@ -93,7 +93,16 @@ __global__ void bestFilter(const double *Params, const bool *match, ////////////////////////////////////////////////////////////////////////////////////////// __global__ void average_snips(const double *Params, const int *iC, const int *call, const int *id, const float *uproj, const float *cmax, float *WU){ - + + //Nfilt blocks + //Thread grid = (NrankPC, NchanNear) + //This implementation does not work correctly for real data! + //Since this_chan is function of the spike -- spikes assigned to a given template + //will have max channels that span a 2-3 channel range -- different (tidx, tidy) + //pairs can wind up trying to add to the same element of dWU, resulting in + //collisions and incorrect results. Use the single-threaded version + //average_snips_v2 instead. Speed hit is only ~ 5-6 seconds out of 360 sec for a + //typical 2 hour Neuropixels 1.0 dataset. int my_chan, this_chan, tidx, tidy, bid, ind, Nspikes, NrankPC, NchanNear, Nchan; float xsum = 0.0f; @@ -106,15 +115,51 @@ __global__ void average_snips(const double *Params, const int *iC, const int *ca tidy = threadIdx.y; bid = blockIdx.x; - for(ind=0; ind>>(d_Params, d_iMatch, d_iC, d_call, d_cmax, d_id, d_x); // average all spikes for same template -- ORIGINAL - dim3 thNN(NrankPC, NchanNear); - average_snips<<>>(d_Params, d_iC, d_call, d_id, d_uproj, d_cmax, d_dWU); - +// dim3 thNN(NrankPC, NchanNear); +// average_snips<<>>(d_Params, d_iC, d_call, d_id, d_uproj, d_cmax, d_dWU); + // average all spikes for same template -- threaded over filters, but not features + // avoid collision when adding to elements of d_dWU + average_snips_v2<<>>(d_Params, d_iC, d_call, d_id, d_uproj, d_cmax, d_dWU); //------------------------------------------------- //jic for running average_snips_v3 with Nfeature threads // float *d_bigArray; // int bSize; +// int NfeatW = (int) Params[4]; +// int Nfeatures = (int) Params[1]; // bSize = Nfeatures*NfeatW*Nfilters; // cudaMalloc(&d_bigArray, bSize*sizeof(float) ); // cudaMemset(d_bigArray, 0, bSize*sizeof(float) ); @@ -336,7 +384,7 @@ void mexFunction(int nlhs, mxArray *plhs[], // d_uproj, d_cmax, d_bigArray); // sum_dWU<<<128,1024>>>( d_Params, d_bigArray, d_dWU ); // cudaFree(d_bigArray); -//------------------------------------------------- + //------------------------------------------------- count_spikes<<<7, 256>>>(d_Params, d_id, d_nsp, d_x, d_V); diff --git a/CUDA/mexDistances2.cu b/CUDA/mexDistances2.cu index 0017d8cc..84aaa85b 100644 --- a/CUDA/mexDistances2.cu +++ b/CUDA/mexDistances2.cu @@ -26,7 +26,7 @@ __global__ void computeCost(const double *Params, const float *Ws, const float * int j, tid, bid, Nspikes, my_chan, this_chan, Nchan, NrankPC, NchanNear, Nthreads, k; float xsum = 0.0f, Ci; - Nspikes = (int) Params[0]; + Nspikes = (int) Params[0]; //more accurately, number of comparisons, Nfilt*Nbatch Nchan = (int) Params[7]; NrankPC = (int) Params[1]; NchanNear = (int) Params[6]; diff --git a/CUDA/mexGPUall.m b/CUDA/mexGPUall.m index 4e14f214..74b6c4d1 100644 --- a/CUDA/mexGPUall.m +++ b/CUDA/mexGPUall.m @@ -2,9 +2,21 @@ % Matlab GPU library first (see README files for platform-specific % information) + enableStableMode = true; + mexcuda -largeArrayDims mexThSpkPC.cu mexcuda -largeArrayDims mexGetSpikes2.cu - mexcuda -largeArrayDims mexMPnu8.cu + + if enableStableMode + % For algorithm development purposes which require guaranteed + % deterministic calculations, add -DENSURE_DETERM swtich to + % compile line for mexMPnu8.cu. -DENABLE_STABLEMODE must also + % be specified. This version will run ~2X slower than the + % non deterministic version. + mexcuda -largeArrayDims -dynamic -DENABLE_STABLEMODE mexMPnu8.cu + else + mexcuda -largeArrayDims mexMPnu8.cu + end mexcuda -largeArrayDims mexSVDsmall2.cu mexcuda -largeArrayDims mexWtW2.cu diff --git a/CUDA/mexMPnu8.cu b/CUDA/mexMPnu8.cu index b3defb11..f9976df7 100644 --- a/CUDA/mexMPnu8.cu +++ b/CUDA/mexMPnu8.cu @@ -17,15 +17,22 @@ #include using namespace std; -//for sorting according to timestamps -//#include "mexNvidia_quicksort.cu" - - - +#ifdef ENABLE_STABLEMODE + //for sorting according to timestamps + #include "mexNvidia_quicksort.cu" +#endif + const int Nthreads = 1024, maxFR = 100000, NrankMax = 3, nmaxiter = 500, NchanMax = 32; + ////////////////////////////////////////////////////////////////////////////////////////// __global__ void spaceFilter(const double *Params, const float *data, const float *U, - const int *iC, const int *iW, float *dprod){ + const int *iC, const int *iW, float *dprod){ + +// <<>> +// blockIdx = current filter/template +// blockDim = 1024 (max number of threads) +// threadIdx = used both to index channel (in synchronized portion) +// and time (in non-synchronized portion). volatile __shared__ float sU[32*NrankMax]; volatile __shared__ int iU[32]; float x; @@ -33,35 +40,44 @@ __global__ void spaceFilter(const double *Params, const float *data, const float tid = threadIdx.x; bid = blockIdx.x; - NT = (int) Params[0]; - Nfilt = (int) Params[1]; + NT = (int) Params[0]; + Nfilt = (int) Params[1]; Nrank = (int) Params[6]; - NchanU = (int) Params[10]; + NchanU = (int) Params[10]; //NchanNear in learnTemplates = 32 Nchan = (int) Params[9]; if (tid>> // just need to do this for all filters that have overlap with id[bid] and st[id] - // tidx still represents time, from -nt0 to nt0 + // as in spaceFilter, tid = threadIdx.x is first used to index over channels and pcs + // then used to loop over time, now just from -nt0 to nt0 about the input spike time + // tidx represents time, from -nt0 to nt0 // tidy loops through all filters that have overlap if (tid=0 & t>> + // just need to do this for all filters that have overlap with id[bid] and st[id] + // as in spaceFilter, tid = threadIdx.x is first used to index over channels and pcs + // then used to loop over time, now just from -nt0 to nt0 about the input spike time + // tidx represents time, from -nt0 to nt0 + // tidy loops through all filters that have overlap + + if (tid=0 & t>> +// threadIdx.x used as index over pcs in temporal templates +// (num PCs * number of timepoints = Nrank * nt0) +// Applied to data that's already been through filtering with +// the spatial templates, input data has dim Nrank x NT x Nfilt - __syncthreads(); + if(tid>> +// Same as timeFilter, except timepoints now limited to +/- nt0 about +// spike times assiged to filters that may overlap the current filter +// specified by bid. The matrix of potentially overlapping filters +// is given in UtU. - if (tid=0 && tid0=0 && tid0>> +// loop over timepoints + tid0 = tid + bid * blockDim.x; while (tid0>> + lockout = (int) Params[4] - 1; // Parms[4] = nt0 tid = threadIdx.x; bid = blockIdx.x; - NT = (int) Params[0]; + NT = (int) Params[0]; tid0 = bid * blockDim.x ; Th = (float) Params[2]; //lam = (float) Params[7]; @@ -325,20 +429,23 @@ __global__ void cleanup_spikes(const double *Params, const float *data, tid0 += blockDim.x * gridDim.x; } } + ////////////////////////////////////////////////////////////////////////////////////////// __global__ void extractFEAT(const double *Params, const int *st, const int *id, const int *counter, const float *dout, const int *iList, const float *mu, float *d_feat){ - int t, tidx, tidy,Nblocks,NthreadsX,idF, bid, NT, ind, tcurr, Nnearest; + float rMax, Ci, Cf, lam; + int t, tidx, tidy, Nblocks, NthreadsX, idF, bid, NT, ind, tcurr, Nnearest; + tidx = threadIdx.x; tidy = threadIdx.y; bid = blockIdx.x; - NT = (int) Params[0]; + NT = (int) Params[0]; Nnearest = (int) Params[5]; NthreadsX = blockDim.x; - Nblocks = gridDim.x; + Nblocks = gridDim.x; lam = (float) Params[7]; // each thread x does a nearby filter @@ -361,6 +468,9 @@ __global__ void extractFEAT(const double *Params, const int *st, const int *id, } ////////////////////////////////////////////////////////////////////////////////////////// +// subtract_spikes version using single precision arithemtic and atomic operations to +// avoid thread interference when threading over spikes. This calculation is not +// deterministic, due to the order dependence of operations in single precision. __global__ void subtract_spikes(const double *Params, const int *st, const int *id, const float *x, const int *counter, float *dataraw, const float *W, const float *U){ @@ -385,7 +495,54 @@ __global__ void subtract_spikes(const double *Params, const int *st, X += W[tidx + id[ind]* nt0 + nt0*Nfilt*k] * U[tidy + id[ind] * Nchan + Nchan*Nfilt*k]; - dataraw[tidx + st[ind] + NT * tidy] -= x[ind] * X; + X = -x[ind]*X; + atomicAdd(&dataraw[tidx + st[ind] + NT * tidy], X); + tidy += blockDim.y; + } + ind += gridDim.x; + } +} + +////////////////////////////////////////////////////////////////////////////////////////// +//subtractions from array of doubles +__global__ void subtract_spikes_v4(const double *Params, const int *st, + const int *id, const float *x, const int *counter, double *dataraw, + const float *W, const float *U){ + + double X; + int nt0, tidx, tidy, k, NT, ind, Nchan, Nfilt, Nrank; + + unsigned long long int* address_as_ull; + unsigned long long int old, assumed; + + NT = (int) Params[0]; + nt0 = (int) Params[4]; + Nchan = (int) Params[9]; + Nfilt = (int) Params[1]; + Nrank = (int) Params[6]; + + tidx = threadIdx.x; + ind = counter[1]+blockIdx.x; + + while(indTh){ if (id[currInd]==bid){ @@ -448,11 +705,6 @@ __global__ void average_snips(const double *Params, const int *st, } //end of for loop over spike indicies } //end of function - - - - - ////////////////////////////////////////////////////////////////////////////////////////// __global__ void computePCfeatures(const double *Params, const int *counter, const float *dataraw, const int *st, const int *id, const float *x, @@ -462,8 +714,8 @@ __global__ void computePCfeatures(const double *Params, const int *counter, volatile __shared__ float sPCA[81 * NrankMax], sW[81 * NrankMax], sU[NchanMax * NrankMax]; volatile __shared__ int iU[NchanMax]; - int bid, nt0, t, tidx, tidy, k, NT, ind, Nchan, NchanU, Nfilt, Nrank; float X = 0.0f, Y = 0.0f; + int bid, nt0, t, tidx, tidy, k, NT, ind, Nchan, NchanU, Nfilt, Nrank; NT = (int) Params[0]; nt0 = (int) Params[4]; @@ -511,18 +763,20 @@ __global__ void computePCfeatures(const double *Params, const int *counter, } ////////////////////////////////////////////////////////////////////////////////////////// +// This function is not called. __global__ void addback_spikes(const double *Params, const int *st, const int *id, const float *x, const int *count, float *dataraw, const float *W, const float *U, const int iter, const float *spkscore){ - int nt0, tidx, tidy, k, NT, ind, Nchan, Nfilt, Nrank; + float X, ThS; + int nt0, tidx, tidy, k, NT, ind, Nchan, Nfilt, Nrank; NT = (int) Params[0]; nt0 = (int) Params[4]; Nchan = (int) Params[9]; - Nfilt = (int) Params[1]; + Nfilt = (int) Params[1]; Nrank = (int) Params[6]; - ThS = (float) Params[11]; + ThS = (float) Params[11]; tidx = threadIdx.x; ind = count[iter]+blockIdx.x; @@ -537,14 +791,23 @@ __global__ void addback_spikes(const double *Params, const int *st, for (k=0;k> +__global__ void set_idx( unsigned int *idx, const unsigned int nitems ) { + for( int i = 0; i < nitems; ++ i ) { + idx[i] = i; + } +} + ////////////////////////////////////////////////////////////////////////////////////////// /* @@ -558,19 +821,21 @@ void mexFunction(int nlhs, mxArray *plhs[], /* Declare input variables*/ double *Params, *d_Params; - unsigned int nt0, Nchan, NT, Nfilt, Nnearest, Nrank, NchanU; - + unsigned int nt0, Nchan, NT, Nfilt, Nnearest, Nrank, NchanU, useStableMode; + /* read Params and copy to GPU */ - Params = (double*) mxGetData(prhs[0]); - NT = (unsigned int) Params[0]; - Nfilt = (unsigned int) Params[1]; - nt0 = (unsigned int) Params[4]; - Nnearest = (unsigned int) Params[5]; - Nrank = (unsigned int) Params[6]; - NchanU = (unsigned int) Params[10]; - Nchan = (unsigned int) Params[9]; + Params = (double*) mxGetData(prhs[0]); + NT = (unsigned int) Params[0]; + Nfilt = (unsigned int) Params[1]; + nt0 = (unsigned int) Params[4]; + Nnearest = (unsigned int) Params[5]; + Nrank = (unsigned int) Params[6]; + NchanU = (unsigned int) Params[10]; + Nchan = (unsigned int) Params[9]; + useStableMode = (unsigned int) Params[16]; + // Make a local pointer to Params, which can be passed to kernels cudaMalloc(&d_Params, sizeof(double)*mxGetNumberOfElements(prhs[0])); cudaMemcpy(d_Params,Params,sizeof(double)*mxGetNumberOfElements(prhs[0]),cudaMemcpyHostToDevice); @@ -649,6 +914,19 @@ void mexFunction(int nlhs, mxArray *plhs[], cudaMemset(d_err, 0, NT * sizeof(float)); cudaMemset(d_ftype, 0, NT * sizeof(int)); cudaMemset(d_eloss, 0, NT * sizeof(float)); + + //allocate memory for index array, to be filled with 0->N items if sorting + //is not selected, fill with time sorted spike indicies if selected + unsigned int *d_idx; + cudaMalloc(&d_idx, maxFR * sizeof(int)); + cudaMemset(d_idx, 0, maxFR * sizeof(int)); + + //allocate arrays for sorting timestamps prior to spike subtraction from + //the data and averaging. Set to Params[17] to 1 in matlab caller + unsigned int *d_stSort; + cudaMalloc(&d_stSort, maxFR * sizeof(int)); + cudaMemset(d_stSort, 0, maxFR * sizeof(int)); + dim3 tpB(8, 2*nt0-1), tpF(16, Nnearest), tpS(nt0, 16), tpW(Nnearest, Nrank), tpPC(NchanU, Nrank); @@ -662,44 +940,118 @@ void mexFunction(int nlhs, mxArray *plhs[], bestFilter<<>>(d_Params, d_dout, d_mu, d_err, d_eloss, d_ftype); // loop to find and subtract spikes - for(int k=0;k<(int) Params[3];k++){ + + double *d_draw64; + if (useStableMode) { + // create copy of the dataraw, d_dout, d_data as doubles for arithmetic + // number of consecutive points to convert = Params(17) (Params(18) in matlab) + cudaMalloc(&d_draw64, NT*Nchan * sizeof(double)); + convToDouble<<<100,Nthreads>>>(d_Params, d_draw, d_draw64); + } + + for(int k=0;k<(int) Params[3];k++){ //Parms[3] = nInnerIter, set to 60 final pass // ignore peaks that are smaller than another nearby peak cleanup_spikes<<>>(d_Params, d_dout, d_mu, d_err, d_eloss, d_ftype, d_st, d_id, d_x, d_y, d_z, d_counter); // add new spikes to 2nd counter cudaMemcpy(counter, d_counter, 2*sizeof(int), cudaMemcpyDeviceToHost); + // limit number of spike to add to feature arrays AND subtract from drez + // to maxFR. maxFR = 100000, so this limit is likely not hit for "standard" + // batch size of 65000. However, could lead to duplicate template formation + // if the limit were hit in learning templates. Should we add a warning flag? if (counter[0]>maxFR){ counter[0] = maxFR; cudaMemcpy(d_counter, counter, sizeof(int), cudaMemcpyHostToDevice); } - // extract template features before subtraction + // extract template features before subtraction, for counter[1] to counter[0] + // tpF(16, Nnearest), blocks are over spikes if (Params[12]>1) extractFEAT<<<64, tpF>>>(d_Params, d_st, d_id, d_counter, d_dout, d_iList, d_mu, d_feat); - // subtract spikes from raw data here - subtract_spikes<<>>(d_Params, d_st, d_id, d_y, d_counter, d_draw, d_W, d_U); - - // filter the data with the spatial templates - spaceFilterUpdate<<>>(d_Params, d_draw, d_U, d_UtU, d_iC, d_iW, d_data, - d_st, d_id, d_counter); + + + + // subtract spikes from raw data. If compile switch "ENSURE_DETERM" is on, + // use subtract_spikes_v2, which threads only over + // spikes subratcted = counter[1] up to counter[0]. + // for this calculation to be reproducible, need to sort the spikes first + + +#ifdef ENSURE_DETERM + // create set of indicies from 0 to counter[0] - counter[1] - 1 + // if useStableMode = 0, this will be passed to subtract_spikes_v2 unaltered + // and spikes will be subtracted off in the order found + // NOTE: deterministic calculations are dependent on ENABLE_STABLEMODE! + set_idx<<< 1, 1 >>>(d_idx, counter[0] - counter[1]); + #ifdef ENABLE_STABLEMODE + if (useStableMode) { + //make a copy of the timestamp array to sort + cudaMemcpy( d_stSort, d_st+counter[1], (counter[0] - counter[1])*sizeof(int), cudaMemcpyDeviceToDevice ); + int left = 0; + int right = counter[0] - counter[1] - 1; + cdp_simple_quicksort<<< 1, 1 >>>(d_stSort, d_idx, left, right, 0); + } + #endif + if (Nchan < Nthreads) { + subtract_spikes_v2<<<1, Nchan>>>(d_Params, d_st, d_idx, d_id, d_y, d_counter, d_draw, d_W, d_U); + } + else { + subtract_spikes_v2<<>>(d_Params, d_st, d_idx, d_id, d_y, d_counter, d_draw, d_W, d_U); + } + // filter the data with the spatial templates, checking only times where + // identified spikes were subtracted. Need version using a single precision copy of draw + spaceFilterUpdate<<>>(d_Params, d_draw, d_U, d_UtU, d_iC, d_iW, d_data, + d_st, d_id, d_counter); + +#else + //"Normal" mode -- recommend useStableMode, which will give mostly deterministic calculations + //useStableMode = 0 will have significant differences from run to run, but is 15-20% faster + if (useStableMode) { + subtract_spikes_v4<<>>(d_Params, d_st, d_id, d_y, d_counter, d_draw64, d_W, d_U); + // filter the data with the spatial templates, checking only times where + // identified spikes were subtracted. Need version using a double precision copy of draw + spaceFilterUpdate_v2<<>>(d_Params, d_draw64, d_U, d_UtU, d_iC, d_iW, d_data, + d_st, d_id, d_counter); + } + else { + subtract_spikes<<>>(d_Params, d_st, d_id, d_y, d_counter, d_draw, d_W, d_U); + // filter the data with the spatial templates, checking only times where + // identified spikes were subtracted. Need version using a single precision copy of draw + spaceFilterUpdate<<>>(d_Params, d_draw, d_U, d_UtU, d_iC, d_iW, d_data, + d_st, d_id, d_counter); + } +#endif + - // filter the data with the temporal templates + // filter the data with the temporal templates, checking only times where + // identified spikes were subtracted timeFilterUpdate<<>>(d_Params, d_data, d_W, d_UtU, d_dout, d_st, d_id, d_counter); - if (counter[0]-counter[1]>0) + // shouldn't the space filter update and time filter update also only + // be done if counter[0] - counter[1] > 0? + if (counter[0]-counter[1]>0) { bestFilterUpdate<<>>(d_Params, d_dout, d_mu, d_err, d_eloss, d_ftype, d_st, d_id, d_counter); - + + } + // d_count records the number of spikes (tracked in d_counter[0] in each + // iteration, but is currently unused. cudaMemcpy(d_count+k+1, d_counter, sizeof(int), cudaMemcpyDeviceToDevice); - // update 1st counter from 2nd counter + // copy d_counter[0] to d_counter[1]. cleanup_spikes will look for new + // spikes in the data and increment d_counter[0]; features of these new + // spikes will be added to d_featPC and then subracted out of d_out. cudaMemcpy(d_counter+1, d_counter, sizeof(int), cudaMemcpyDeviceToDevice); } - + if (useStableMode) { + //convert arrays back to singles for the rest of the process + convToSingle<<<100,Nthreads>>>(d_Params, d_draw64, d_draw); + } + // compute PC features from reziduals + subtractions if (Params[12]>0) computePCfeatures<<>>(d_Params, d_counter, d_draw, d_st, @@ -707,23 +1059,26 @@ void mexFunction(int nlhs, mxArray *plhs[], //jic addition of time sorting prior to average_snips //get a set of indices for the sorted timestamp array - //make a copy of the timestamp array to sort, plus an array of indicies - -// unsigned int *d_stSort, *d_idx; -// cudaMalloc(&d_stSort, counter[0] * sizeof(int)); -// cudaMemset(d_stSort, 0, counter[0] *sizeof(int)); -// cudaMalloc(&d_idx, counter[0] * sizeof(int)); -// cudaMemset(d_idx, 0, counter[0] *sizeof(int)); -// cudaMemcpy( d_stSort, d_st, counter[0]*sizeof(int), cudaMemcpyDeviceToDevice ); -// set_idx<<< 1, 1 >>>(d_idx, counter[0]); -// int left = 0; -// int right = counter[0]-1; -// cdp_simple_quicksort<<< 1, 1 >>>(d_stSort, d_idx, left, right, 0); + //make an array of indicies; if useStableMode = 0, this will be passed + //to average_snips unaltered + set_idx<<< 1, 1 >>>(d_idx, counter[0]); + +#ifdef ENABLE_STABLEMODE + if (useStableMode) { + //make a copy of the timestamp array to sort + cudaMemcpy( d_stSort, d_st, counter[0]*sizeof(int), cudaMemcpyDeviceToDevice ); + int left = 0; + int right = counter[0]-1; + cdp_simple_quicksort<<< 1, 1 >>>(d_stSort, d_idx, left, right, 0); + } +#endif + + // update dWU here by adding back to subbed spikes. // additional parameter d_idx = array of time sorted indicies - average_snips<<>>(d_Params, d_st, d_id, d_x, d_y, d_counter, - d_draw, d_W, d_U, d_dWU, d_nsp,d_mu, d_z); + average_snips<<>>(d_Params, d_st, d_idx, d_id, d_x, d_y, d_counter, + d_draw, d_W, d_U, d_dWU, d_nsp, d_mu, d_z); float *x, *feat, *featPC, *vexp; int *st, *id; @@ -759,8 +1114,37 @@ void mexFunction(int nlhs, mxArray *plhs[], cudaMemcpy(vexp, d_x, minSize * sizeof(float), cudaMemcpyDeviceToHost); cudaMemcpy(feat, d_feat, minSize * Nnearest*sizeof(float), cudaMemcpyDeviceToHost); cudaMemcpy(featPC, d_featPC, minSize * NchanU*Nrank*sizeof(float), cudaMemcpyDeviceToHost); + + // send back an error message if useStableMode was selected but couldn't be used + //local array to hold error + int *d_errmsg; + cudaMalloc(&d_errmsg, 1 * sizeof(int)); + + //host array + int *errmsg; + const mwSize dimErr[] = {1,1}; + plhs[9] = mxCreateNumericArray(2, dimErr, mxINT32_CLASS, mxREAL); + errmsg = (int*) mxGetData(plhs[9]); + + //set to no error + cudaMemset(d_errmsg, 0, 1 * sizeof(int)); + if (useStableMode) { + #ifndef ENABLE_STABLEMODE + //if caller requested stableMode, but not enabled, set error = 1 + cudaMemset(d_errmsg, 1, 1); //set single byte = 1 + #endif + } + cudaMemcpy(errmsg, d_errmsg, 1 * sizeof(int), cudaMemcpyDeviceToHost); + cudaFree(d_errmsg); + + if (useStableMode) { + //only free the memory if it was allocated + cudaFree(d_draw64); + } + cudaFree(d_counter); + cudaFree(d_count); cudaFree(d_Params); cudaFree(d_ftype); cudaFree(d_err); @@ -774,8 +1158,10 @@ void mexFunction(int nlhs, mxArray *plhs[], cudaFree(d_featPC); cudaFree(d_dout); cudaFree(d_data); -// cudaFree(d_idx); -// cudaFree(d_stSort); + cudaFree(d_idx); + cudaFree(d_stSort); + + mxGPUDestroyGPUArray(draw); mxGPUDestroyGPUArray(wPCA); @@ -788,5 +1174,4 @@ void mexFunction(int nlhs, mxArray *plhs[], mxGPUDestroyGPUArray(nsp); mxGPUDestroyGPUArray(iW); mxGPUDestroyGPUArray(iList); - } diff --git a/CUDA/mexNvidia_quicksort.cu b/CUDA/mexNvidia_quicksort.cu index 71f14165..15a4c92f 100644 --- a/CUDA/mexNvidia_quicksort.cu +++ b/CUDA/mexNvidia_quicksort.cu @@ -140,12 +140,39 @@ __global__ void cdp_simple_quicksort( unsigned int *data, unsigned int *idx, int } } -__global__ void set_idx( unsigned int *idx, unsigned int nitems ) { + +//////////////////////////////////////////////////////////////////////////////// +// Helper functions +//////////////////////////////////////////////////////////////////////////////// + +// will need to include this function to set indicies in the calling program +// to allow sorting to be optional. +// create gpu array of starting index values, 0..nitimes-1 +// call with no threads, i.e. <<1, 1>> +// __global__ void set_idx( unsigned int *idx, const unsigned int nitems ) { +// for( int i = 0; i < nitems; ++ i ) { +// idx[i] = i; +// } +// } + +// copy values from an integer to a new array in sort order given by sort_idx +// call with no threads, i.e. <<1, 1>> +__global__ void copy_sort_int( const int *orig, const unsigned int *sort_idx, + const unsigned int nitems, int *sorted ) { for( int i = 0; i < nitems; ++ i ) { - idx[i] = i; + sorted[sort_idx[i]] = orig[i]; } } +// copy values from an array of single precision +// floating point numbers to a new array in sort order given by sort_idx +// call with no threads, i.e. <<1, 1>> +__global__ void copy_sort_int( const float *orig, const unsigned int *sort_idx, + const unsigned int nitems, float *sorted ) { + for( int i = 0; i < nitems; ++ i ) { + sorted[sort_idx[i]] = orig[i]; + } +} /////////////////////////////////////////////////////////////////// // Host code diff --git a/CUDA/mexThSpkPC.cu b/CUDA/mexThSpkPC.cu index 3c8930c1..b5db30c0 100644 --- a/CUDA/mexThSpkPC.cu +++ b/CUDA/mexThSpkPC.cu @@ -17,7 +17,7 @@ #include using namespace std; -const int Nthreads = 1024, maxFR = 10000, NrankMax = 3, nt0max=81, NchanMax = 17; +const int Nthreads = 1024, maxFR = 100000, NrankMax = 3, nt0max=81, NchanMax = 17; ////////////////////////////////////////////////////////////////////////////////////////// __global__ void Conv1D(const double *Params, const float *data, const float *W, float *conv_sig){ @@ -62,7 +62,10 @@ __global__ void Conv1D(const double *Params, const float *data, const float *W, ////////////////////////////////////////////////////////////////////////////////////////// __global__ void computeProjections(const double *Params, const float *dataraw, const int *iC, const int *st, const int *id, const float *W, float *feat){ - + + //number of blocks = number of spikes to process minimum( number found, maxFR=100000) + //Thread grid = (NchanNear, NrankPC) + float x; int tidx, nt0min, tidy, my_chan, this_chan, tid, bid, nt0, NchanNear, j, t, NT, NrankPC; volatile __shared__ float sW[nt0max*NrankMax], sD[nt0max*NchanMax]; @@ -73,9 +76,9 @@ __global__ void computeProjections(const double *Params, const float *dataraw, NrankPC = (int) Params[6]; nt0min = (int) Params[4]; - tidx = threadIdx.x; - tidy = threadIdx.y; - bid = blockIdx.x; + tidx = threadIdx.x; //PC index in W (column index) + tidy = threadIdx.y; //channel index + bid = blockIdx.x; //NchanNear*NrankPC; each spike gets NchanNear*NrankPC values in projection // move wPCA to shared memory while (tidx>>(d_Params, d_dout, d_dmax); // take max across nearby channels + // return spike times in d-st, max channel index in d_id, #spikes in d_counter + // note that max channel and spike times are only saved for the first maxFR spikes maxChannels<<>>(d_Params, d_dout, d_dmax, d_iC, d_st, d_id, d_counter); - + cudaMemcpy(counter, d_counter, sizeof(int), cudaMemcpyDeviceToHost); - // move d_x to the CPU + // calculate features for up to maxFR spikes unsigned int minSize=1; minSize = min(maxFR, counter[0]); @@ -278,12 +287,13 @@ void mexFunction(int nlhs, mxArray *plhs[], computeProjections<<>>(d_Params, d_data, d_iC, d_st, d_id, d_W, d_featPC); cudaMemcpy(d_id2, d_id, minSize * sizeof(int), cudaMemcpyDeviceToDevice); + - // dWU stays a GPU array + + // uproj and array of max channels will remain GPU arrays plhs[0] = mxGPUCreateMxArrayOnGPU(featPC); plhs[1] = mxGPUCreateMxArrayOnGPU(id); - cudaFree(d_st); cudaFree(d_id); cudaFree(d_counter); diff --git a/eMouse_drift/benchmark_drift_simulation.m b/eMouse_drift/benchmark_drift_simulation.m index 5b58a352..0cba82a3 100644 --- a/eMouse_drift/benchmark_drift_simulation.m +++ b/eMouse_drift/benchmark_drift_simulation.m @@ -1,11 +1,17 @@ function benchmark_drift_simulation(rez, GTfilepath, simRecfilepath, sortType, bAutoMerge, varargin) -%for testing outside a script. comment out for normal calling! -% load('D:\drift_simulations\74U_norm_64site_20um_600sec_20min\ks2_master_060919\rezFinal.mat'); -% GTfilepath = 'D:\drift_simulations\74U_norm_64site_20um_600sec_20min\ks2_master_060919\eMouseGroundTruth.mat'; -% simRecfilepath = 'D:\drift_simulations\74U_norm_64site_20um_600sec_20min\ks2_master_060919\eMouseSimRecord.mat'; -% sortType = 2; -% bAutoMerge = 0; +bOutFile = 0; + +%these definitions for testing outside a script. comment out for normal calling! +%can leave out last 2 lines if output file not desired +% load('D:\test_new_sim\74U_20um_drift_standard\r28_KS2determ_r26rep\rez2.mat'); +% GTfilepath = 'D:\test_new_sim\74U_20um_drift_standard\eMouseGroundTruth.mat'; +% simRecfilepath = 'D:\test_new_sim\74U_20um_drift_standard\eMouseSimRecord.mat'; +% bOutFile = 1; +% out_fid = fopen('D:\test_new_sim\74U_20um_drift_standard\r28_benchmark_output.txt','w'); + +sortType = 2; +bAutoMerge = 0; load(GTfilepath); @@ -15,7 +21,7 @@ function benchmark_drift_simulation(rez, GTfilepath, simRecfilepath, sortType, b testClu = rez.st3(:,2) ; end -bOutFile = 0; + %fprintf( 'length of vargin: %d\n', numel(varargin)); if( numel(varargin) == 1) %path for output file @@ -23,6 +29,7 @@ function benchmark_drift_simulation(rez, GTfilepath, simRecfilepath, sortType, b fprintf( 'output filename: %s\n', varargin{1} ); out_fid = fopen( varargin{1}, 'w' ); end + testRes = rez.st3(:,1); diff --git a/eMouse_drift/make_eMouseChannelMap_3B_short.m b/eMouse_drift/make_eMouseChannelMap_3B_short.m index ace03cd6..27143b04 100644 --- a/eMouse_drift/make_eMouseChannelMap_3B_short.m +++ b/eMouse_drift/make_eMouseChannelMap_3B_short.m @@ -55,4 +55,4 @@ chanMapName = sprintf('chanMap_3B_%dsites.mat', NchanTOT); -save(fullfile(fpath, chanMapName), 'chanMap', 'connected', 'xcoords', 'ycoords', 'kcoords', 'fs', 'NchanTOT' ) \ No newline at end of file +save(fullfile(fpath, chanMapName), 'chanMap', 'connected', 'xcoords', 'ycoords', 'kcoords', 'fs' ) \ No newline at end of file diff --git a/eMouse_drift/make_eMouseData_drift.m b/eMouse_drift/make_eMouseData_drift.m index 8385bf6b..8bfe1732 100644 --- a/eMouse_drift/make_eMouseData_drift.m +++ b/eMouse_drift/make_eMouseData_drift.m @@ -5,16 +5,16 @@ function make_eMouseData_drift(fpath, KS2path, chanMapName, useGPU, useParPool) % probe sites. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % you can play with the parameters just below here to achieve a signal more similar to your own data!!! -norm_amp = 20; % if 0, use amplitudes of input waveforms; if > 0, set all amplitudes to norm_amp*rms_noise +norm_amp = 16.7; % if 0, use amplitudes of input waveforms; if > 0, set all amplitudes to norm_amp*rms_noise mu_mean = 0.75; % mean of mean spike amplitudes. Incoming waveforms are in uV; make <1 to make sorting harder -noise_model = 'fromData'; %'gauss' or 'fromData'; 'fromData' requires a noiseModel.mat built by make_noise_model -rms_noise = 12; % rms noise in uV. Will be added to the spike signal. 15-20 uV an OK estimate from real data +noise_model = 'gauss'; %'gauss' or 'fromData'; 'fromData' requires a noiseModel.mat built by make_noise_model +rms_noise = 10; % rms noise in uV. Will be added to the spike signal. 15-20 uV an OK estimate from real data t_record = 1200; % duration in seconds of simulation. longer is better (and slower!) (1000) fr_bounds = [1 10]; % min and max of firing rates ([1 10]) tsmooth = 0.5; % gaussian smooth the noise with sig = this many samples (increase to make it harder) (0.5) chsmooth = 0.5; % smooth the noise across channels too, with this sig (increase to make it harder) (0.5) amp_std = .1; % standard deviation of single spike amplitude variability (increase to make it harder, technically std of gamma random variable of mean 1) (.25) -fs = 30000; % sample rate for the simulation. Incoming waveforms must be sampled at this freq. +fs_rec = 30000; % sample rate for the for the recording. Waveforms must be sampled at a rate w/in .01% of this rate nt = 81; % number of timepoints expected. All waveforms must have this time window %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %drift params. See the comments in calcYPos_v2 for details @@ -81,12 +81,37 @@ function make_eMouseData_drift(fpath, KS2path, chanMapName, useGPU, useParPool) end noiseFromData = load(nmPath); end + +% Add a SYNC channel to the file +% 16 bit word with a 1 Hz square wave in 7th bit +addSYNC = false; +syncOffset = 0.232; % must be between 0 and 0.5, offset to first on edge %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% rng('default'); -rng(101); % set the seed of the random number generator; default = 101 -bitPerUV = 0.42667; %imec 3A or 3B, gain = 500 +% There are 3 seeds for the randome number generator, so some parts of the +% simulation can be fixed while others vary from run to run + +% For unit placment, average amplitude, and spike times + +unit_seed = 101; + +% For individual spike amplitudes, but still based on the same average +% Meant to simulate the same spikes showing up in different streams + +amp_seed = 101; + +% For noise generation + +noise_seed = 101; + + +% set the seed of the random number generator used for unit definition +rng(unit_seed); + +%bitPerUV = 0.42667; %imec 3A or 3B, gain = 500 +bitPerUV = 1.3107; %for NP 2.0, fixed gain of 80 % load channel map file built with make_eMouseChannelMap_3A_short.m @@ -95,7 +120,7 @@ function make_eMouseData_drift(fpath, KS2path, chanMapName, useGPU, useParPool) zeroSites = find(connected == 0); -Nchan = NchanTOT; %physical sites on the probe +Nchan = numel(chanMap); %physical sites on the probe %invChanMap(chanMap) = [1:Nchan]; % invert the channel map here--create the order in which to write output % range of y positions to place units @@ -133,7 +158,8 @@ function make_eMouseData_drift(fpath, KS2path, chanMapName, useGPU, useParPool) for fileIndex = 1:nFile % generate waveforms for the simulation uData = load(filePath{fileIndex}); - if (uData.fs ~= fs) + fs_diff = 100*(abs(uData.fs - fs_rec)/fs_rec); + if (fs_diff > 0.01) fprintf( 'Waveform file %d has wrong sample rate.\n', fileIndex ); fprintf( 'Skipping to next file.'); continue @@ -369,8 +395,12 @@ function make_eMouseData_drift(fpath, KS2path, chanMapName, useGPU, useParPool) bContinue = 1; +% set the sample rate to that specified in the hard coded params in this +% file (independent of fs read in through channel map) +% allows simulation of multiple streams with slightly different clock rates - +fs = fs_rec; +fs_std = 30000; %used to generate spike times %same for range of firing rates (note that we haven't included any info %about the original firign rates of the units @@ -392,12 +422,16 @@ function make_eMouseData_drift(fpath, KS2path, chanMapName, useGPU, useParPool) %before a success (neuron firing) %second two params for geornd are size of the array, here 2*firing %rate*total time of the simulation. - dspks = int64(geornd(1/(fs/fr(j)), ceil(2*fr(j)*t_record),1)); - dspks(dspks0 dat(1:buff/2, :) = dat_old(NT-buff/2 + [1:buff/2], :); end @@ -559,6 +611,10 @@ function make_eMouseData_drift(fpath, KS2path, chanMapName, useGPU, useParPool) dat_old = dat; %convert to 16 bit integers; waveforms are in uV dat = int16(bitPerUV * dat); + if addSYNC + %add the column of sync data + dat = horzcat(dat, sync); + end fwrite(fidW, dat(1:(NT-buff),:)', 'int16'); t_all = t_all + (NT-buff)/fs; @@ -860,3 +916,5 @@ function make_eMouseData_drift(fpath, KS2path, chanMapName, useGPU, useParPool) end + + diff --git a/eMouse_drift/master_eMouse_drift.m b/eMouse_drift/master_eMouse_drift.m index 43f4c9f9..cbbdf57a 100644 --- a/eMouse_drift/master_eMouse_drift.m +++ b/eMouse_drift/master_eMouse_drift.m @@ -4,18 +4,18 @@ sortData = 1; runBenchmark = 1; %set to 1 to compare sorted data to ground truth for the simulation -fpath = 'D:\test_new_sim\SC_noise\'; % where on disk do you want the simulation? ideally an SSD... +fpath = 'C:\Users\labadmin\Documents\jic\050320_sim_test\'; % where on disk do you want the simulation? ideally an SSD... if ~exist(fpath, 'dir'); mkdir(fpath); end %KS2 path -- also has the waveforms for the simulation -KS2path = 'Z:\workstation_backup\full_080119\Documents\KS2_current\'; +KS2path = 'C:\Users\labadmin\Documents\jic\KS2_040920\Kilosort2\'; % add paths to the matlab path -addpath(genpath('Z:\workstation_backup\full_080119\Documents\KS2_current\')); % path to kilosort2 folder -addpath(genpath('D:\KS2\npy-matlab-master\')); +addpath(genpath('C:\Users\labadmin\Documents\jic\KS2_040920\Kilosort2\')); % path to kilosort2 folder +addpath(genpath('C:\Users\labadmin\Documents\jic\npy-matlab-master\')); % path to whitened, filtered proc file (on a fast SSD) -rootH = 'D:\KS2\kilosort_datatemp\'; +rootH = 'D:\kilosort_datatemp\'; % path to config file; if running the default config, no need to change. pathToYourConfigFile = [KS2path,'eMouse_drift\']; % path to config file diff --git a/mainLoop/learnAndSolve8b.m b/mainLoop/learnAndSolve8b.m index f8cbd864..8a602c1a 100644 --- a/mainLoop/learnAndSolve8b.m +++ b/mainLoop/learnAndSolve8b.m @@ -8,10 +8,14 @@ % we learn the templates by going back and forth through some of the data, % in the order specified by iorig (determined by batch reordering). + % standard order -- learn templates from first half of data starting + % from midpoint, counting down to 1, and then returning. iorder0 = rez.iorig([ihalf:-1:1 1:ihalf]); % these are absolute batch ids + rez = learnTemplates(rez, iorder0); rez.istart = rez.iorig(ihalf); % this is the absolute batch id where we start sorting + else rez.WA = []; end diff --git a/mainLoop/learnTemplates.m b/mainLoop/learnTemplates.m index 2eff361f..d2727d54 100644 --- a/mainLoop/learnTemplates.m +++ b/mainLoop/learnTemplates.m @@ -1,12 +1,19 @@ function rez = learnTemplates(rez, iorder) % This is the main optimization. Takes the longest time and uses the GPU heavily. -ops = rez.ops; -ops.fig = getOr(ops, 'fig', 1); % whether to show plots every N batches +rez.ops.fig = getOr(rez.ops, 'fig', 1); % whether to show plots every N batches + +% Turn on sorting of spikes before subtracting and averaging in mpnu8 +rez.ops.useStableMode = getOr(rez.ops, 'useStableMode', 1); +useStableMode = rez.ops.useStableMode; NrankPC = 6; % this one is the rank of the PCs, used to detect spikes with threshold crossings Nrank = 3; % this one is the rank of the templates -rng('default'); rng(1); + +rez.ops.LTseed = getOr(rez.ops, 'LTseed', 1); +rng('default'); rng(rez.ops.LTseed); + +ops = rez.ops; % we need PC waveforms, as well as template waveforms [wTEMP, wPCA] = extractTemplatesfromSnippets(rez, NrankPC); @@ -57,7 +64,7 @@ Nsum = min(Nchan,7); % how many channels to extend out the waveform in mexgetspikes % lots of parameters passed into the CUDA scripts Params = double([NT Nfilt ops.Th(1) nInnerIter nt0 Nnearest ... - Nrank ops.lam pmi(1) Nchan NchanNear ops.nt0min 1 Nsum NrankPC ops.Th(1)]); + Nrank ops.lam pmi(1) Nchan NchanNear ops.nt0min 1 Nsum NrankPC ops.Th(1) useStableMode]); % W0 has to be ordered like this W0 = permute(double(wPCA), [1 3 2]); @@ -138,10 +145,19 @@ % gets scores for the template fits to each spike (vexp), outputs the average of % waveforms assigned to each cluster (dWU0), % and probably a few more things I forget about - [st0, id0, x0, featW, dWU0, drez, nsp0, featPC, vexp] = ... + [st0, id0, x0, featW, dWU0, drez, nsp0, featPC, vexp, errmsg] = ... mexMPnu8(Params, dataRAW, single(U), single(W), single(mu), iC-1, iW-1, UtU, iList-1, ... wPCA); - + + % errmsg returns 1 if caller requested "stableMode" but mexMPnu8 was + % compiled without the sorter enabled (i.e. STABLEMODE_ENABLE = false + % in mexGPUAll). Send an error message to the console just once if this + % is the case: + if (ibatch == 1) + if( (useStableMode == 1) && (errmsg == 1) ) + fprintf( 'useStableMode selected but STABLEMODE not enabled in compiled mexMPnu8.\n' ); + end + end % Sometimes nsp can get transposed (think this has to do with it being % a single element in one iteration, to which elements are added % nsp, nsp0, and pm must all be row vectors (Nfilt x 1), so force nsp @@ -244,8 +260,8 @@ rez = memorizeW(rez, W, dWU, U, mu); % memorize the state of the templates rez.ops = ops; % update these (only rez comes out of this script) - -fprintf('memorized middle timepoint \n') +% save('rez_mid.mat', 'rez'); + fprintf('Finished learning templates \n') %% diff --git a/mainLoop/runTemplates.m b/mainLoop/runTemplates.m index 7f588a1f..88e7b949 100644 --- a/mainLoop/runTemplates.m +++ b/mainLoop/runTemplates.m @@ -33,26 +33,32 @@ ihalf = find(rez.iorig==istart); iorder_sorted = ihalf:-1:1; -iorder = rez.iorig(iorder_sorted); +iorder = rez.iorig(iorder_sorted); %batch number in full set + [rez, st3_0, fW_0,fWpc_0] = trackAndSort(rez, iorder); + st3_0(:,5) = iorder_sorted(st3_0(:,5)); iorder_sorted = (ihalf+1):Nbatches; iorder = rez.iorig(iorder_sorted); + [rez, st3_1, fW_1,fWpc_1] = trackAndSort(rez, iorder); -st3_1(:,5) = iorder_sorted(st3_1(:,5)); + +st3_1(:,5) = iorder_sorted(st3_1(:,5)); %batch number in full set st3 = cat(1, st3_0, st3_1); fW = cat(2, fW_0, fW_1); fWpc = cat(3, fWpc_0, fWpc_1); -[~, isort] = sort(st3(:,1)); +% sort all spikes by batch -- to keep similar batches together, +% which avoids false splits in splitAllClusters. Break ties +[~, isort] = sortrows(st3,[5,1,2,3,4]); st3 = st3(isort, :); fW = fW(:, isort); fWpc = fWpc(:, :, isort); % just display the total number of spikes -size(st3,1) +fprintf( 'Number of spikes before applying cutoff: %d\n', size(st3,1)); rez.st3 = st3; rez.st2 = st3; % keep also an st2 copy, because st3 will be over-written by one of the post-processing steps diff --git a/mainLoop/trackAndSort.m b/mainLoop/trackAndSort.m index 467c6189..623e0808 100644 --- a/mainLoop/trackAndSort.m +++ b/mainLoop/trackAndSort.m @@ -2,6 +2,10 @@ % This is the extraction phase of the optimization. % iorder is the order in which to traverse the batches +% Turn on sorting of spikes before subtracting and averaging in mpnu8 +rez.ops.useStableMode = getOr(rez.ops, 'useStableMode', 1); +useStableMode = rez.ops.useStableMode; + ops = rez.ops; % revert to the saved templates @@ -18,11 +22,12 @@ dWU(:,:,j) = mu(j) * squeeze(W(:, j, :)) * squeeze(U(:, j, :))'; end + ops.fig = getOr(ops, 'fig', 1); % whether to show plots every N batches NrankPC = 6; % this one is the rank of the PCs, used to detect spikes with threshold crossings Nrank = 3; % this one is the rank of the templates -rng('default'); rng(1); +rng('default'); rng(1); % initializing random number generator % move these to the GPU wPCA = gpuArray(ops.wPCA); @@ -64,7 +69,7 @@ Nsum = min(Nchan,7); % how many channels to extend out the waveform in mexgetspikes % lots of parameters passed into the CUDA scripts Params = double([NT Nfilt ops.Th(1) nInnerIter nt0 Nnearest ... - Nrank ops.lam pm Nchan NchanNear ops.nt0min 1 Nsum NrankPC ops.Th(1)]); + Nrank ops.lam pm Nchan NchanNear ops.nt0min 1 Nsum NrankPC ops.Th(1) useStableMode]); % initialize average number of spikes per batch for each template nsp = gpuArray.zeros(Nfilt,1, 'double'); @@ -116,7 +121,9 @@ dataRAW = single(gpuArray(dat))/ ops.scaleproc; % decompose dWU by svd of time and space (via covariance matrix of 61 by 61 samples) - % this uses a "warm start" by remembering the W from the previous iteration + % this uses a "warm start" by remembering the W from the previous + % iteration + [W, U, mu] = mexSVDsmall2(Params, dWU, W, iC-1, iW-1, Ka, Kb); % UtU is the gram matrix of the spatial components of the low-rank SVDs @@ -134,7 +141,7 @@ % waveforms assigned to each cluster (dWU0), % and probably a few more things I forget about - [st0, id0, x0, featW, dWU0, drez, nsp0, featPC, vexp] = ... + [st0, id0, x0, featW, dWU0, drez, nsp0, featPC, vexp, errmsg] = ... mexMPnu8(Params, dataRAW, single(U), single(W), single(mu), iC-1, iW-1, UtU, iList-1, ... wPCA); @@ -171,7 +178,7 @@ rez.WA(:,:,:,k) = gather(W); rez.UA(:,:,:,k) = gather(U); rez.muA(:,k) = gather(mu); - + % we carefully assign the correct absolute times to spikes found in this batch ioffset = ops.ntbuff; if k==1 diff --git a/master_kilosort.m b/master_kilosort.m index 5230b251..7c9aa19e 100644 --- a/master_kilosort.m +++ b/master_kilosort.m @@ -64,6 +64,10 @@ rez.cProj = []; rez.cProjPC = []; +% final time sorting of spikes, for apps that use st3 directly +[~, isort] = sortrows(rez.st3); +rez.st3 = rez.st3(isort, :); + % save final results as rez2 fprintf('Saving final results in rez2 \n') fname = fullfile(rootZ, 'rez2.mat'); diff --git a/postProcess/splitAllClusters.m b/postProcess/splitAllClusters.m index 9367a076..7ec82f3f 100644 --- a/postProcess/splitAllClusters.m +++ b/postProcess/splitAllClusters.m @@ -7,6 +7,7 @@ % it only uses the PC features for each spike, stored in rez.cProjPC ops = rez.ops; + wPCA = gather(ops.wPCA); % use PCA projections to reconstruct templates when we do splits ccsplit = rez.ops.AUCsplit; % this is the threshold for splits, and is one of the main parameters users can change diff --git a/preProcess/clusterSingleBatches.m b/preProcess/clusterSingleBatches.m index 45b5b375..77ae4ddc 100644 --- a/preProcess/clusterSingleBatches.m +++ b/preProcess/clusterSingleBatches.m @@ -4,10 +4,15 @@ % the resulting cluster means are then compared for all pairs of batches, and a dissimilarity score is assigned to each pair % the matrix of similarity scores is then re-ordered so that low dissimilaity is along the diagonal +ops = rez.ops; -rng('default'); rng(1); +% Turn on sorting of spikes before starting kmeans +rez.ops.useStableMode = getOr(rez.ops, 'useStableMode', 1); +useStableMode = rez.ops.useStableMode; -ops = rez.ops; +rez.ops.CSBseed = getOr(rez.ops, 'CSBseed', 1); %standard seed = 1; +rng('default'); rng(rez.ops.CSBseed); +fprintf('random seed for clusterSingleBatches: %d\n', rez.ops.CSBseed ); if getOr(ops, 'reorder', 0)==0 rez.iorig = 1:rez.temp.Nbatch; % if reordering is turned off, return consecutive order @@ -18,6 +23,8 @@ Nfilt = ceil(rez.ops.Nchan/2); tic wPCA = extractPCfromSnippets(rez, nPCs); % extract PCA waveforms pooled over channels +%JIC -- in what sense is this 7 PC waveforms? A single set of basis PCs is +%calculated from all spikes found in every 100th batch. fprintf('Obtained 7 PC waveforms in %2.2f seconds \n', toc) % 7 is the default, and I don't think it needs to be able to change Nchan = rez.ops.Nchan; @@ -32,6 +39,8 @@ ns = gpuArray.zeros(Nfilt, nBatches, 'single'); % this holds the number of spikes for that cluster Whs = gpuArray.ones(Nfilt, nBatches, 'int32'); % this holds the center channel for each template + + i0 = 0; NrankPC = 3; % I am not sure if this gets used, but it goes into the function @@ -40,15 +49,27 @@ tic for ibatch = 1:nBatches + [uproj, call] = extractPCbatch2(rez, wPCA, min(nBatches-1, ibatch), iC); % extract spikes using PCA waveforms % call contains the center channels for each spike + + % sort rows of uprojDAT (sorts on first component, breaks ties with 2nd, 3rd...) + % the order is arbitrary but ordering makes the k-means + % deterministic + [~,order] = sortrows(uproj'); + uproj = uproj(:,order); + call = call(order); + + if sum(isnan(uproj(:)))>0 %sum(mus(:,ibatch)<.1)>30 break; % I am not sure what case this safeguards against.... end if size(uproj,2)>Nfilt % if a batch has at least as many spikes as templates we request, then cluster it + % uproj contains all spikes, W will hold the starting points for + % k-means. [W, mu, Wheights, irand] = initializeWdata2(call, uproj, Nchan, nPCs, Nfilt, iC); % this initialize the k-means % Params is a whole bunch of parameters sent to the C++ scripts inside a float64 vector @@ -62,7 +83,7 @@ % get iclust and update W [dWU, iclust, dx, nsp, dV] = mexClustering2(Params, uproj, W, mu, ... call-1, iMatch, iC-1); % CUDA script to efficiently compute distances for pairs in which iMatch is 1 - + dWU = dWU./(1e-5 + single(nsp')); % divide the cumulative waveform by the number of spikes mu = sum(dWU.^2,1).^.5; % norm of cluster template @@ -73,15 +94,20 @@ W = reshape(W, Nchan * nPCs, Nfilt); [~, Wheights] = max(nW,[], 1); % the new best channel of each cluster template + end - + % carefully keep track of cluster templates in dense format W = reshape(W, nPCs, Nchan, Nfilt); + W0 = gpuArray.zeros(nPCs, NchanNear, Nfilt, 'single'); for t = 1:Nfilt W0(:, :, t) = W(:, iC(:, Wheights(t)), t); end W0 = W0 ./ (1e-5 + sum(sum(W0.^2,1),2).^.5); % I don't really know why this needs another normalization + else + % make a note when a batch has fewer than Nfilt spikes + fprintf( 'Batch %d has fewer than Nfilt spikes.\n', ibatch ); end if exist('W0', 'var') @@ -99,11 +125,11 @@ if rem(ibatch, 500)==1 fprintf('time %2.2f, pre clustered %d / %d batches \n', toc, ibatch, nBatches) end + end - tic -% anothr one of these Params variables transporting parameters to the C++ code +% another one of these Params variables transporting parameters to the C++ code Params = [1 NrankPC Nfilt 0 size(W,1) 0 NchanNear Nchan]; Params(1) = size(Ws,3) * size(Ws,4); % the total number of templates is the number of templates per batch times the number of batches @@ -112,7 +138,7 @@ for ibatch = 1:nBatches % for every batch, compute in parallel its dissimilarity to ALL other batches - Wh0 = single(Whs(:, ibatch)); % this one is the primary batch + Wh0 = single(Whs(:, ibatch)); % max channels of the primary batch W0 = Ws(:, :, ibatch); mu = mus(:, ibatch); @@ -123,6 +149,10 @@ end % pairs of templates that live on the same channels are potential "matches" + % This calculateion finds all channels with at least one neighbor + % overlapping the max channel of the cluster. + % for probes where channel order does not reflect site position, does + % this need to change to a distance calculation? iMatch = sq(min(abs(single(iC) - reshape(Wh0, 1, 1, [])), [], 1))<.1; diff --git a/preProcess/extractPCfromSnippets.m b/preProcess/extractPCfromSnippets.m index 9bd8d4cc..b7b5d012 100644 --- a/preProcess/extractPCfromSnippets.m +++ b/preProcess/extractPCfromSnippets.m @@ -37,3 +37,4 @@ wPCA = U(:, 1:nPCs); % take as many as needed wPCA(:,1) = - wPCA(:,1) * sign(wPCA(21,1)); % adjust the arbitrary sign of the first PC so its negativity is downward + diff --git a/preProcess/initializeWdata2.m b/preProcess/initializeWdata2.m index a4d27fa9..bc66eea3 100644 --- a/preProcess/initializeWdata2.m +++ b/preProcess/initializeWdata2.m @@ -4,7 +4,11 @@ % uprojDAT are features projections (Nfeatures by Nspikes) % some more parameters need to be passed in from the main workspace -irand = ceil(rand(Nfilt,1) * size(uprojDAT,2)); % pick random spikes from the sample +%get a set of Nfilt unique spike indices +allSpike = randperm(size(uprojDAT,2)); +irand = allSpike(1:Nfilt); +%fprintf( 'in initializeWdata, Nfilt = %d, num unique spikes = %d\n', Nfilt, numel(unique(irand))); +%irand = ceil(rand(Nfilt,1) * size(uprojDAT,2)); % pick random spikes from the sample % irand = 1:Nfilt; W = gpuArray.zeros(nPCs, Nchan, Nfilt, 'single'); @@ -14,7 +18,7 @@ W(:, ich, t) = reshape(uprojDAT(:, irand(t)), nPCs, []); % for each selected spike, get its features end W = reshape(W, [], Nfilt); -W = W + .001 * gpuArray.randn(size(W), 'single'); % add small amount of noise in case we accidentally picked the same spike twice +%W = W + .001 * gpuArray.randn(size(W), 'single'); % add small amount of noise in case we accidentally picked the same spike twice mu = sum(W.^2,1).^.5; % get the mean of the template W = W./(1e-5 + mu); % and normalize the template