Skip to content

Commit

Permalink
works perfectly for seeg, needs a bit of work for ecog
Browse files Browse the repository at this point in the history
  • Loading branch information
alexrockhill committed Nov 21, 2023
1 parent bf8d40f commit 3b822db
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 76 deletions.
53 changes: 48 additions & 5 deletions examples/ieeg_locate.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,20 +397,24 @@ def plot_overlay(image, compare, title, thresh=None):
montage = raw.get_montage()
montage.apply_trans(subj_trans) # convert to surface RAS
# convert to scanner RAS
mne_bids.convert_montage_to_ras(montage, subject='sample_seeg', subjects_dir=misc_path / "seeg")
mne_bids.convert_montage_to_ras(
montage, subject="sample_seeg", subjects_dir=misc_path / "seeg"
)

raw.set_montage(None) # clear already found montage
# fake surgical plans from already-found contact locations
targets = {''.join([letter for letter in ch if not letter.isdigit()]): pos
for ch, pos in montage.get_positions()['ch_pos'].items()
if [letter for letter in ch if letter.isdigit()] == ['1']}
targets = {
"".join([letter for letter in ch if not letter.isdigit()]): pos
for ch, pos in montage.get_positions()["ch_pos"].items()
if [letter for letter in ch if letter.isdigit()] == ["1"]
}
mne_gui.locate_ieeg(
raw.info,
subj_trans,
CT_aligned,
subject="sample_seeg",
subjects_dir=misc_path / "seeg",
targets=targets
targets=targets,
)

# %%
Expand Down Expand Up @@ -444,6 +448,45 @@ def plot_overlay(image, compare, title, thresh=None):
subjects_dir=misc_path / "ecog",
)

# %%
# Similarly for ECoG, we can try to automatically detect contact locations.
# In the case of ECoG, there can often be a lot of contacts so this can be
# a big time saver.

montage = raw_ecog.get_montage()
montage.apply_trans(subj_trans_ecog) # convert to surface RAS
# convert to scanner RAS
mne_bids.convert_montage_to_ras(
montage, subject="sample_ecog", subjects_dir=misc_path / "ecog"
)

raw_ecog.set_montage(None) # clear already found montage
# fake surgical plans from already-found contact locations
targets = dict()
ch_pos = montage.get_positions()["ch_pos"]
for elec in set(
[
"".join([letter for letter in ch if not letter.isdigit()])
for ch in raw_ecog.ch_names
]
):
names = [ch for ch in raw_ecog.ch_names if ch.replace(elec, "").isdigit()]
ch1 = [name for name in names if name.replace(elec, "") == "1"][0]
ch2 = [name for name in names if name.replace(elec, "") == "2"][0]
targets[elec] = (
ch_pos[ch1],
ch_pos[ch2],
) # use second channel so grid counts the right way

mne_gui.locate_ieeg(
raw_ecog.info,
subj_trans_ecog,
CT_aligned_ecog,
subject="sample_ecog",
subjects_dir=misc_path / "ecog",
targets=targets,
)

# %%
# For ECoG, we typically want to account for "brain shift" or shrinking of the
# brain away from the skull/dura due to changes in pressure during the
Expand Down
179 changes: 108 additions & 71 deletions mne_gui_addons/_ieeg_locate.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __init__(
)

if targets:
self._auto_find_contacts(targets)
self.auto_find_contacts(targets)

# set current position as current contact location if exists
if not np.isnan(self._chs[self._ch_names[self._ch_index]]).any():
Expand Down Expand Up @@ -486,70 +486,66 @@ def _group_channels(self, groups):
base_names[base_name] = i
i += 1

def _find_local_maxima(self, target, check_nearest=5, max_search_radius=20):
def _deduplicate_local_maxima(self, local_maxima):
"""De-duplicate peaks by finding center of mass of high-intensity voxels."""
local_maxima2 = set()
for local_max in local_maxima:
neighbors = _voxel_neighbors(
local_max,
self._ct_data,
thresh=0.5,
voxels_max=self._radius**3,
use_relative=True,
)
loc = np.array(list(neighbors)).mean(axis=0)
if (
not local_maxima2
or np.min(
[np.linalg.norm(np.array(loc2) - loc) for loc2 in local_maxima2]
)
> 1
):
local_maxima2.add(tuple(loc))
return np.array([np.array(local_max) for local_max in local_maxima2])

