diff --git a/test/unittests/utils/test_atlas.py b/test/unittests/utils/test_atlas.py index aa4288c1f..ccb74538d 100644 --- a/test/unittests/utils/test_atlas.py +++ b/test/unittests/utils/test_atlas.py @@ -1,4 +1,6 @@ import re +from pathlib import Path +from typing import List import nibabel as nib import numpy as np @@ -110,6 +112,7 @@ def test_atlas_factory(tmp_path, monkeypatch, atlas_name, atlas): ), ) def test_atlases( + tmp_path, atlas, expected_name, expected_checksum, @@ -117,7 +120,16 @@ def test_atlases( expected_roi_filename, expected_resolution, expected_size, + mocker, ): + # The following atlases are supposed to be downloaded from a remote server + # For unit testing, we mock the get_file_from_server function to return + # a nifti image with fake data realistic enough to make the tests pass + if expected_name in ("Hammers", "LPBA40", "Neuromorphometrics"): + mocker.patch( + "clinica.utils.inputs.get_file_from_server", + return_value=get_mocked_atlas(tmp_path, atlas), + ) assert atlas.name == expected_name assert atlas.expected_checksum == expected_checksum assert atlas.atlas_filename == expected_atlas_filename @@ -128,6 +140,107 @@ def test_atlases( assert_array_equal(atlas.get_index(), np.arange(expected_size)) +def get_mocked_atlas(tmp_path, atlas) -> Path: + """Return the path to the mocked atlas label image. + + We need the mocked image to have: + - the right resolution (same for all mocked atlases) + - the right data shape (same for all mocked atlases) + - the right label values + + To ensure the last point, this function generate a data array + composed of the cycled expected labels. + """ + from itertools import cycle + + data_shape = (121, 145, 121) + gen = cycle(get_labels(atlas)) + mocked_data = [] + while len(mocked_data) < np.prod(data_shape): + mocked_data.append(next(gen)) + mocked_data = np.array(mocked_data, dtype="float") + mocked_data = mocked_data.reshape(data_shape) + mocked_image = nib.Nifti1Image(mocked_data, np.diag([-1.5, 1.5, 1.5, 1.0])) + nib.save(mocked_image, tmp_path / "mocked_atlas.nii.gz") + + return tmp_path / "mocked_atlas.nii.gz" + + +def get_labels(atlas) -> List[int]: + """Return the expected label values for the mocked atlases.""" + if atlas.name == "Neuromorphometrics": + labels = list(range(143)) + labels.remove(35) + labels.remove(23) + elif atlas.name == "LPBA40": + labels = [ + 0, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 101, + 102, + 121, + 122, + 161, + 162, + 163, + 164, + 165, + 166, + 181, + 182, + ] + elif atlas.name == "Hammers": + labels = list(range(69)) + else: + raise ValueError( + f"Atlas {atlas.name} is not supposed to be mocked in this test." + ) + return labels + + @pytest.fixture def atlas(expected_name, tmp_path, monkeypatch): from clinica.utils.atlas import atlas_factory