diff --git a/examples/ieeg_locate.py b/examples/ieeg_locate.py index 3813d7c..fd3df9e 100644 --- a/examples/ieeg_locate.py +++ b/examples/ieeg_locate.py @@ -397,20 +397,24 @@ def plot_overlay(image, compare, title, thresh=None): montage = raw.get_montage() montage.apply_trans(subj_trans) # convert to surface RAS # convert to scanner RAS -mne_bids.convert_montage_to_ras(montage, subject='sample_seeg', subjects_dir=misc_path / "seeg") +mne_bids.convert_montage_to_ras( + montage, subject="sample_seeg", subjects_dir=misc_path / "seeg" +) raw.set_montage(None) # clear already found montage # fake surgical plans from already-found contact locations -targets = {''.join([letter for letter in ch if not letter.isdigit()]): pos - for ch, pos in montage.get_positions()['ch_pos'].items() - if [letter for letter in ch if letter.isdigit()] == ['1']} +targets = { + "".join([letter for letter in ch if not letter.isdigit()]): pos + for ch, pos in montage.get_positions()["ch_pos"].items() + if [letter for letter in ch if letter.isdigit()] == ["1"] +} mne_gui.locate_ieeg( raw.info, subj_trans, CT_aligned, subject="sample_seeg", subjects_dir=misc_path / "seeg", - targets=targets + targets=targets, ) # %% @@ -444,6 +448,45 @@ def plot_overlay(image, compare, title, thresh=None): subjects_dir=misc_path / "ecog", ) +# %% +# Similarly for ECoG, we can try to automatically detect contact locations. +# In the case of ECoG, there can often be a lot of contacts so this can be +# a big time saver. + +montage = raw_ecog.get_montage() +montage.apply_trans(subj_trans_ecog) # convert to surface RAS +# convert to scanner RAS +mne_bids.convert_montage_to_ras( + montage, subject="sample_ecog", subjects_dir=misc_path / "ecog" +) + +raw_ecog.set_montage(None) # clear already found montage +# fake surgical plans from already-found contact locations +targets = dict() +ch_pos = montage.get_positions()["ch_pos"] +for elec in set( + [ + "".join([letter for letter in ch if not letter.isdigit()]) + for ch in raw_ecog.ch_names + ] +): + names = [ch for ch in raw_ecog.ch_names if ch.replace(elec, "").isdigit()] + ch1 = [name for name in names if name.replace(elec, "") == "1"][0] + ch2 = [name for name in names if name.replace(elec, "") == "2"][0] + targets[elec] = ( + ch_pos[ch1], + ch_pos[ch2], + ) # use second channel so grid counts the right way + +mne_gui.locate_ieeg( + raw_ecog.info, + subj_trans_ecog, + CT_aligned_ecog, + subject="sample_ecog", + subjects_dir=misc_path / "ecog", + targets=targets, +) + # %% # For ECoG, we typically want to account for "brain shift" or shrinking of the # brain away from the skull/dura due to changes in pressure during the diff --git a/mne_gui_addons/_ieeg_locate.py b/mne_gui_addons/_ieeg_locate.py index db4f079..6948397 100644 --- a/mne_gui_addons/_ieeg_locate.py +++ b/mne_gui_addons/_ieeg_locate.py @@ -142,7 +142,7 @@ def __init__( ) if targets: - self._auto_find_contacts(targets) + self.auto_find_contacts(targets) # set current position as current contact location if exists if not np.isnan(self._chs[self._ch_names[self._ch_index]]).any(): @@ -486,70 +486,66 @@ def _group_channels(self, groups): base_names[base_name] = i i += 1 - def _find_local_maxima(self, target, check_nearest=5, max_search_radius=20): + def _deduplicate_local_maxima(self, local_maxima): + """De-duplicate peaks by finding center of mass of high-intensity voxels.""" + local_maxima2 = set() + for local_max in local_maxima: + neighbors = _voxel_neighbors( + local_max, + self._ct_data, + thresh=0.5, + voxels_max=self._radius**3, + use_relative=True, + ) + loc = np.array(list(neighbors)).mean(axis=0) + if ( + not local_maxima2 + or np.min( + [np.linalg.norm(np.array(loc2) - loc) for loc2 in local_maxima2] + ) + > 1 + ): + local_maxima2.add(tuple(loc)) + return np.array([np.array(local_max) for local_max in local_maxima2]) + + def _find_local_maxima(self, target, check_nearest=3, max_search_radius=50): target_vox = ( apply_trans(self._scan_ras_ras_vox_t, target * 1000).round().astype(int) ) search_radius = 1 - while ( - np.nansum( - self._ct_maxima[ - tuple( - slice( - target_vox[i] - search_radius, target_vox[i] + search_radius + local_maxima = None + while local_maxima is None or local_maxima.shape[0] < check_nearest: + local_maxima = ( + np.array( + np.where( + ~np.isnan( + self._ct_maxima[ + tuple( + slice( + target_vox[i] - search_radius, + target_vox[i] + search_radius, + ) + for i in range(3) + ) + ] ) - for i in range(3) ) - ] + ).T + + target_vox + - search_radius ) - < check_nearest - ): + local_maxima = self._deduplicate_local_maxima(local_maxima) search_radius += 1 if search_radius > max_search_radius: break if search_radius > max_search_radius: return - local_maxima = ( - np.array( - np.where( - ~np.isnan( - self._ct_maxima[ - tuple( - slice( - target_vox[i] - search_radius, - target_vox[i] + search_radius, - ) - for i in range(3) - ) - ] - ) - ) - ).T - + target_vox - - search_radius - ) - # de-duplicate peaks by finding center of mass of high-intensity voxels - shape = np.mean(self._voxel_sizes) # Freesurfer shape (256) - voxels_max = int( - 4 / 3 * np.pi * (shape * self._radius / _CH_PLOT_SIZE) ** 3 - ) - local_maxima2 = set() - for local_max in local_maxima: - neighbors = _voxel_neighbors( - local_max, - self._ct_data, - thresh=0.5, - voxels_max=voxels_max, - use_relative=True, - ) - local_maxima2.add(tuple(np.array(list(neighbors)).mean(axis=0))) - local_maxima = np.array([np.array(local_max) for local_max in local_maxima2]) local_maxima = local_maxima[ np.argsort(np.linalg.norm(local_maxima - target_vox, axis=1)) ] return local_maxima - def _auto_find_line(self, tv, r, max_search_radius=20, voxel_tol=2): + def _auto_find_line(self, tv, r, max_search_radius=50, voxel_tol=2): """Look for local maxima on a line.""" # move in that direction and count to number of contact in group locs = [tuple(tv)] @@ -558,7 +554,9 @@ def _auto_find_line(self, tv, r, max_search_radius=20, voxel_tol=2): # stop when all the contacts or found or you have moved more than # max_search radius without finding another one while abs(t) < max_search_radius: - check_vox = (locs[-1] + t * r).round().astype(int) + check_vox = ( + (locs[-1 if direction == 1 else 0] + t * r).round().astype(int) + ) next_locs = ( np.array( np.where( @@ -578,28 +576,35 @@ def _auto_find_line(self, tv, r, max_search_radius=20, voxel_tol=2): + check_vox - voxel_tol ) + next_locs = self._deduplicate_local_maxima(next_locs) for next_loc in next_locs: - if tuple(next_loc) not in locs: + if np.min([np.linalg.norm(next_loc - loc) for loc in locs]) > 1: t = 0 locs.insert(len(locs) if direction == 1 else 0, tuple(next_loc)) t += direction return locs def _auto_find_grid( - self, tv, r, check_nearest=5, max_search_radius=20, voxel_tol=2 + self, tv, r, check_nearest=3, max_search_radius=50, voxel_tol=2 ): """Automatically find a series of lines to form a grid.""" - local_maxima = self._find_local_maxima( - tv, check_nearest=check_nearest, max_search_radius=max_search_radius - ) # first, find first line of contacts locs = self._auto_find_line( tv, r, max_search_radius=max_search_radius, voxel_tol=voxel_tol ) if len(locs) < 3: - return - local_maxima = [loc for loc in local_maxima if tuple(loc) not in locs] - tv = locs[0] # re-pick target value in case shifted to second contact + return [] + tv = np.array(locs[0]) # re-pick target value in case shifted to second contact + local_maxima = self._find_local_maxima( + apply_trans(self._ras_vox_scan_ras_t, tv) / 1000, + check_nearest=check_nearest, + max_search_radius=max_search_radius, + ) + local_maxima = [ + loc + for loc in local_maxima + if np.min([np.linalg.norm(loc - loc2) for loc2 in locs]) > 1 + ] # next fine a line of contacts in a different direction for tv2 in local_maxima: # find specified direction vector/direction vector to next contact @@ -619,14 +624,33 @@ def _auto_find_grid( locs += locs3 return locs - def _auto_find_contacts( + def auto_find_contacts( self, targets, - check_nearest=5, - max_search_radius=20, + check_nearest=3, + max_search_radius="auto", voxel_tol=2, ): - """Try automatically finding contact locations from targets.""" + """Try automatically finding contact locations from targets. + + Parameters + ---------- + targets : dict + Keys are names of groups (electrodes/grids) and values are target and + entry (optional) locations in scanner RAS. + check_nearest : int + The number of nearest neighbors to check for completing lines. Increase + if locations are not found because artifactual high-intensity areas + are causing the wrong line directions. + max_search_radius : int | 'auto' + The maximum distance to search for a high-intensity voxel away from + the last point found. ``auto`` uses 50 for sEEG in order to find + electrodes across spanning gaps and 10 for ECoG so as not to be + confused by all the extra points (especially if there are two grids). + voxel_tol : int + The number of voxels away from the line local maxima are allowed to + be in order to be marked. + """ _validate_type(targets, (dict,), "targets") self._update_ct_maxima() for elec, target in targets.items(): @@ -650,6 +674,8 @@ def _auto_find_contacts( is_ecog = all( [self._info.ch_names.index(name) in self._ecog_idx for name in names] ) + if max_search_radius == "auto": + max_search_radius = 10 if is_ecog else 50 if not names or not all( [np.isnan(self._chs[name]).all() for name in names] ): @@ -670,7 +696,7 @@ def _auto_find_contacts( v /= np.linalg.norm(v) for i, tv in enumerate( local_maxima[:check_nearest] - ): # try sequentially based on closest + ): # try neartest sequentially # only try entry if given, otherwise try other local maxima as direction vectors for tv2 in local_maxima[i + 1 :] if entry is None else [tv + v]: # find specified direction vector/direction vector to next contact @@ -705,7 +731,9 @@ def _auto_find_contacts( # assign locations for name, loc in zip(names, locs): vox = apply_trans(self._ras_vox_scan_ras_t, loc) - self._chs[name][:] = apply_trans(self._scan_ras_mri_t, vox) # to surface RAS + self._chs[name][:] = apply_trans( + self._scan_ras_mri_t, vox + ) # to surface RAS self._color_list_item(name) self._save_ch_coords() @@ -717,11 +745,24 @@ def _auto_mark_group(self): if self._groups[name] == self._groups[self._ch_names[self._ch_index]] and not np.isnan(self._chs[name]).any() ] + names = [ + name + for name in self._groups + if self._groups[name] == self._groups[self._ch_names[self._ch_index]] + ] if len(locs) > 1: - if self._ch_idx in self._ecog_idx: - self._auto_find_grid(locs[0], locs[1] - locs[0]) + if self._ch_index in self._ecog_idx: + locs = self._auto_find_grid(locs[0], locs[1] - locs[0]) else: - self._auto_find_line(locs[0], locs[1] - locs[0]) + locs = self._auto_find_line(locs[0], locs[1] - locs[0]) + # assign locations + for name, loc in zip(names, locs): + vox = apply_trans(self._ras_vox_scan_ras_t, loc) + self._chs[name][:] = apply_trans( + self._scan_ras_mri_t, vox + ) # to surface RAS + self._color_list_item(name) + self._save_ch_coords() else: QMessageBox.information( self, @@ -881,15 +922,11 @@ def mark_channel(self, ch=None): self._scan_ras_mri_t, self._ras ) # stored as surface RAS else: - shape = np.mean(self._voxel_sizes) # Freesurfer shape (256) - voxels_max = int( - 4 / 3 * np.pi * (shape * self._radius / _CH_PLOT_SIZE) ** 3 - ) neighbors = _voxel_neighbors( self._vox, self._ct_data, thresh=0.5, - voxels_max=voxels_max, + voxels_max=self._radius**3, use_relative=True, ) self._chs[name][:] = apply_trans( # to surface RAS