def _find_local_maxima(self, target, check_nearest=3, max_search_radius=50):
target_vox = (
apply_trans(self._scan_ras_ras_vox_t, target * 1000).round().astype(int)
)
search_radius = 1
while (
np.nansum(
self._ct_maxima[
tuple(
slice(
target_vox[i] - search_radius, target_vox[i] + search_radius
local_maxima = None
while local_maxima is None or local_maxima.shape[0] < check_nearest:
local_maxima = (
np.array(
np.where(
~np.isnan(
self._ct_maxima[
tuple(
slice(
target_vox[i] - search_radius,
target_vox[i] + search_radius,
)
for i in range(3)
)
]
)
for i in range(3)
)
]
).T
+ target_vox
- search_radius
)
< check_nearest
):
local_maxima = self._deduplicate_local_maxima(local_maxima)
search_radius += 1
if search_radius > max_search_radius:
break
if search_radius > max_search_radius:
return
local_maxima = (
np.array(
np.where(
~np.isnan(
self._ct_maxima[
tuple(
slice(
target_vox[i] - search_radius,
target_vox[i] + search_radius,
)
for i in range(3)
)
]
)
)
).T
+ target_vox
- search_radius
)
# de-duplicate peaks by finding center of mass of high-intensity voxels
shape = np.mean(self._voxel_sizes) # Freesurfer shape (256)
voxels_max = int(
4 / 3 * np.pi * (shape * self._radius / _CH_PLOT_SIZE) ** 3
)
local_maxima2 = set()
for local_max in local_maxima:
neighbors = _voxel_neighbors(
local_max,
self._ct_data,
thresh=0.5,
voxels_max=voxels_max,
use_relative=True,
)
local_maxima2.add(tuple(np.array(list(neighbors)).mean(axis=0)))
local_maxima = np.array([np.array(local_max) for local_max in local_maxima2])
local_maxima = local_maxima[
np.argsort(np.linalg.norm(local_maxima - target_vox, axis=1))
]
return local_maxima

def _auto_find_line(self, tv, r, max_search_radius=20, voxel_tol=2):
def _auto_find_line(self, tv, r, max_search_radius=50, voxel_tol=2):
"""Look for local maxima on a line."""
# move in that direction and count to number of contact in group
locs = [tuple(tv)]
Expand All @@ -558,7 +554,9 @@ def _auto_find_line(self, tv, r, max_search_radius=20, voxel_tol=2):
# stop when all the contacts or found or you have moved more than
# max_search radius without finding another one
while abs(t) < max_search_radius:
check_vox = (locs[-1] + t * r).round().astype(int)
check_vox = (
(locs[-1 if direction == 1 else 0] + t * r).round().astype(int)
)
next_locs = (
np.array(
np.where(
Expand All @@ -578,28 +576,35 @@ def _auto_find_line(self, tv, r, max_search_radius=20, voxel_tol=2):
+ check_vox
- voxel_tol
)
next_locs = self._deduplicate_local_maxima(next_locs)
for next_loc in next_locs:
if tuple(next_loc) not in locs:
if np.min([np.linalg.norm(next_loc - loc) for loc in locs]) > 1:
t = 0
locs.insert(len(locs) if direction == 1 else 0, tuple(next_loc))
t += direction
return locs

def _auto_find_grid(
self, tv, r, check_nearest=5, max_search_radius=20, voxel_tol=2
self, tv, r, check_nearest=3, max_search_radius=50, voxel_tol=2
):
"""Automatically find a series of lines to form a grid."""
local_maxima = self._find_local_maxima(
tv, check_nearest=check_nearest, max_search_radius=max_search_radius
)
# first, find first line of contacts
locs = self._auto_find_line(
tv, r, max_search_radius=max_search_radius, voxel_tol=voxel_tol
)
if len(locs) < 3:
return
local_maxima = [loc for loc in local_maxima if tuple(loc) not in locs]
tv = locs[0] # re-pick target value in case shifted to second contact
return []
tv = np.array(locs[0]) # re-pick target value in case shifted to second contact
local_maxima = self._find_local_maxima(
apply_trans(self._ras_vox_scan_ras_t, tv) / 1000,
check_nearest=check_nearest,
max_search_radius=max_search_radius,
)
local_maxima = [
loc
for loc in local_maxima
if np.min([np.linalg.norm(loc - loc2) for loc2 in locs]) > 1
]
# next fine a line of contacts in a different direction
for tv2 in local_maxima:
# find specified direction vector/direction vector to next contact
Expand All @@ -619,14 +624,33 @@ def _auto_find_grid(
locs += locs3
return locs

def _auto_find_contacts(
def auto_find_contacts(
self,
targets,
check_nearest=5,
max_search_radius=20,
check_nearest=3,
max_search_radius="auto",
voxel_tol=2,
):
"""Try automatically finding contact locations from targets."""
"""Try automatically finding contact locations from targets.
Parameters
----------
targets : dict
Keys are names of groups (electrodes/grids) and values are target and
entry (optional) locations in scanner RAS.
check_nearest : int
The number of nearest neighbors to check for completing lines. Increase
if locations are not found because artifactual high-intensity areas
are causing the wrong line directions.
max_search_radius : int | 'auto'
The maximum distance to search for a high-intensity voxel away from
the last point found. ``auto`` uses 50 for sEEG in order to find
electrodes across spanning gaps and 10 for ECoG so as not to be
confused by all the extra points (especially if there are two grids).
voxel_tol : int
The number of voxels away from the line local maxima are allowed to
be in order to be marked.
"""
_validate_type(targets, (dict,), "targets")
self._update_ct_maxima()
for elec, target in targets.items():
Expand All @@ -650,6 +674,8 @@ def _auto_find_contacts(
is_ecog = all(
[self._info.ch_names.index(name) in self._ecog_idx for name in names]
)
if max_search_radius == "auto":
max_search_radius = 10 if is_ecog else 50
if not names or not all(
[np.isnan(self._chs[name]).all() for name in names]
):
Expand All @@ -670,7 +696,7 @@ def _auto_find_contacts(
v /= np.linalg.norm(v)
for i, tv in enumerate(
local_maxima[:check_nearest]
): # try sequentially based on closest
): # try neartest sequentially
# only try entry if given, otherwise try other local maxima as direction vectors
for tv2 in local_maxima[i + 1 :] if entry is None else [tv + v]:
# find specified direction vector/direction vector to next contact
Expand Down Expand Up @@ -705,7 +731,9 @@ def _auto_find_contacts(
# assign locations
for name, loc in zip(names, locs):
vox = apply_trans(self._ras_vox_scan_ras_t, loc)
self._chs[name][:] = apply_trans(self._scan_ras_mri_t, vox) # to surface RAS
self._chs[name][:] = apply_trans(
self._scan_ras_mri_t, vox
) # to surface RAS
self._color_list_item(name)
self._save_ch_coords()

Expand All @@ -717,11 +745,24 @@ def _auto_mark_group(self):
if self._groups[name] == self._groups[self._ch_names[self._ch_index]]
and not np.isnan(self._chs[name]).any()
]
names = [
name
for name in self._groups
if self._groups[name] == self._groups[self._ch_names[self._ch_index]]
]
if len(locs) > 1:
if self._ch_idx in self._ecog_idx:
self._auto_find_grid(locs[0], locs[1] - locs[0])
if self._ch_index in self._ecog_idx:
locs = self._auto_find_grid(locs[0], locs[1] - locs[0])
else:
self._auto_find_line(locs[0], locs[1] - locs[0])
locs = self._auto_find_line(locs[0], locs[1] - locs[0])
# assign locations
for name, loc in zip(names, locs):
vox = apply_trans(self._ras_vox_scan_ras_t, loc)
self._chs[name][:] = apply_trans(
self._scan_ras_mri_t, vox
) # to surface RAS
self._color_list_item(name)
self._save_ch_coords()
else:
QMessageBox.information(
self,
Expand Down Expand Up @@ -881,15 +922,11 @@ def mark_channel(self, ch=None):
self._scan_ras_mri_t, self._ras
) # stored as surface RAS
else:
shape = np.mean(self._voxel_sizes) # Freesurfer shape (256)
voxels_max = int(
4 / 3 * np.pi * (shape * self._radius / _CH_PLOT_SIZE) ** 3
)
neighbors = _voxel_neighbors(
self._vox,
self._ct_data,
thresh=0.5,
voxels_max=voxels_max,
voxels_max=self._radius**3,
use_relative=True,
)
self._chs[name][:] = apply_trans( # to surface RAS
Expand Down

0 comments on commit 3b822db

Please sign in to comment.