diff --git a/mne_gui_addons/_core.py b/mne_gui_addons/_core.py index 37ac09e..57b6451 100644 --- a/mne_gui_addons/_core.py +++ b/mne_gui_addons/_core.py @@ -379,10 +379,11 @@ def _plot_images(self): plot_x_idx, plot_y_idx = self._xy_idx[axis] fig = self._figs[axis] ax = fig.axes[0] - img_data = np.take(self._base_data, self._current_slice[axis], axis=axis).T self._images["base"].append( ax.imshow( - img_data, + self._base_data[ + (slice(None),) * axis + (self._current_slice[axis],) + ].T, cmap="gray", aspect="auto", zorder=1, @@ -739,8 +740,9 @@ def _draw(self, axis=None): def _update_base_images(self, axis=None, draw=False): """Update the base images.""" for axis in range(3) if axis is None else [axis]: - img_data = np.take(self._base_data, self._current_slice[axis], axis=axis).T - self._images["base"][axis].set_data(img_data) + self._images["base"][axis].set_data( + self._base_data[(slice(None),) * axis + (self._current_slice[axis],)].T + ) if draw: self._draw(axis) diff --git a/mne_gui_addons/_ieeg_locate.py b/mne_gui_addons/_ieeg_locate.py index 6be3bbb..c372abf 100644 --- a/mne_gui_addons/_ieeg_locate.py +++ b/mne_gui_addons/_ieeg_locate.py @@ -973,22 +973,29 @@ def _update_ch_images(self, axis=None, draw=False): def _update_ct_images(self, axis=None, draw=False): """Update the CT image(s).""" for axis in range(3) if axis is None else [axis]: - ct_data = np.take(self._ct_data, self._current_slice[axis], axis=axis).T + ct_data = ( + self._ct_data[(slice(None),) * axis + (self._current_slice[axis],)] + .copy() + .T + ) # Threshold the CT so only bright objects (electrodes) are visible ct_data[ct_data < self._ct_min_slider.value()] = np.nan ct_data[ct_data > self._ct_max_slider.value()] = np.nan self._images["ct"][axis].set_data(ct_data) if "local_max" in self._images: - ct_max_data = np.take( - self._ct_maxima, self._current_slice[axis], axis=axis - ).T - self._images["local_max"][axis].set_data(ct_max_data) + self._images["local_max"][axis].set_data( + self._ct_maxima[ + (slice(None),) * axis + (self._current_slice[axis],) + ].T + ) if draw: self._draw(axis) def _get_mr_slice(self, axis): """Get the current MR slice.""" - mri_data = np.take(self._mr_data, self._current_slice[axis], axis=axis).T + mri_data = self._mr_data[ + (slice(None),) * axis + (self._current_slice[axis],) + ].T if self._using_atlas: mri_slice = mri_data.copy().astype(int) mri_data = np.zeros(mri_slice.shape + (3,), dtype=int) @@ -1135,14 +1142,13 @@ def _toggle_show_max(self): self._update_ct_maxima() self._images["local_max"] = list() for axis in range(3): - ct_max_data = np.take( - self._ct_maxima, self._current_slice[axis], axis=axis - ).T self._images["local_max"].append( self._figs[axis] .axes[0] .imshow( - ct_max_data, + self._ct_maxima[ + (slice(None),) * axis + (self._current_slice[axis],) + ].T, cmap="autumn", aspect="auto", vmin=0, @@ -1166,11 +1172,16 @@ def _toggle_show_brain(self): self._images["mri"] = list() for axis in range(3): cmap = None if self._using_atlas else "hot" - mri_data = self._get_mr_slice(axis) self._images["mri"].append( self._figs[axis] .axes[0] - .imshow(mri_data, cmap=cmap, aspect="auto", alpha=0.25, zorder=2) + .imshow( + self._get_mr_slice(axis), + cmap="hot", + aspect="auto", + alpha=0.25, + zorder=2, + ) ) self._draw() diff --git a/mne_gui_addons/_segment.py b/mne_gui_addons/_segment.py index d6c06fa..332eb25 100644 --- a/mne_gui_addons/_segment.py +++ b/mne_gui_addons/_segment.py @@ -320,7 +320,9 @@ def _update_img_scale(self): def _update_base_images(self, axis=None, draw=False): """Update the CT image(s).""" for axis in range(3) if axis is None else [axis]: - img_data = np.take(self._base_data, self._current_slice[axis], axis=axis).T + img_data = self._base_data[ + (slice(None),) * axis + (self._current_slice[axis],) + ].T.copy() img_data[img_data < self._img_min_slider.value()] = np.nan img_data[img_data > self._img_max_slider.value()] = np.nan self._images["base"][axis].set_data(img_data) @@ -335,10 +337,11 @@ def _plot_vol_images(self): for axis in range(3): fig = self._figs[axis] ax = fig.axes[0] - vol_data = np.take(self._vol_img, self._current_slice[axis], axis=axis).T self._images["vol"].append( ax.imshow( - vol_data, + self._vol_img[ + (slice(None),) * axis + (self._current_slice[axis],) + ].T, aspect="auto", zorder=3, cmap=_CMAP, @@ -438,8 +441,9 @@ def _mark_all(self): def _update_vol_images(self, axis=None, draw=False): """Update the volume image(s).""" for axis in range(3) if axis is None else [axis]: - vol_data = np.take(self._vol_img, self._current_slice[axis], axis=axis).T - self._images["vol"][axis].set_data(vol_data) + self._images["vol"][axis].set_data( + self._vol_img[(slice(None),) * axis + (self._current_slice[axis],)].T + ) if draw: self._draw(axis) diff --git a/mne_gui_addons/_vol_stc.py b/mne_gui_addons/_vol_stc.py index c3da680..d57b097 100644 --- a/mne_gui_addons/_vol_stc.py +++ b/mne_gui_addons/_vol_stc.py @@ -306,7 +306,6 @@ def __init__( ] src_coord = self._get_src_coord() for axis in range(3): - stc_slice = np.take(self._stc_img, src_coord[axis], axis=axis).T x_idx, y_idx = self._xy_idx[axis] extent = [ corners[0][x_idx], @@ -318,7 +317,7 @@ def __init__( self._figs[axis] .axes[0] .imshow( - stc_slice, + self._stc_img[(slice(None),) * axis + (src_coord[axis],)].T, aspect="auto", extent=extent, cmap=self._cmap, @@ -507,7 +506,7 @@ def _apply_vector_norm(self, stc_data, axis=1): # if self._data.dtype in (COMPLEX_DTYPE, BASE_INT_DTYPE): # stc_data = stc_data.round().astype(BASE_INT_DTYPE) else: - stc_data = np.take(stc_data, 0, axis=axis) + stc_data = stc_data[(slice(None),) * axis + (0,)] return stc_data def _apply_baseline_correction(self, stc_data): @@ -541,9 +540,9 @@ def _pick_stc_vertex(self, stc_data): def _pick_stc_tfr(self, stc_data): """Select the frequency and time based on GUI values.""" - stc_data = np.take(stc_data, self._t_idx, axis=-1) + stc_data = stc_data[..., self._t_idx] f_idx = 0 if self._f_idx is None else self._f_idx - stc_data = np.take(stc_data, f_idx, axis=-1) + stc_data = stc_data[..., f_idx] return stc_data def _configure_ui(self): @@ -1380,10 +1379,13 @@ def _plot_stc_images(self, axis=None, draw=True): for axis in range(3): # ensure in bounds if src_coord[axis] >= 0 and src_coord[axis] < self._stc_img.shape[axis]: - stc_slice = np.take(self._stc_img, src_coord[axis], axis=axis).T + self._images["stc"][axis].set_data( + self._stc_img[(slice(None),) * axis + (src_coord[axis],)].T + ) else: - stc_slice = np.take(self._stc_img, 0, axis=axis).T * np.nan - self._images["stc"][axis].set_data(stc_slice) + self._images["stc"][axis].set_data( + self._stc_img[(slice(None),) * axis + (0,)].copy().T * np.nan + ) if draw and self._update: self._draw(axis) diff --git a/mne_gui_addons/tests/test_segment.py b/mne_gui_addons/tests/test_segment.py index 32c8836..e2f7041 100644 --- a/mne_gui_addons/tests/test_segment.py +++ b/mne_gui_addons/tests/test_segment.py @@ -37,9 +37,10 @@ def test_segment_display(renderer_interactive_pyvistaqt): # test no seghead, fsaverage doesn't have seghead with pytest.warns(RuntimeWarning, match="`seghead` not found"): - gui = VolumeSegmenter( - subject="fsaverage", subjects_dir=subjects_dir, verbose=True - ) + with pytest.warns(RuntimeWarning, match="`pial` surface not found"): + gui = VolumeSegmenter( + subject="fsaverage", subjects_dir=subjects_dir, verbose=True + ) # test functions gui.set_RAS([25.37, 0.00, 34.18]) diff --git a/mne_gui_addons/tests/test_vol_stc.py b/mne_gui_addons/tests/test_vol_stc.py index 044a437..23c8147 100644 --- a/mne_gui_addons/tests/test_vol_stc.py +++ b/mne_gui_addons/tests/test_vol_stc.py @@ -57,7 +57,7 @@ def _fake_stc(src_type="vol"): ) + 1j * rng.integers( -1000, 1000, size=(n_epochs, len(info.ch_names), freqs.size, times.size) ) - epochs_tfr = mne.time_frequency.EpochsTFR(info, data, times=times, freqs=freqs) + epochs_tfr = mne.time_frequency.EpochsTFRArray(info, data, times=times, freqs=freqs) nuse = sum([this_src["nuse"] for this_src in src]) stc_data = rng.integers( -1000, 1000, size=(n_epochs, nuse, 3, freqs.size, times.size)