diff --git a/scripts/ae_bundle_streamlines.py b/scripts/ae_bundle_streamlines.py index 8e1eef5..3762254 100644 --- a/scripts/ae_bundle_streamlines.py +++ b/scripts/ae_bundle_streamlines.py @@ -29,6 +29,9 @@ from tractolearn.models.autoencoding_utils import encode_data from tractolearn.models.model_pool import get_model from tractolearn.tractoio.utils import ( + assert_bundle_datum_exists, + assert_tractogram_exists, + filter_filenames, load_ref_anat_image, load_streamlines, read_data_from_json_file, @@ -331,15 +334,26 @@ def _build_arg_parser(): add_overwrite_arg(parser) add_verbose_arg(parser) - return parser.parse_args() + return parser def main(): - args = _build_arg_parser() - device = torch.device(args.device) + + parser = _build_arg_parser() + args = parser.parse_args() print(args) + streamline_classes = read_data_from_json_file(args.anatomy_file) + + # Get the bundles of interest + boi = list(streamline_classes.keys()) + + assert_tractogram_exists(parser, args.atlas_path, boi) + + thresholds = read_data_from_json_file(args.thresholds_file) + assert_bundle_datum_exists(parser, thresholds, boi) + if exists(args.output): if not args.overwrite: print( @@ -366,6 +380,8 @@ def main(): f"Please specify a number between 1 and 30. Got {args.num_neighbors}. " ) + device = torch.device(args.device) + logging.info(args) _set_up_logger(pjoin(args.output, LoggerKeys.logger_file_basename.name)) @@ -379,27 +395,20 @@ def main(): model.load_state_dict(state_dict) model.eval() - streamline_classes = read_data_from_json_file(args.anatomy_file) - - thresholds = read_data_from_json_file(args.thresholds_file) - latent_atlas_all = np.empty((0, 32)) y_latent_atlas_all = np.empty((0,)) - atlas_file = os.listdir(args.atlas_path) + # Filter the atlas filenames according to the bundles of interest + foi = filter_filenames(args.atlas_path, boi) logger.info("Loading atlas files ...") - for f in tqdm(atlas_file): + for f in tqdm(foi): key = f.split(".")[-2] - assert ( - key in thresholds.keys() - ), f"[!] Threshold: {key} not in threshold file" - X_a_not_flipped, y_a_not_flipped = load_streamlines( - pjoin(args.atlas_path, f), + f, args.common_space_reference, streamline_classes[key], resample=True, @@ -407,7 +416,7 @@ def main(): ) X_a_flipped, y_a_flipped = load_streamlines( - pjoin(args.atlas_path, f), + f, args.common_space_reference, streamline_classes[key], resample=True, diff --git a/tractolearn/tractoio/file_extensions.py b/tractolearn/tractoio/file_extensions.py new file mode 100644 index 0000000..46f7b13 --- /dev/null +++ b/tractolearn/tractoio/file_extensions.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- + +import enum + +fname_sep = "." + + +class DictDataExtensions(enum.Enum): + JSON = "json" + + +class TractogramExtensions(enum.Enum): + TCK = "tck" + TRK = "trk" diff --git a/tractolearn/tractoio/tests/test_utils.py b/tractolearn/tractoio/tests/test_utils.py new file mode 100644 index 0000000..35a390a --- /dev/null +++ b/tractolearn/tractoio/tests/test_utils.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import tempfile + +from tractolearn.tractoio.file_extensions import ( + DictDataExtensions, + TractogramExtensions, + fname_sep, +) +from tractolearn.tractoio.utils import ( + compose_filename, + filter_filenames, + identify_missing_bundle, + identify_missing_tractogram, + read_data_from_json_file, + save_data_to_json_file, +) + + +def test_identify_missing_bundle(tmp_path): + + with tempfile.NamedTemporaryFile( + suffix=fname_sep + DictDataExtensions.JSON.value, dir=tmp_path + ) as f: + + # Target bundle names + bundle_names = ["CC_Fr_1", "CST_L", "AC"] + + bundle_data = dict({"CC_Fr_1": 1.0, "CST_L": 2.0, "AC": 3.0}) + expected = sorted(set(bundle_names).difference(bundle_data.keys())) + + save_data_to_json_file(bundle_data, f.name) + + data = read_data_from_json_file( + os.path.join(tmp_path, os.listdir(tmp_path)[0]) + ) + + obtained = identify_missing_bundle(data, bundle_names) + + assert obtained == expected + + with tempfile.NamedTemporaryFile( + suffix=fname_sep + DictDataExtensions.JSON.value, dir=tmp_path + ) as f: + + # Target bundle names + bundle_names = ["Cu", "PrCu"] + + bundle_data = dict({"Cu": 2.0}) + expected = sorted(set(bundle_names).difference(bundle_data.keys())) + + save_data_to_json_file(bundle_data, f.name) + + data = read_data_from_json_file( + os.path.join(tmp_path, os.listdir(tmp_path)[0]) + ) + + obtained = identify_missing_bundle(data, bundle_names) + + assert obtained == expected + + +def test_identify_missing_tractogram(tmp_path): + + # Target bundle names + bundle_names = ["CC_Fr_1", "CST_L", "AC"] + + # Create some files in the temporary path + file_rootnames = ["CC_Fr_1", "CC_Fr_2", "AC"] + fnames = [ + compose_filename(tmp_path, val, TractogramExtensions.TRK.value) + for val in file_rootnames + ] + [open(val, "w") for val in fnames] + + expected = sorted(set(bundle_names).difference(file_rootnames)) + + obtained = identify_missing_tractogram(tmp_path, bundle_names) + + assert obtained == expected + + # Target bundle names + bundle_names = ["Cu"] + expected = [ + compose_filename(tmp_path, val, TractogramExtensions.TRK.value) + for val in bundle_names + ] + + # Create some files in the temporary path + file_rootnames = ["Cu", "PrCu"] + fnames = [ + compose_filename(tmp_path, val, TractogramExtensions.TRK.value) + for val in file_rootnames + ] + [open(val, "w") for val in fnames] + + expected = sorted(set(bundle_names).difference(file_rootnames)) + + obtained = identify_missing_tractogram(tmp_path, bundle_names) + + assert obtained == expected + + +def test_filter_fnames(tmp_path): + + # Target bundle names + bundle_names = ["CC_Fr_1", "CST_L", "AC"] + + # Create some files in the temporary path + file_rootnames = ["CC_Fr_1", "CC_Fr_2", "AC"] + fnames = [ + compose_filename(tmp_path, val, TractogramExtensions.TRK.value) + for val in file_rootnames + ] + [open(val, "w") for val in fnames] + + expected_rootnames = ["AC", "CC_Fr_1"] + expected = [ + compose_filename(tmp_path, val, TractogramExtensions.TRK.value) + for val in expected_rootnames + ] + + obtained = filter_filenames(tmp_path, bundle_names) + + assert obtained == expected + + # Target bundle names + bundle_names = ["Cu"] + expected = [ + compose_filename(tmp_path, val, TractogramExtensions.TRK.value) + for val in bundle_names + ] + + # Create some files in the temporary path + file_rootnames = ["Cu", "PrCu"] + fnames = [ + compose_filename(tmp_path, val, TractogramExtensions.TRK.value) + for val in file_rootnames + ] + [open(val, "w") for val in fnames] + + expected_rootnames = ["Cu"] + expected = [ + compose_filename(tmp_path, val, TractogramExtensions.TRK.value) + for val in expected_rootnames + ] + + obtained = filter_filenames(tmp_path, bundle_names) + + assert obtained == expected diff --git a/tractolearn/tractoio/utils.py b/tractolearn/tractoio/utils.py index c426fa4..2b9554d 100644 --- a/tractolearn/tractoio/utils.py +++ b/tractolearn/tractoio/utils.py @@ -17,6 +17,7 @@ from tractolearn.anatomy.bundles_additional_labels import ( BundlesAdditionalLabels, ) +from tractolearn.tractoio.file_extensions import fname_sep from tractolearn.transformation.streamline_transformation import ( flip_random_streamlines, flip_streamlines, @@ -65,6 +66,177 @@ transformation_data_dir_label = "transformation" +def identify_missing_tractogram(dirname, bundle_names): + """Identify bundles whose tractograms are missing. Assumes exact matches + in the filename root. + + Parameters + ---------- + dirname : Path + Dirname + bundle_names : list + Bundle names. + + Returns + ------- + list + Missing bundle names. + """ + + fname_comps = [ + os.path.splitext(os.path.basename(fname)) + for fname in sorted(os.listdir(dirname)) + ] + + return sorted(set(bundle_names).difference(list(zip(*fname_comps))[0])) + + +def identify_missing_bundle(data, bundle_names): + """Identifies bundles whose data are missing. Assumes exact matches in the + filename root. + + Parameters + ---------- + data : dict + Data. + bundle_names : list + Bundle names. + + Returns + ------- + list + Missing bundle names. + """ + + return sorted(set(bundle_names).difference(data.keys())) + + +def assert_tractogram_exists(parser, dirname, bundle_names): + """Assert that all tractograms corresponding to the given bundles exist. + Assumes exact matches in the filename root. + + Parameters + ---------- + parser : ArgumentParser + Parser. + dirname : Path + Dirname. + bundle_names : str + Bundle names. + + Returns + ------- + Raises an error if missing bundles. + """ + + missing = identify_missing_tractogram(dirname, bundle_names) + if missing: + parser.error( + "Tractograms corresponding to the following bundles are missing:\n" + + ", ".join(missing) + ) + + +def assert_bundle_datum_exists(parser, data, bundle_names): + """Assert that all data corresponding to the given bundles exist. Assumes + exact matches in the filename root. + + Parameters + ---------- + parser : ArgumentParser + Parser. + data : dict + Data. + bundle_names : str + Bundle names. + + Returns + ------- + Raises an error if missing bundles. + """ + + missing = identify_missing_bundle(data, bundle_names) + if missing: + parser.error( + "Data corresponding to the following bundles are missing:\n" + + ", ".join(missing) + ) + + +def compose_filename(dirname, file_rootname, file_ext): + """Compose filename given a dirname, a rootname and an extension. + + Parameters + ---------- + dirname : Path + Dirname. + file_rootname : str + File rootname. + file_ext : str + File extension. + + Returns + ------- + str + Filename. + """ + + return os.path.join(dirname, file_rootname + fname_sep + file_ext) + + +def retrieve_matching_indices(a, b): + """Retrieve indices of the candidate items that match exactly with any of + the query items. + + Parameters + ---------- + a : list + Candidate items. + b : list + Query items. + + Returns + ------- + list + Indices of the relevant matches in a. + """ + + b_set = set(b) + return [i for i, v in enumerate(a) if v in b_set] + + +def filter_filenames(path, bundle_names): + """Filter the relevant filenames in a path according to the provided bundle + names of interest. Assumes exact matches in the filename root. Filenames + are returned in sorted order. + + Parameters + ---------- + path : Path + Folder name. + bundle_names : list + Bundle names. + + Returns + ------- + list + Filenames corresponding to the bundles of interest. + """ + + fname_comps = [ + os.path.splitext(os.path.basename(fname)) + for fname in sorted(os.listdir(path)) + ] + + match_idx = retrieve_matching_indices( + list(zip(*fname_comps))[0], bundle_names + ) + fname_comp_match = [fname_comps[i] for i in match_idx] + + # Compose the relevant filenames back + return [os.path.join(path, item[0] + item[1]) for item in fname_comp_match] + + def load_bundles_dict(dataset_name): """Loads the bundle dictionary containing their names and classes corresponding to a dataset.