Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Assert bundle tractograms and data in bundling script #58

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions scripts/ae_bundle_streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from tractolearn.models.autoencoding_utils import encode_data
from tractolearn.models.model_pool import get_model
from tractolearn.tractoio.utils import (
assert_bundle_datum_exists,
assert_tractogram_exists,
filter_filenames,
load_ref_anat_image,
load_streamlines,
read_data_from_json_file,
Expand Down Expand Up @@ -331,15 +334,26 @@ def _build_arg_parser():
add_overwrite_arg(parser)
add_verbose_arg(parser)

return parser.parse_args()
return parser


def main():
args = _build_arg_parser()
device = torch.device(args.device)

parser = _build_arg_parser()
args = parser.parse_args()

print(args)

streamline_classes = read_data_from_json_file(args.anatomy_file)

# Get the bundles of interest
boi = list(streamline_classes.keys())

assert_tractogram_exists(parser, args.atlas_path, boi)

thresholds = read_data_from_json_file(args.thresholds_file)
assert_bundle_datum_exists(parser, thresholds, boi)

if exists(args.output):
if not args.overwrite:
print(
Expand All @@ -366,6 +380,8 @@ def main():
f"Please specify a number between 1 and 30. Got {args.num_neighbors}. "
)

device = torch.device(args.device)

logging.info(args)

_set_up_logger(pjoin(args.output, LoggerKeys.logger_file_basename.name))
Expand All @@ -379,35 +395,28 @@ def main():
model.load_state_dict(state_dict)
model.eval()

streamline_classes = read_data_from_json_file(args.anatomy_file)

thresholds = read_data_from_json_file(args.thresholds_file)

latent_atlas_all = np.empty((0, 32))
y_latent_atlas_all = np.empty((0,))

atlas_file = os.listdir(args.atlas_path)
# Filter the atlas filenames according to the bundles of interest
foi = filter_filenames(args.atlas_path, boi)

logger.info("Loading atlas files ...")

for f in tqdm(atlas_file):
for f in tqdm(foi):

key = f.split(".")[-2]

assert (
key in thresholds.keys()
), f"[!] Threshold: {key} not in threshold file"

X_a_not_flipped, y_a_not_flipped = load_streamlines(
pjoin(args.atlas_path, f),
f,
args.common_space_reference,
streamline_classes[key],
resample=True,
num_points=256,
)

X_a_flipped, y_a_flipped = load_streamlines(
pjoin(args.atlas_path, f),
f,
args.common_space_reference,
streamline_classes[key],
resample=True,
Expand Down
14 changes: 14 additions & 0 deletions tractolearn/tractoio/file_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-

import enum

fname_sep = "."


class DictDataExtensions(enum.Enum):
JSON = "json"


class TractogramExtensions(enum.Enum):
TCK = "tck"
TRK = "trk"
152 changes: 152 additions & 0 deletions tractolearn/tractoio/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import tempfile

from tractolearn.tractoio.file_extensions import (
DictDataExtensions,
TractogramExtensions,
fname_sep,
)
from tractolearn.tractoio.utils import (
compose_filename,
filter_filenames,
identify_missing_bundle,
identify_missing_tractogram,
read_data_from_json_file,
save_data_to_json_file,
)


def test_identify_missing_bundle(tmp_path):

with tempfile.NamedTemporaryFile(
suffix=fname_sep + DictDataExtensions.JSON.value, dir=tmp_path
) as f:

# Target bundle names
bundle_names = ["CC_Fr_1", "CST_L", "AC"]

bundle_data = dict({"CC_Fr_1": 1.0, "CST_L": 2.0, "AC": 3.0})
expected = sorted(set(bundle_names).difference(bundle_data.keys()))

save_data_to_json_file(bundle_data, f.name)

data = read_data_from_json_file(
os.path.join(tmp_path, os.listdir(tmp_path)[0])
)

obtained = identify_missing_bundle(data, bundle_names)

assert obtained == expected

with tempfile.NamedTemporaryFile(
suffix=fname_sep + DictDataExtensions.JSON.value, dir=tmp_path
) as f:

# Target bundle names
bundle_names = ["Cu", "PrCu"]

bundle_data = dict({"Cu": 2.0})
expected = sorted(set(bundle_names).difference(bundle_data.keys()))

save_data_to_json_file(bundle_data, f.name)

data = read_data_from_json_file(
os.path.join(tmp_path, os.listdir(tmp_path)[0])
)

obtained = identify_missing_bundle(data, bundle_names)

assert obtained == expected


def test_identify_missing_tractogram(tmp_path):

# Target bundle names
bundle_names = ["CC_Fr_1", "CST_L", "AC"]

# Create some files in the temporary path
file_rootnames = ["CC_Fr_1", "CC_Fr_2", "AC"]
fnames = [
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
for val in file_rootnames
]
[open(val, "w") for val in fnames]

expected = sorted(set(bundle_names).difference(file_rootnames))

obtained = identify_missing_tractogram(tmp_path, bundle_names)

assert obtained == expected

# Target bundle names
bundle_names = ["Cu"]
expected = [
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
for val in bundle_names
]

# Create some files in the temporary path
file_rootnames = ["Cu", "PrCu"]
fnames = [
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
for val in file_rootnames
]
[open(val, "w") for val in fnames]

expected = sorted(set(bundle_names).difference(file_rootnames))

obtained = identify_missing_tractogram(tmp_path, bundle_names)

assert obtained == expected


def test_filter_fnames(tmp_path):

# Target bundle names
bundle_names = ["CC_Fr_1", "CST_L", "AC"]

# Create some files in the temporary path
file_rootnames = ["CC_Fr_1", "CC_Fr_2", "AC"]
fnames = [
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
for val in file_rootnames
]
[open(val, "w") for val in fnames]

expected_rootnames = ["AC", "CC_Fr_1"]
expected = [
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
for val in expected_rootnames
]

obtained = filter_filenames(tmp_path, bundle_names)

assert obtained == expected

# Target bundle names
bundle_names = ["Cu"]
expected = [
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
for val in bundle_names
]

# Create some files in the temporary path
file_rootnames = ["Cu", "PrCu"]
fnames = [
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
for val in file_rootnames
]
[open(val, "w") for val in fnames]

expected_rootnames = ["Cu"]
expected = [
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
for val in expected_rootnames
]

obtained = filter_filenames(tmp_path, bundle_names)

assert obtained == expected
Loading