Skip to content

Commit

Permalink
pytest fixes:
Browse files Browse the repository at this point in the history
fixing directory issues
  • Loading branch information
taha-abdullah committed Aug 28, 2024
1 parent e9e5779 commit 64b9c6c
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 33 deletions.
12 changes: 6 additions & 6 deletions test/quick_test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
logger = getLogger(__name__)


def load_test_files():
def load_test_subjects():
"""
Load the test files from the given file path.
Expand All @@ -14,19 +14,19 @@ def load_test_files():
"""

subjects_dir = os.environ["SUBJECTS_DIR"]
subject_id = os.environ["SUBJECTS_LIST"]
subjects_list = os.environ["SUBJECTS_LIST"]

test_files = []
test_subjects = []

# Load the reference and test files
with open(os.path.join(subjects_dir, subject_id), 'r') as file:
with open(os.path.join(subjects_dir, subjects_list), 'r') as file:
for line in file:
filename = line.strip()
logger.debug(filename)
# test_file = os.path.join(subjects_dir, filename)
test_files.append(filename)
test_subjects.append(filename)

return test_files
return test_subjects


def load_reference_file():
Expand Down
12 changes: 6 additions & 6 deletions test/quick_test/test_errors_in_logfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def load_errors():
return errors, whitelist


def load_log_files(test_dir: str):
def load_log_files(test_subject: str):
"""
Retrieve the log files in the given log directory.
Expand All @@ -51,14 +51,14 @@ def load_log_files(test_dir: str):

# Retrieve the log files in given log directory

log_directory = os.path.join(test_dir, "scripts")
log_directory = os.path.join(test_subject, "scripts")
log_files = [file for file in Path(log_directory).iterdir() if file.suffix == '.log']

return log_files


@pytest.mark.parametrize("test_file", load_test_files())
def test_errors(subjects_dir, test_file: str):
@pytest.mark.parametrize("test_subject", load_test_subjects())
def test_errors(subjects_dir, test_subject: str):
"""
Test if there are any errors in the log files.
Expand All @@ -75,8 +75,8 @@ def test_errors(subjects_dir, test_file: str):
If any of the keywords are in the log files.
"""

test_file = os.path.join(subjects_dir, test_file)
log_files = load_log_files(test_file)
test_subject = os.path.join(subjects_dir, test_subject)
log_files = load_log_files(test_subject)

error_flag = False

Expand Down
8 changes: 4 additions & 4 deletions test/quick_test/test_file_existence.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def get_files_from_folder(folder_path: str):
return filenames


@pytest.mark.parametrize("test_file", load_test_files())
def test_file_existence(subjects_dir, test_file: str,
@pytest.mark.parametrize("test_subject", load_test_subjects())
def test_file_existence(subjects_dir, test_subject: str,
reference_file: str):
"""
Test the existence of files in the folder.
Expand All @@ -78,10 +78,10 @@ def test_file_existence(subjects_dir, test_file: str,
"""

# Get reference files from the reference subject directory
reference_files = get_files_from_folder(test_file)
reference_files = get_files_from_folder(test_subject)

# Get test list of files in the test subject directory
test_file = os.path.join(subjects_dir, test_file)
test_file = os.path.join(subjects_dir, test_subject)
test_files = get_files_from_folder(reference_file)

# Check if each file in the reference list exists in the test list
Expand Down
18 changes: 9 additions & 9 deletions test/quick_test/test_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def compute_mean_square_error(test_data: np.ndarray,
return mse


@pytest.mark.parametrize("test_file", load_test_files())
def test_image_headers(subjects_dir, test_file, reference_file):
@pytest.mark.parametrize("test_subject", load_test_subjects())
def test_image_headers(subjects_dir, test_subject, reference_file):
"""
Test image headers using nibabel.
Expand All @@ -142,7 +142,7 @@ def test_image_headers(subjects_dir, test_file, reference_file):
"""

# Load images
test_file = os.path.join(subjects_dir, test_file)
test_file = os.path.join(subjects_dir, test_subject)
test_image = load_image(test_file, "brain.mgz")
reference_image = load_image(reference_file, "brain.mgz")

Expand All @@ -155,8 +155,8 @@ def test_image_headers(subjects_dir, test_file, reference_file):
logger.debug("Image headers are correct")


@pytest.mark.parametrize("test_file", load_test_files())
def test_seg_data(subjects_dir, test_file: str,
@pytest.mark.parametrize("test_subject", load_test_subjects())
def test_seg_data(subjects_dir, test_subject: str,
reference_file: str):
"""
Test the segmentation data by calculating and comparing dice scores.
Expand All @@ -178,7 +178,7 @@ def test_seg_data(subjects_dir, test_file: str,

labels = load_labels()

test_file = os.path.join(subjects_dir, test_file)
test_file = os.path.join(subjects_dir, test_subject)
test_image = load_image(test_file, "brain.mgz")
reference_image = load_image(reference_file, "brain.mgz")

Expand All @@ -197,8 +197,8 @@ def test_seg_data(subjects_dir, test_file: str,
logger.debug("Dice scores are within range for all classes")


@pytest.mark.parametrize("test_file", load_test_files())
def test_int_data(subjects_dir, test_file: str,
@pytest.mark.parametrize("test_subject", load_test_subjects())
def test_int_data(subjects_dir, test_subject: str,
reference_file: str):
"""
Test the intensity data by calculating and comparing mean square errors.
Expand All @@ -218,7 +218,7 @@ def test_int_data(subjects_dir, test_file: str,
If the mean square error is not 0
"""

test_file = os.path.join(subjects_dir, test_file)
test_file = os.path.join(subjects_dir, test_subject)
test_image = load_image(test_file, "brain.mgz")
reference_image = load_image(reference_file, "brain.mgz")

Expand Down
27 changes: 19 additions & 8 deletions test/quick_test/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ def thresholds():
return default_threshold, thresholds


def load_stats_file(test_subject: str):

files = os.listdir(os.path.join(test_subject, "stats"))

if "aseg.stats" in files:
return os.path.join(test_subject, "stats", "aseg.stats")
elif "aparc+DKT.stats" in files:
return os.path.join(test_subject, "stats", "aparc+DKT.stats")
else:
raise ValueError("Unknown stats file")


# def load_structs():
# """
# Load the structs from the given file path.
Expand Down Expand Up @@ -103,8 +115,6 @@ def read_measure_stats(file_path: str):
measure = []
measurements = {}

file_path = os.path.join(file_path, "stats", "aseg.stats")

# Retrieve lines starting with "# Measure" from the stats file
with open(file_path, 'r') as file:
# Read each line in the file
Expand Down Expand Up @@ -164,10 +174,10 @@ def read_table(file_path: str):
return table


@pytest.mark.parametrize("test_file", load_test_files())
@pytest.mark.parametrize("test_subject", load_test_subjects())
def test_measure_exists(
subjects_dir: str,
test_file: str,
test_subject: str,
):
"""
Test if the measure exists in the stats file.
Expand All @@ -185,7 +195,8 @@ def test_measure_exists(
If the measure does not exist in the stats file.
"""

test_file = os.path.join(subjects_dir, test_file)
test_subject = os.path.join(subjects_dir, test_subject)
test_file = load_stats_file(test_subject)
data = read_measure_stats(test_file)
ref_data = read_measure_stats(test_file)
errors = []
Expand Down Expand Up @@ -237,8 +248,8 @@ def test_measure_exists(
# if threshold := thresholds.get(struct):
# assert variation <= threshold, f"Variation of {struct} is greater than threshold."

@pytest.mark.parametrize("test_file", load_test_files())
def test_tables(subjects_dir, test_file, reference_file, thresholds):
@pytest.mark.parametrize("test_subject", load_test_subjects())
def test_tables(subjects_dir, test_subject, reference_file, thresholds):
"""
Test if the table values are within the threshold.
Expand All @@ -261,7 +272,7 @@ def test_tables(subjects_dir, test_file, reference_file, thresholds):

# Read the reference and test files

test_file = os.path.join(subjects_dir, test_file)
test_file = os.path.join(subjects_dir, test_subject)
test_table = read_table(test_file)
ref_table = read_table(reference_file)

Expand Down

0 comments on commit 64b9c6c

Please sign in to comment.