From 7150d5ea85760fd23d96988a649e2048a9295bc3 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 19 Jun 2024 22:03:42 -0400 Subject: [PATCH] Added code metadata checking to the test. --- tests/test_extraction.py | 41 +++++++++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/tests/test_extraction.py b/tests/test_extraction.py index b5bb85e..75a7404 100644 --- a/tests/test_extraction.py +++ b/tests/test_extraction.py @@ -11,7 +11,6 @@ from pathlib import Path import polars as pl -from loguru import logger from polars.testing import assert_frame_equal pl.enable_string_cache() @@ -203,6 +202,22 @@ def get_expected_output(df: str) -> pl.DataFrame: 1500733,"06/03/2010, 16:44:26",DISCHARGE, """ +MEDS_OUTPUT_CODE_METADATA_FILE = """ +code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd +,44,4,28,3198.8389005974336,382968.28937288234 +ADMISSION//CARDIAC,2,2,0,, +ADMISSION//ORTHOPEDIC,1,1,0,, +ADMISSION//PULMONARY,1,1,0,, +DISCHARGE,4,4,0,, +DOB,4,4,0,, +EYE_COLOR//BLUE,1,1,0,, +EYE_COLOR//BROWN,1,1,0,, +EYE_COLOR//HAZEL,2,2,0,, +HEIGHT,4,4,4,656.8389005974336,108056.12937288235 +HR,12,4,12,1360.5000000000002,158538.77 +TEMP,12,4,12,1181.4999999999998,116373.38999999998 +""" + SUB_SHARDED_OUTPUTS = { "train/0": { "subjects": MEDS_OUTPUT_TRAIN_0_SUBJECTS, @@ -454,9 +469,6 @@ def test_extraction(): output_folder = MEDS_cohort_dir / "final_cohort" try: for split, expected_df_L in MEDS_OUTPUTS.items(): - if expected_df_L is None: - continue - if not isinstance(expected_df_L, list): expected_df_L = [expected_df_L] @@ -487,8 +499,6 @@ def test_extraction(): print(f"stdout:\n{full_stdout}") raise e - logger.warning("Only checked the train/0 split for now. TODO: add the rest of the splits.") - # Step 4: Merge to the final output stderr, stdout = run_command( extraction_root / "collect_code_metadata.py", @@ -504,4 +514,21 @@ def test_extraction(): output_file = MEDS_cohort_dir / "code_metadata.parquet" assert output_file.is_file(), f"Expected {output_file} to exist: stderr:\n{stderr}\nstdout:\n{stdout}" - logger.warning("Didn't check contents of code metadata!") + got_df = pl.read_parquet(output_file, glob=False) + + want_df = pl.read_csv(source=StringIO(MEDS_OUTPUT_CODE_METADATA_FILE)).with_columns( + pl.col("code").cast(pl.Categorical), + pl.col("code/n_occurrences").cast(pl.UInt32), + pl.col("code/n_patients").cast(pl.UInt32), + pl.col("values/n_occurrences").cast(pl.UInt32), + pl.col("values/sum").cast(pl.Float64).fill_null(0), + pl.col("values/sum_sqd").cast(pl.Float64).fill_null(0), + ) + + assert_df_equal( + want=want_df, + got=got_df, + msg="Code metadata differs!", + check_column_order=False, + check_row_order=False, + )