From c70e36d79d4a85f41ff78c80dc4c7cea99fc3985 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Mon, 6 May 2024 18:55:13 -0400 Subject: [PATCH] ENH: Add DWI volume plot method Add DWI volume plot method. --- nireports/reportlets/modality/dwi.py | 56 ++++++++++++++++++++++++++++ nireports/tests/test_dwi.py | 29 +++++++++++++- 2 files changed, 84 insertions(+), 1 deletion(-) diff --git a/nireports/reportlets/modality/dwi.py b/nireports/reportlets/modality/dwi.py index d28727d4..c6dbdee2 100644 --- a/nireports/reportlets/modality/dwi.py +++ b/nireports/reportlets/modality/dwi.py @@ -21,10 +21,66 @@ # https://www.nipreps.org/community/licensing/ # """Visualizations for diffusion MRI data.""" +import nibabel as nb import numpy as np from matplotlib import pyplot as plt from matplotlib.pyplot import cm from mpl_toolkits.mplot3d import art3d +from nilearn.plotting import plot_anat + + +def plot_dwi(dataobj, affine, gradient=None, **kwargs): + """ + Plot orthogonal (axial, coronal, sagittal) slices of a given DWI volume. The + slices displayed are determined by a tuple contained in the ``cut_coords`` + keyword argument. + + Parameters + ---------- + dataobj : :obj:`numpy.ndarray` + DWI volume data: a single 3D volume from a given gradient direction. + affine : :obj:`numpy.ndarray` + Affine transformation matrix. + gradient : :obj:`numpy.ndarray` + Gradient values in RAS+b format at the chosen gradient direction. + kwargs : :obj:`dict` + Extra args given to :obj:`nilearn.plotting.plot_anat()`. + + Returns + ------- + :class:`nilearn.plotting.displays.OrthoSlicer` or None + An instance of the OrthoSlicer class. If ``output_file`` is defined, + None is returned. + + """ + + plt.rcParams.update( + { + "text.usetex": True, + "font.family": "sans-serif", + "font.sans-serif": ["Helvetica"], + } + ) + + affine = np.diag(nb.affines.voxel_sizes(affine).tolist() + [1]) + affine[:3, 3] = -1.0 * (affine[:3, :3] @ ((np.array(dataobj.shape) - 1) * 0.5)) + + vmax = kwargs.pop("vmax", None) or np.percentile(dataobj, 98) + cut_coords = kwargs.pop("cut_coords", None) or (0, 0, 0) + + return plot_anat( + nb.Nifti1Image(dataobj, affine, None), + vmax=vmax, + cut_coords=cut_coords, + title=( + r"Reference $b$=0" + if gradient is None + else f"""\ +$b$={gradient[3].astype(int)}, \ +$\\vec{{b}}$ = ({', '.join(str(v) for v in gradient[:3])})""" + ), + **kwargs, + ) def plot_heatmap( diff --git a/nireports/tests/test_dwi.py b/nireports/tests/test_dwi.py index 830e3bc0..84257532 100644 --- a/nireports/tests/test_dwi.py +++ b/nireports/tests/test_dwi.py @@ -23,11 +23,38 @@ """Test DWI reportlets.""" import pytest +from pathlib import Path +import nibabel as nb import numpy as np from matplotlib import pyplot as plt -from nireports.reportlets.modality.dwi import plot_gradients +from nireports.reportlets.modality.dwi import plot_dwi, plot_gradients + + +@pytest.mark.parametrize( + 'dwi', 'dwi_btable', + ['ds000114_sub-01_ses-test_dwi.nii.gz', 'ds000114_singleshell'], +) +def test_plot_dwi(tmp_path, testdata_path, dwi, dwi_btable, outdir): + """Check the plot of DWI data.""" + + dwi_img = nb.load(testdata_path / f'{dwi}') + affine = dwi_img.affine + + bvecs = np.loadtxt(testdata_path / f'{dwi_btable}.bvec').T + bvals = np.loadtxt(testdata_path / f'{dwi_btable}.bval') + + gradients = np.hstack([bvecs, bvals[:, None]]) + + # Pick a random volume to show + rng = np.random.default_rng(1234) + idx = rng.integers(low=0, high=len(bvals), size=1).item() + + _ = plot_dwi(dwi_img.get_fdata()[..., idx], affine, gradient=gradients[idx]) + + if outdir is not None: + plt.savefig(outdir / f'{Path(dwi).with_suffix("").stem}.svg', bbox_inches='tight') @pytest.mark.parametrize(