Skip to content

Commit

Permalink
Added code metadata checking to the test.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Jun 20, 2024
1 parent e70d26f commit 7150d5e
Showing 1 changed file with 34 additions and 7 deletions.
41 changes: 34 additions & 7 deletions tests/test_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from pathlib import Path

import polars as pl
from loguru import logger
from polars.testing import assert_frame_equal

pl.enable_string_cache()
Expand Down Expand Up @@ -203,6 +202,22 @@ def get_expected_output(df: str) -> pl.DataFrame:
1500733,"06/03/2010, 16:44:26",DISCHARGE,
"""

MEDS_OUTPUT_CODE_METADATA_FILE = """
code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd
,44,4,28,3198.8389005974336,382968.28937288234
ADMISSION//CARDIAC,2,2,0,,
ADMISSION//ORTHOPEDIC,1,1,0,,
ADMISSION//PULMONARY,1,1,0,,
DISCHARGE,4,4,0,,
DOB,4,4,0,,
EYE_COLOR//BLUE,1,1,0,,
EYE_COLOR//BROWN,1,1,0,,
EYE_COLOR//HAZEL,2,2,0,,
HEIGHT,4,4,4,656.8389005974336,108056.12937288235
HR,12,4,12,1360.5000000000002,158538.77
TEMP,12,4,12,1181.4999999999998,116373.38999999998
"""

SUB_SHARDED_OUTPUTS = {
"train/0": {
"subjects": MEDS_OUTPUT_TRAIN_0_SUBJECTS,
Expand Down Expand Up @@ -454,9 +469,6 @@ def test_extraction():
output_folder = MEDS_cohort_dir / "final_cohort"
try:
for split, expected_df_L in MEDS_OUTPUTS.items():
if expected_df_L is None:
continue

if not isinstance(expected_df_L, list):
expected_df_L = [expected_df_L]

Expand Down Expand Up @@ -487,8 +499,6 @@ def test_extraction():
print(f"stdout:\n{full_stdout}")
raise e

logger.warning("Only checked the train/0 split for now. TODO: add the rest of the splits.")

# Step 4: Merge to the final output
stderr, stdout = run_command(
extraction_root / "collect_code_metadata.py",
Expand All @@ -504,4 +514,21 @@ def test_extraction():
output_file = MEDS_cohort_dir / "code_metadata.parquet"
assert output_file.is_file(), f"Expected {output_file} to exist: stderr:\n{stderr}\nstdout:\n{stdout}"

logger.warning("Didn't check contents of code metadata!")
got_df = pl.read_parquet(output_file, glob=False)

want_df = pl.read_csv(source=StringIO(MEDS_OUTPUT_CODE_METADATA_FILE)).with_columns(
pl.col("code").cast(pl.Categorical),
pl.col("code/n_occurrences").cast(pl.UInt32),
pl.col("code/n_patients").cast(pl.UInt32),
pl.col("values/n_occurrences").cast(pl.UInt32),
pl.col("values/sum").cast(pl.Float64).fill_null(0),
pl.col("values/sum_sqd").cast(pl.Float64).fill_null(0),
)

assert_df_equal(
want=want_df,
got=got_df,
msg="Code metadata differs!",
check_column_order=False,
check_row_order=False,
)

0 comments on commit 7150d5e

Please sign in to comment.