diff --git a/src/MEDS_transforms/configs/stage_configs/fit_vocabulary_indices.yaml b/src/MEDS_transforms/configs/stage_configs/fit_vocabulary_indices.yaml index 8725010..41596f0 100644 --- a/src/MEDS_transforms/configs/stage_configs/fit_vocabulary_indices.yaml +++ b/src/MEDS_transforms/configs/stage_configs/fit_vocabulary_indices.yaml @@ -1,4 +1,3 @@ fit_vocabulary_indices: is_metadata: true ordering_method: "lexicographic" - output_dir: "${cohort_dir}" diff --git a/src/MEDS_transforms/transforms/tensorization.py b/src/MEDS_transforms/transforms/tensorization.py index 2ec3ff6..0266ce2 100644 --- a/src/MEDS_transforms/transforms/tensorization.py +++ b/src/MEDS_transforms/transforms/tensorization.py @@ -5,6 +5,7 @@ import hydra import polars as pl +from loguru import logger from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict from omegaconf import DictConfig @@ -87,9 +88,17 @@ def convert_to_NRT(df: pl.LazyFrame) -> JointNestedRaggedTensorDict: time_delta_col = time_delta_cols[0] - return JointNestedRaggedTensorDict( - df.select(time_delta_col, "code", "numeric_value").collect().to_dict(as_series=False) - ) + tensors_dict = df.select(time_delta_col, "code", "numeric_value").collect().to_dict(as_series=False) + + if all((not v) for v in tensors_dict.values()): + logger.warning("All columns are empty. Returning an empty tensor dict.") + return JointNestedRaggedTensorDict({}) + + for k, v in tensors_dict.items(): + if not v: + raise ValueError(f"Column {k} is empty") + + return JointNestedRaggedTensorDict(tensors_dict) @hydra.main( diff --git a/src/MEDS_transforms/transforms/tokenization.py b/src/MEDS_transforms/transforms/tokenization.py index c4f3d26..d6f5003 100644 --- a/src/MEDS_transforms/transforms/tokenization.py +++ b/src/MEDS_transforms/transforms/tokenization.py @@ -184,27 +184,27 @@ def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: ... None, datetime(2021, 1, 1), datetime(2021, 1, 1), datetime(2021, 1, 13), ... None, datetime(2021, 1, 2), datetime(2021, 1, 2)], ... "code": [100, 101, 102, 103, 200, 201, 202], - ... "numeric_value": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0] + ... "numeric_value": pl.Series([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], dtype=pl.Float32) ... }).lazy() >>> extract_seq_of_patient_events(df).collect() shape: (2, 4) - ┌────────────┬─────────────────┬───────────────────────────┬─────────────────────┐ - │ patient_id ┆ time_delta_days ┆ code ┆ numeric_value │ - │ --- ┆ --- ┆ --- ┆ --- │ - │ i64 ┆ list[f64] ┆ list[list[f64]] ┆ list[list[f64]] │ - ╞════════════╪═════════════════╪═══════════════════════════╪═════════════════════╡ - │ 1 ┆ [NaN, 12.0] ┆ [[101.0, 102.0], [103.0]] ┆ [[2.0, 3.0], [4.0]] │ - │ 2 ┆ [NaN] ┆ [[201.0, 202.0]] ┆ [[6.0, 7.0]] │ - └────────────┴─────────────────┴───────────────────────────┴─────────────────────┘ + ┌────────────┬─────────────────┬─────────────────────┬─────────────────────┐ + │ patient_id ┆ time_delta_days ┆ code ┆ numeric_value │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ list[f32] ┆ list[list[i64]] ┆ list[list[f32]] │ + ╞════════════╪═════════════════╪═════════════════════╪═════════════════════╡ + │ 1 ┆ [NaN, 12.0] ┆ [[101, 102], [103]] ┆ [[2.0, 3.0], [4.0]] │ + │ 2 ┆ [NaN] ┆ [[201, 202]] ┆ [[6.0, 7.0]] │ + └────────────┴─────────────────┴─────────────────────┴─────────────────────┘ """ _, dynamic = split_static_and_dynamic(df) - time_delta_days_expr = (pl.col("time").diff().dt.total_seconds() / SECONDS_PER_DAY).cast(pl.Float64) + time_delta_days_expr = (pl.col("time").diff().dt.total_seconds() / SECONDS_PER_DAY).cast(pl.Float32) return ( dynamic.group_by("patient_id", "time", maintain_order=True) - .agg(fill_to_nans("code").name.keep(), fill_to_nans("numeric_value").name.keep()) + .agg(pl.col("code").name.keep(), fill_to_nans("numeric_value").name.keep()) .group_by("patient_id", maintain_order=True) .agg( fill_to_nans(time_delta_days_expr).alias("time_delta_days"), diff --git a/src/MEDS_transforms/utils.py b/src/MEDS_transforms/utils.py index 0c4c20a..59a7cd6 100644 --- a/src/MEDS_transforms/utils.py +++ b/src/MEDS_transforms/utils.py @@ -244,7 +244,7 @@ def populate_stage( 'reducer_output_dir': '/c/d/metadata'} >>> populate_stage("stage6", *args) # doctest: +NORMALIZE_WHITESPACE {'is_metadata': False, 'data_input_dir': '/c/d/stage4', - 'metadata_input_dir': '/c/d/stage5', 'output_dir': '/c/d/data', 'reducer_output_dir': None} + 'metadata_input_dir': '/c/d/metadata', 'output_dir': '/c/d/data', 'reducer_output_dir': None} >>> populate_stage("stage7", *args) # doctest: +NORMALIZE_WHITESPACE Traceback (most recent call last): ... @@ -304,7 +304,7 @@ def populate_stage( if is_first_metadata_stage: default_metadata_input_dir = pipeline_input_metadata_dir else: - default_metadata_input_dir = prior_metadata_stage["output_dir"] + default_metadata_input_dir = prior_metadata_stage["reducer_output_dir"] # Now, we need to set output directories. The output directory for the stage will either be a stage # specific output directory, or, for the last data or metadata stages, respectively, will be the global diff --git a/tests/test_add_time_derived_measurements.py b/tests/test_add_time_derived_measurements.py index 87138c8..e5653a1 100644 --- a/tests/test_add_time_derived_measurements.py +++ b/tests/test_add_time_derived_measurements.py @@ -240,5 +240,5 @@ def test_add_time_derived_measurements(): transform_script=ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, stage_name="add_time_derived_measurements", transform_stage_kwargs={"age": {"DOB_code": "DOB"}}, - want_outputs=WANT_SHARDS, + want_data=WANT_SHARDS, ) diff --git a/tests/test_aggregate_code_metadata.py b/tests/test_aggregate_code_metadata.py index 2a8f78c..21698cb 100644 --- a/tests/test_aggregate_code_metadata.py +++ b/tests/test_aggregate_code_metadata.py @@ -180,7 +180,7 @@ def test_aggregate_code_metadata(): transform_script=AGGREGATE_CODE_METADATA_SCRIPT, stage_name="aggregate_code_metadata", transform_stage_kwargs={"aggregations": AGGREGATIONS, "do_summarize_over_all_codes": True}, - want_outputs=WANT_OUTPUT_CODE_METADATA_FILE, - code_metadata=MEDS_CODE_METADATA_FILE, + want_metadata=WANT_OUTPUT_CODE_METADATA_FILE, + input_code_metadata=MEDS_CODE_METADATA_FILE, do_use_config_yaml=True, ) diff --git a/tests/test_filter_measurements.py b/tests/test_filter_measurements.py index 2243320..cb919d1 100644 --- a/tests/test_filter_measurements.py +++ b/tests/test_filter_measurements.py @@ -113,7 +113,7 @@ def test_filter_measurements(): transform_script=FILTER_MEASUREMENTS_SCRIPT, stage_name="filter_measurements", transform_stage_kwargs={"min_patients_per_code": 2}, - want_outputs=WANT_SHARDS, + want_data=WANT_SHARDS, ) @@ -223,6 +223,6 @@ def test_match_revise_filter_measurements(): {"_matcher": {"code": "EYE_COLOR//HAZEL"}, "min_patients_per_code": 4}, ], }, - want_outputs=MR_WANT_SHARDS, + want_data=MR_WANT_SHARDS, do_use_config_yaml=True, ) diff --git a/tests/test_filter_patients.py b/tests/test_filter_patients.py index e936875..0b07836 100644 --- a/tests/test_filter_patients.py +++ b/tests/test_filter_patients.py @@ -82,5 +82,5 @@ def test_filter_patients(): transform_script=FILTER_PATIENTS_SCRIPT, stage_name="filter_patients", transform_stage_kwargs={"min_events_per_patient": 5}, - want_outputs=WANT_SHARDS, + want_data=WANT_SHARDS, ) diff --git a/tests/test_fit_vocabulary_indices.py b/tests/test_fit_vocabulary_indices.py index 312e65f..ce7c40a 100644 --- a/tests/test_fit_vocabulary_indices.py +++ b/tests/test_fit_vocabulary_indices.py @@ -33,5 +33,5 @@ def test_fit_vocabulary_indices_with_default_stage_config(): transform_script=FIT_VOCABULARY_INDICES_SCRIPT, stage_name="fit_vocabulary_indices", transform_stage_kwargs=None, - want_outputs=parse_code_metadata_csv(WANT_CSV), + want_metadata=parse_code_metadata_csv(WANT_CSV), ) diff --git a/tests/test_multi_stage_preprocess_pipeline.py b/tests/test_multi_stage_preprocess_pipeline.py new file mode 100644 index 0000000..eda9060 --- /dev/null +++ b/tests/test_multi_stage_preprocess_pipeline.py @@ -0,0 +1,1098 @@ +"""Tests a multi-stage pre-processing pipeline. Only checks the end result, not the intermediate files. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. + +In this test, the following stages are run: + - filter_patients + - add_time_derived_measurements + - fit_outlier_detection + - occlude_outliers + - fit_normalization + - fit_vocabulary_indices + - normalization + - tokenization + - tensorization + +The stage configuration arguments will be as given in the yaml block below: +""" + +from datetime import datetime + +import polars as pl +from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict + +from .transform_tester_base import ( + ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, + AGGREGATE_CODE_METADATA_SCRIPT, + FILTER_PATIENTS_SCRIPT, + FIT_VOCABULARY_INDICES_SCRIPT, + NORMALIZATION_SCRIPT, + OCCLUDE_OUTLIERS_SCRIPT, + TENSORIZATION_SCRIPT, + TOKENIZATION_SCRIPT, + multi_stage_transform_tester, + parse_shards_yaml, +) + +MEDS_CODE_METADATA = pl.DataFrame( + { + "code": ["EYE_COLOR//BLUE", "EYE_COLOR//BROWN", "EYE_COLOR//HAZEL", "HR", "TEMP"], + "description": [ + "Blue Eyes. Less common than brown.", + "Brown Eyes. The most common eye color.", + "Hazel eyes. These are uncommon", + "Heart Rate", + "Body Temperature", + ], + "parent_codes": [None, None, None, ["LOINC/8867-4"], ["LOINC/8310-5"]], + }, + schema={"code": pl.String, "description": pl.String, "parent_codes": pl.List(pl.String)}, +) + +STAGE_CONFIG_YAML = """ +filter_patients: + min_events_per_patient: 5 +add_time_derived_measurements: + age: + DOB_code: "DOB" # This is the MEDS official code for BIRTH + age_code: "AGE" + age_unit: "years" + time_of_day: + time_of_day_code: "TIME_OF_DAY" + endpoints: [6, 12, 18, 24] +fit_outlier_detection: + aggregations: + - "values/n_occurrences" + - "values/sum" + - "values/sum_sqd" +occlude_outliers: + stddev_cutoff: 1 +fit_normalization: + aggregations: + - "code/n_occurrences" + - "code/n_patients" + - "values/n_occurrences" + - "values/sum" + - "values/sum_sqd" +""" + +# After filtering out patients with fewer than 5 events: +WANT_FILTER = parse_shards_yaml( + """ + "filter_patients/train/0": |-2 + patient_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + "filter_patients/train/1": |-2 + patient_id,time,code,numeric_value + + "filter_patients/tuning/0": |-2 + patient_id,time,code,numeric_value + + "filter_patients/held_out/0": |-2 + patient_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, +""" +) + +WANT_TIME_DERIVED = parse_shards_yaml( + """ + "add_time_derived_measurements/train/0": |-2 + patient_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00","TIME_OF_DAY//[00,06)", + 239684,"12/28/1980, 00:00:00",DOB, + 239684,"05/11/2010, 17:41:51","TIME_OF_DAY//[12,18)", + 239684,"05/11/2010, 17:41:51",AGE,29.36883360091833 + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48","TIME_OF_DAY//[12,18)", + 239684,"05/11/2010, 17:48:48",AGE,29.36884681513314 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35","TIME_OF_DAY//[18,24)", + 239684,"05/11/2010, 18:25:35",AGE,29.36891675223647 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18","TIME_OF_DAY//[18,24)", + 239684,"05/11/2010, 18:57:18",AGE,29.36897705595538 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19","TIME_OF_DAY//[18,24)", + 239684,"05/11/2010, 19:27:19",AGE,29.369034127420306 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00","TIME_OF_DAY//[00,06)", + 1195293,"06/20/1978, 00:00:00",DOB, + 1195293,"06/20/2010, 19:23:52","TIME_OF_DAY//[18,24)", + 1195293,"06/20/2010, 19:23:52",AGE,32.002896271955265 + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:25:32","TIME_OF_DAY//[18,24)", + 1195293,"06/20/2010, 19:25:32",AGE,32.00289944083172 + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19","TIME_OF_DAY//[18,24)", + 1195293,"06/20/2010, 19:45:19",AGE,32.00293705539522 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:12:31","TIME_OF_DAY//[18,24)", + 1195293,"06/20/2010, 20:12:31",AGE,32.002988771458945 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:24:44","TIME_OF_DAY//[18,24)", + 1195293,"06/20/2010, 20:24:44",AGE,32.00301199932335 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33","TIME_OF_DAY//[18,24)", + 1195293,"06/20/2010, 20:41:33",AGE,32.003043973286765 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04","TIME_OF_DAY//[18,24)", + 1195293,"06/20/2010, 20:50:04",AGE,32.00306016624544 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + "add_time_derived_measurements/train/1": |-2 + patient_id,time,code,numeric_value + + "add_time_derived_measurements/tuning/0": |-2 + patient_id,time,code,numeric_value + + "add_time_derived_measurements/held_out/0": |-2 + patient_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00","TIME_OF_DAY//[00,06)", + 1500733,"07/20/1986, 00:00:00",DOB, + 1500733,"06/03/2010, 14:54:38","TIME_OF_DAY//[12,18)", + 1500733,"06/03/2010, 14:54:38",AGE,23.873531791091356 + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49","TIME_OF_DAY//[12,18)", + 1500733,"06/03/2010, 15:39:49",AGE,23.873617699332012 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:20:49","TIME_OF_DAY//[12,18)", + 1500733,"06/03/2010, 16:20:49",AGE,23.873695653692767 + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26","TIME_OF_DAY//[12,18)", + 1500733,"06/03/2010, 16:44:26",AGE,23.873740556672114 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, +""" +) + +# Fit outliers python code +FIT_OUTLIERS_CODE = """ +```python +>>> from tests.test_multi_stage_preprocess_pipeline import WANT_TIME_DERIVED +>>> import polars as pl +>>> VALS = pl.col("numeric_value").drop_nulls().drop_nans() +>>> post_outliers = ( +... WANT_TIME_DERIVED['add_time_derived_measurements/train/0'] +... .group_by("code") +... .agg( +... VALS.len().alias("values/n_occurrences"), +... VALS.sum().alias("values/sum"), +... (VALS**2).sum().alias("values/sum_sqd") +... ) +... .filter(pl.col("values/n_occurrences") > 0) +... ) +>>> post_outliers +shape: (4, 4) +┌────────┬──────────────────────┬─────────────┬────────────────┐ +│ code ┆ values/n_occurrences ┆ values/sum ┆ values/sum_sqd │ +│ --- ┆ --- ┆ --- ┆ --- │ +│ str ┆ u32 ┆ f32 ┆ f32 │ +╞════════╪══════════════════════╪═════════════╪════════════════╡ +│ HR ┆ 10 ┆ 1104.300049 ┆ 122174.726562 │ +│ AGE ┆ 12 ┆ 370.865448 ┆ 11482.001953 │ +│ TEMP ┆ 10 ┆ 983.600037 ┆ 96788.53125 │ +│ HEIGHT ┆ 2 ┆ 339.958008 ┆ 57841.734375 │ +└────────┴──────────────────────┴─────────────┴────────────────┘ +>>> print(post_outliers.to_dict(as_series=False)) +{'code': ['HR', 'AGE', 'TEMP', 'HEIGHT'], + 'values/n_occurrences': [10, 12, 10, 2], + 'values/sum': [1104.300048828125, 370.8654479980469, 983.6000366210938, 339.9580078125], + 'values/sum_sqd': [122174.7265625, 11482.001953125, 96788.53125, 57841.734375]} + + +``` +""" + +# Input: +# code,description,parent_codes +# EYE_COLOR//BLUE,"Blue Eyes. Less common than brown.", +# EYE_COLOR//BROWN,"Brown Eyes. The most common eye color.", +# EYE_COLOR//HAZEL,"Hazel eyes. These are uncommon", +# HR,"Heart Rate",LOINC/8867-4 +# TEMP,"Body Temperature",LOINC/8310-5 + +WANT_FIT_OUTLIERS = { + "fit_outlier_detection/codes.parquet": pl.DataFrame( + { + "code": [ + "EYE_COLOR//BLUE", + "EYE_COLOR//BROWN", + "HR", + "TEMP", + "AGE", + "HEIGHT", + "TIME_OF_DAY//[18,24)", + "TIME_OF_DAY//[12,18)", + "TIME_OF_DAY//[00,06)", + "ADMISSION//CARDIAC", + "DISCHARGE", + "DOB", + ], + "values/n_occurrences": [0, 0, 10, 10, 12, 2, 0, 0, 0, 0, 0, 0], + "values/sum": [ + 0.0, + 0.0, + 1104.300048828125, + 983.6000366210938, + 370.8654479980469, + 339.9580078125, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + "values/sum_sqd": [ + 0.0, + 0.0, + 122174.7265625, + 96788.53125, + 11482.001953125, + 57841.734375, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + "description": [ + "Blue Eyes. Less common than brown.", + "Brown Eyes. The most common eye color.", + "Heart Rate", + "Body Temperature", + None, + None, + None, + None, + None, + None, + None, + None, + ], + "parent_codes": [ + None, + None, + ["LOINC/8867-4"], + ["LOINC/8310-5"], + None, + None, + None, + None, + None, + None, + None, + None, + ], + }, + schema={ + "code": pl.String, + "description": pl.String, + "parent_codes": pl.List(pl.String), + "values/n_occurrences": pl.UInt8, # In the real stage, this is shrunk, so it differs from the ex. + "values/sum": pl.Float32, + "values/sum_sqd": pl.Float32, + }, + ).sort(by="code") +} + +# For occluding outliers +OCCLUDE_OUTLIERS_CODE = """ +```python +# This implies the following means and standard deviations +>>> from tests.test_multi_stage_preprocess_pipeline import WANT_FIT_OUTLIERS as metadata_df +>>> mean_col = pl.col("values/sum") / pl.col("values/n_occurrences") +>>> stddev_col = (pl.col("values/sum_sqd") / pl.col("values/n_occurrences") - mean_col**2) ** 0.5 +>>> metadata_df.select( +... "code", +... (mean_col - stddev_col).alias("values/inlier_lower_bound"), +... (mean_col + stddev_col).alias("values/inlier_upper_bound") +... ) +shape: (4, 3) +┌────────┬───────────────────────────┬───────────────────────────┐ +│ code ┆ values/inlier_lower_bound ┆ values/inlier_upper_bound │ +│ --- ┆ --- ┆ --- │ +│ str ┆ f64 ┆ f64 │ +╞════════╪═══════════════════════════╪═══════════════════════════╡ +│ HR ┆ 105.666951 ┆ 115.193058 │ +│ AGE ┆ 29.606836 ┆ 32.204072 │ +│ TEMP ┆ 96.319708 ┆ 100.400299 │ +│ HEIGHT ┆ 164.686989 ┆ 175.271019 │ +└────────┴───────────────────────────┴───────────────────────────┘ + +``` +""" + +WANT_OCCLUDE_OUTLIERS = parse_shards_yaml( + """ + "occlude_outliers/train/0": |-2 + patient_id,time,code,numeric_value,numeric_value/is_inlier + 239684,,EYE_COLOR//BROWN,, + 239684,,HEIGHT,,false + 239684,"12/28/1980, 00:00:00","TIME_OF_DAY//[00,06)",, + 239684,"12/28/1980, 00:00:00",DOB,, + 239684,"05/11/2010, 17:41:51","TIME_OF_DAY//[12,18)",, + 239684,"05/11/2010, 17:41:51",AGE,,false + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC,, + 239684,"05/11/2010, 17:41:51",HR,,false + 239684,"05/11/2010, 17:41:51",TEMP,,false + 239684,"05/11/2010, 17:48:48","TIME_OF_DAY//[12,18)",, + 239684,"05/11/2010, 17:48:48",AGE,,false + 239684,"05/11/2010, 17:48:48",HR,,false + 239684,"05/11/2010, 17:48:48",TEMP,,false + 239684,"05/11/2010, 18:25:35","TIME_OF_DAY//[18,24)",, + 239684,"05/11/2010, 18:25:35",AGE,,false + 239684,"05/11/2010, 18:25:35",HR,113.4,true + 239684,"05/11/2010, 18:25:35",TEMP,,false + 239684,"05/11/2010, 18:57:18","TIME_OF_DAY//[18,24)",, + 239684,"05/11/2010, 18:57:18",AGE,,false + 239684,"05/11/2010, 18:57:18",HR,112.6,true + 239684,"05/11/2010, 18:57:18",TEMP,,false + 239684,"05/11/2010, 19:27:19","TIME_OF_DAY//[18,24)",, + 239684,"05/11/2010, 19:27:19",AGE,,false + 239684,"05/11/2010, 19:27:19",DISCHARGE,, + 1195293,,EYE_COLOR//BLUE,, + 1195293,,HEIGHT,,false + 1195293,"06/20/1978, 00:00:00","TIME_OF_DAY//[00,06)",, + 1195293,"06/20/1978, 00:00:00",DOB,, + 1195293,"06/20/2010, 19:23:52","TIME_OF_DAY//[18,24)",, + 1195293,"06/20/2010, 19:23:52",AGE,32.002896271955265,true + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC,, + 1195293,"06/20/2010, 19:23:52",HR,109.0,true + 1195293,"06/20/2010, 19:23:52",TEMP,100.0,true + 1195293,"06/20/2010, 19:25:32","TIME_OF_DAY//[18,24)",, + 1195293,"06/20/2010, 19:25:32",AGE,32.00289944083172,true + 1195293,"06/20/2010, 19:25:32",HR,114.1,true + 1195293,"06/20/2010, 19:25:32",TEMP,100.0,true + 1195293,"06/20/2010, 19:45:19","TIME_OF_DAY//[18,24)",, + 1195293,"06/20/2010, 19:45:19",AGE,32.00293705539522,true + 1195293,"06/20/2010, 19:45:19",HR,,false + 1195293,"06/20/2010, 19:45:19",TEMP,99.9,true + 1195293,"06/20/2010, 20:12:31","TIME_OF_DAY//[18,24)",, + 1195293,"06/20/2010, 20:12:31",AGE,32.002988771458945,true + 1195293,"06/20/2010, 20:12:31",HR,112.5,true + 1195293,"06/20/2010, 20:12:31",TEMP,99.8,true + 1195293,"06/20/2010, 20:24:44","TIME_OF_DAY//[18,24)", + 1195293,"06/20/2010, 20:24:44",AGE,32.00301199932335,true + 1195293,"06/20/2010, 20:24:44",HR,107.7,true + 1195293,"06/20/2010, 20:24:44",TEMP,100.0,true + 1195293,"06/20/2010, 20:41:33","TIME_OF_DAY//[18,24)",, + 1195293,"06/20/2010, 20:41:33",AGE,32.003043973286765,true + 1195293,"06/20/2010, 20:41:33",HR,107.5,true + 1195293,"06/20/2010, 20:41:33",TEMP,100.4,true + 1195293,"06/20/2010, 20:50:04","TIME_OF_DAY//[18,24)",, + 1195293,"06/20/2010, 20:50:04",AGE,32.00306016624544,true + 1195293,"06/20/2010, 20:50:04",DISCHARGE,, + + "occlude_outliers/train/1": |-2 + patient_id,time,code,numeric_value,numeric_value/is_inlier + + "occlude_outliers/tuning/0": |-2 + patient_id,time,code,numeric_value,numeric_value/is_inlier + + "occlude_outliers/held_out/0": |-2 + patient_id,time,code,numeric_value,numeric_value/is_inlier + 1500733,,EYE_COLOR//BROWN,, + 1500733,,HEIGHT,,false + 1500733,"07/20/1986, 00:00:00","TIME_OF_DAY//[00,06)",, + 1500733,"07/20/1986, 00:00:00",DOB,, + 1500733,"06/03/2010, 14:54:38","TIME_OF_DAY//[12,18)",, + 1500733,"06/03/2010, 14:54:38",AGE,,false + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC,, + 1500733,"06/03/2010, 14:54:38",HR,,false + 1500733,"06/03/2010, 14:54:38",TEMP,100.0,true + 1500733,"06/03/2010, 15:39:49","TIME_OF_DAY//[12,18)",, + 1500733,"06/03/2010, 15:39:49",AGE,,false + 1500733,"06/03/2010, 15:39:49",HR,,false + 1500733,"06/03/2010, 15:39:49",TEMP,100.3,true + 1500733,"06/03/2010, 16:20:49","TIME_OF_DAY//[12,18)",, + 1500733,"06/03/2010, 16:20:49",AGE,,false + 1500733,"06/03/2010, 16:20:49",HR,,false + 1500733,"06/03/2010, 16:20:49",TEMP,100.1,true + 1500733,"06/03/2010, 16:44:26","TIME_OF_DAY//[12,18)",, + 1500733,"06/03/2010, 16:44:26",AGE,,false + 1500733,"06/03/2010, 16:44:26",DISCHARGE,, +""" +) + +FIT_NORMALIZATION_CODE = """ +```python +>>> from tests.test_multi_stage_preprocess_pipeline import WANT_OCCLUDE_OUTLIERS as dfs +>>> import polars as pl +>>> VALS = pl.col("numeric_value").drop_nulls().drop_nans() +>>> post_transform = ( +... dfs[next(k for k in dfs.keys() if k.endswith("/train/0"))] +... .group_by("code") +... .agg( +... pl.len().alias("code/n_occurrences"), +... pl.col("patient_id").n_unique().alias("code/n_patients"), +... VALS.len().alias("values/n_occurrences"), +... VALS.sum().alias("values/sum"), +... (VALS**2).sum().alias("values/sum_sqd") +... ) +... ) +>>> post_transform.filter(pl.col("values/n_occurrences") > 0) +shape: (3, 6) +┌──────┬────────────────────┬─────────────────┬──────────────────────┬────────────┬────────────────┐ +│ code ┆ code/n_occurrences ┆ code/n_patients ┆ values/n_occurrences ┆ values/sum ┆ values/sum_sqd │ +│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ +│ str ┆ u32 ┆ u32 ┆ u32 ┆ f32 ┆ f32 │ +╞══════╪════════════════════╪═════════════════╪══════════════════════╪════════════╪════════════════╡ +│ HR ┆ 10 ┆ 2 ┆ 7 ┆ 776.799988 ┆ 86249.921875 │ +│ TEMP ┆ 10 ┆ 2 ┆ 6 ┆ 600.100037 ┆ 60020.214844 │ +│ AGE ┆ 12 ┆ 2 ┆ 7 ┆ 224.020844 ┆ 7169.333496 │ +└──────┴────────────────────┴─────────────────┴──────────────────────┴────────────┴────────────────┘ +>>> print(post_transform.filter(pl.col("values/n_occurrences") > 0).to_dict(as_series=False)) +{'code': ['HR', 'TEMP', 'AGE'], + 'code/n_occurrences': [10, 10, 12], + 'code/n_patients': [2, 2, 2], + 'values/n_occurrences': [7, 6, 7], + 'values/sum': [776.7999877929688, 600.1000366210938, 224.02084350585938], + 'values/sum_sqd': [86249.921875, 60020.21484375, 7169.33349609375]} + +""" + +WANT_FIT_NORMALIZATION = { + "fit_normalization/codes.parquet": pl.DataFrame( + { + "code": [ + "EYE_COLOR//BLUE", + "EYE_COLOR//BROWN", + "HR", + "TEMP", + "AGE", + "HEIGHT", + "TIME_OF_DAY//[18,24)", + "TIME_OF_DAY//[12,18)", + "TIME_OF_DAY//[00,06)", + "ADMISSION//CARDIAC", + "DISCHARGE", + "DOB", + ], + "code/n_occurrences": [1, 1, 10, 10, 12, 2, 10, 2, 2, 2, 2, 2], + "code/n_patients": [1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2], + "values/n_occurrences": [0, 0, 7, 6, 7, 0, 0, 0, 0, 0, 0, 0], + "values/sum": [ + 0.0, + 0.0, + 776.7999877929688, + 600.1000366210938, + 224.0208376784967, + 0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + "values/sum_sqd": [ + 0.0, + 0.0, + 86249.921875, + 60020.21484375, + 7169.33349609375, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + "description": [ + "Blue Eyes. Less common than brown.", + "Brown Eyes. The most common eye color.", + "Heart Rate", + "Body Temperature", + None, + None, + None, + None, + None, + None, + None, + None, + ], + "parent_codes": [ + None, + None, + ["LOINC/8867-4"], + ["LOINC/8310-5"], + None, + None, + None, + None, + None, + None, + None, + None, + ], + }, + schema={ + "code": pl.String, + "description": pl.String, + "parent_codes": pl.List(pl.String), + "code/n_occurrences": pl.UInt8, + "code/n_patients": pl.UInt8, + "values/n_occurrences": pl.UInt8, # In the real stage, this is shrunk, so it differs from the ex. + "values/sum": pl.Float32, + "values/sum_sqd": pl.Float32, + }, + ).sort(by="code") +} + +# As the last metadata stage, this gets a special directory. +WANT_FIT_VOCABULARY_INDICES = { + "metadata/codes.parquet": pl.DataFrame( + { + "code": [ + "EYE_COLOR//BLUE", + "EYE_COLOR//BROWN", + "HR", + "TEMP", + "AGE", + "HEIGHT", + "TIME_OF_DAY//[18,24)", + "TIME_OF_DAY//[12,18)", + "TIME_OF_DAY//[00,06)", + "ADMISSION//CARDIAC", + "DISCHARGE", + "DOB", + ], + "code/vocab_index": [5, 6, 8, 9, 2, 7, 12, 11, 10, 1, 3, 4], + "code/n_occurrences": [1, 1, 10, 10, 12, 2, 10, 2, 2, 2, 2, 2], + "code/n_patients": [1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2], + "values/n_occurrences": [0, 0, 7, 6, 7, 0, 0, 0, 0, 0, 0, 0], + "values/sum": [ + 0.0, + 0.0, + 776.7999877929688, + 600.1000366210938, + 224.0208376784967, + 0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + "values/sum_sqd": [ + 0.0, + 0.0, + 86249.921875, + 60020.21484375, + 7169.33349609375, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + "description": [ + "Blue Eyes. Less common than brown.", + "Brown Eyes. The most common eye color.", + "Heart Rate", + "Body Temperature", + None, + None, + None, + None, + None, + None, + None, + None, + ], + "parent_codes": [ + None, + None, + ["LOINC/8867-4"], + ["LOINC/8310-5"], + None, + None, + None, + None, + None, + None, + None, + None, + ], + }, + schema={ + "code": pl.String, + "description": pl.String, + "parent_codes": pl.List(pl.String), + "code/n_occurrences": pl.UInt8, + "code/n_patients": pl.UInt8, + "code/vocab_index": pl.UInt8, + "values/n_occurrences": pl.UInt8, + "values/sum": pl.Float32, + "values/sum_sqd": pl.Float32, + }, + ).sort(by="code") +} + + +NORMALIZATION_CODE = """ +```python +# This implies the following means and standard deviations +>>> import polars as pl +>>> from tests.test_multi_stage_preprocess_pipeline import WANT_FIT_VOCABULARY_INDICES as metadata_df +>>> metadata_df = list(metadata_df.values())[0] +>>> from tests.test_multi_stage_preprocess_pipeline import WANT_OCCLUDE_OUTLIERS as dfs +>>> mean_col = pl.col("values/sum") / pl.col("values/n_occurrences") +>>> stddev_col = (pl.col("values/sum_sqd") / pl.col("values/n_occurrences") - mean_col**2) ** 0.5 +>>> metadata_df = metadata_df.select( +... "code", +... "code/vocab_index", +... mean_col.alias("values/mean"), +... stddev_col.alias("values/stddev"), +... ) +>>> metadata_df +shape: (12, 4) +┌──────────────────────┬──────────────────┬─────────────┬───────────────┐ +│ code ┆ code/vocab_index ┆ values/mean ┆ values/stddev │ +│ --- ┆ --- ┆ --- ┆ --- │ +│ str ┆ u8 ┆ f32 ┆ f32 │ +╞══════════════════════╪══════════════════╪═════════════╪═══════════════╡ +│ ADMISSION//CARDIAC ┆ 1 ┆ NaN ┆ NaN │ +│ AGE ┆ 2 ┆ 32.002979 ┆ NaN │ +│ DISCHARGE ┆ 3 ┆ NaN ┆ NaN │ +│ DOB ┆ 4 ┆ NaN ┆ NaN │ +│ EYE_COLOR//BLUE ┆ 5 ┆ NaN ┆ NaN │ +│ … ┆ … ┆ … ┆ … │ +│ HR ┆ 8 ┆ 110.971428 ┆ 2.599767 │ +│ TEMP ┆ 9 ┆ 100.01667 ┆ 0.1875 │ +│ TIME_OF_DAY//[00,06) ┆ 10 ┆ NaN ┆ NaN │ +│ TIME_OF_DAY//[12,18) ┆ 11 ┆ NaN ┆ NaN │ +│ TIME_OF_DAY//[18,24) ┆ 12 ┆ NaN ┆ NaN │ +└──────────────────────┴──────────────────┴─────────────┴───────────────┘ +>>> import pprint +>>> pp = pprint.PrettyPrinter(width=80, compact=True) +>>> for k, df in dfs.items(): +... df = df.join(metadata_df, on="code").select( +... "code/vocab_index", +... (pl.col("numeric_value") - pl.col("values/mean")) / pl.col("values/stddev") +... ) +... print("/".join(k.split("/")[1:])) +... pp.pprint(df.to_dict(as_series=False)) +train/0 +{'code/vocab_index': [6, 7, 10, 4, 11, 2, 1, 8, 9, 11, 2, 8, 9, 12, 2, 8, 9, 12, + 2, 8, 9, 12, 2, 3, 5, 7, 10, 4, 12, 2, 1, 8, 9, 12, 2, 8, + 9, 12, 2, 8, 9, 12, 2, 8, 9, 12, 2, 8, 9, 12, 2, 8, 9, 12, + 2, 3], + 'numeric_value': [None, None, None, None, None, None, None, None, None, None, + None, None, None, None, None, 0.9341503977775574, None, None, + None, 0.6264293789863586, None, None, None, None, None, None, + None, None, None, nan, None, -0.7583094239234924, + -0.0889078751206398, None, nan, 1.2034040689468384, + -0.0889078751206398, None, nan, None, -0.6222330927848816, + None, nan, 0.5879650115966797, -1.1555582284927368, None, + nan, -1.2583553791046143, -0.0889078751206398, None, nan, + -1.3352841138839722, 2.04443359375, None, nan, None]} +train/1 +{'code/vocab_index': [], 'numeric_value': []} +tuning/0 +{'code/vocab_index': [], 'numeric_value': []} +held_out/0 +{'code/vocab_index': [6, 7, 10, 4, 11, 2, 8, 9, 11, 2, 8, 9, 11, 2, 8, 9, 11, 2, + 3], + 'numeric_value': [None, None, None, None, None, None, None, + -0.0889078751206398, None, None, None, 1.5111083984375, None, + None, None, 0.4444173276424408, None, None, None]} + +``` +""" + +# Note we have dropped the row in the held out shard that doesn't have a code in the vocabulary! +WANT_NORMALIZATION = parse_shards_yaml( + """ + "normalization/train/0": |-2 + patient_id,time,code,numeric_value + 239684,,6, + 239684,,7, + 239684,"12/28/1980, 00:00:00",10, + 239684,"12/28/1980, 00:00:00",4, + 239684,"05/11/2010, 17:41:51",11, + 239684,"05/11/2010, 17:41:51",2, + 239684,"05/11/2010, 17:41:51",1, + 239684,"05/11/2010, 17:41:51",8, + 239684,"05/11/2010, 17:41:51",9, + 239684,"05/11/2010, 17:48:48",11, + 239684,"05/11/2010, 17:48:48",2, + 239684,"05/11/2010, 17:48:48",8, + 239684,"05/11/2010, 17:48:48",9, + 239684,"05/11/2010, 18:25:35",12, + 239684,"05/11/2010, 18:25:35",2, + 239684,"05/11/2010, 18:25:35",8,0.9341503977775574 + 239684,"05/11/2010, 18:25:35",9, + 239684,"05/11/2010, 18:57:18",12, + 239684,"05/11/2010, 18:57:18",2, + 239684,"05/11/2010, 18:57:18",8,0.6264293789863586 + 239684,"05/11/2010, 18:57:18",9, + 239684,"05/11/2010, 19:27:19",12, + 239684,"05/11/2010, 19:27:19",2, + 239684,"05/11/2010, 19:27:19",3, + 1195293,,5, + 1195293,,7, + 1195293,"06/20/1978, 00:00:00",10, + 1195293,"06/20/1978, 00:00:00",4, + 1195293,"06/20/2010, 19:23:52",12, + 1195293,"06/20/2010, 19:23:52",2,nan + 1195293,"06/20/2010, 19:23:52",1, + 1195293,"06/20/2010, 19:23:52",8,-0.7583094239234924 + 1195293,"06/20/2010, 19:23:52",9,-0.0889078751206398 + 1195293,"06/20/2010, 19:25:32",12, + 1195293,"06/20/2010, 19:25:32",2,nan + 1195293,"06/20/2010, 19:25:32",8,1.2034040689468384 + 1195293,"06/20/2010, 19:25:32",9,-0.0889078751206398 + 1195293,"06/20/2010, 19:45:19",12, + 1195293,"06/20/2010, 19:45:19",2,nan + 1195293,"06/20/2010, 19:45:19",8, + 1195293,"06/20/2010, 19:45:19",9,-0.6222330927848816 + 1195293,"06/20/2010, 20:12:31",12, + 1195293,"06/20/2010, 20:12:31",2,nan + 1195293,"06/20/2010, 20:12:31",8,0.5879650115966797 + 1195293,"06/20/2010, 20:12:31",9,-1.1555582284927368 + 1195293,"06/20/2010, 20:24:44",12 + 1195293,"06/20/2010, 20:24:44",2,nan + 1195293,"06/20/2010, 20:24:44",8,-1.2583553791046143 + 1195293,"06/20/2010, 20:24:44",9,-0.0889078751206398 + 1195293,"06/20/2010, 20:41:33",12, + 1195293,"06/20/2010, 20:41:33",2,nan + 1195293,"06/20/2010, 20:41:33",8,-1.3352841138839722 + 1195293,"06/20/2010, 20:41:33",9,2.04443359375 + 1195293,"06/20/2010, 20:50:04",12, + 1195293,"06/20/2010, 20:50:04",2,nan + 1195293,"06/20/2010, 20:50:04",3, + + "normalization/train/1": |-2 + patient_id,time,code,numeric_value + + "normalization/tuning/0": |-2 + patient_id,time,code,numeric_value + + "normalization/held_out/0": |-2 + patient_id,time,code,numeric_value + 1500733,,6, + 1500733,,7, + 1500733,"07/20/1986, 00:00:00",10, + 1500733,"07/20/1986, 00:00:00",4, + 1500733,"06/03/2010, 14:54:38",11, + 1500733,"06/03/2010, 14:54:38",2, + 1500733,"06/03/2010, 14:54:38",8, + 1500733,"06/03/2010, 14:54:38",9,-0.0889078751206398 + 1500733,"06/03/2010, 15:39:49",11, + 1500733,"06/03/2010, 15:39:49",2, + 1500733,"06/03/2010, 15:39:49",8, + 1500733,"06/03/2010, 15:39:49",9,1.5111083984375 + 1500733,"06/03/2010, 16:20:49",11, + 1500733,"06/03/2010, 16:20:49",2, + 1500733,"06/03/2010, 16:20:49",8, + 1500733,"06/03/2010, 16:20:49",9,0.4444173276424408 + 1500733,"06/03/2010, 16:44:26",11, + 1500733,"06/03/2010, 16:44:26",2, + 1500733,"06/03/2010, 16:44:26",3, + """, + code=pl.UInt8, +) + +TOKENIZATION_SCHEMA_DF_SCHEMA = { + "patient_id": pl.UInt32, + "code": pl.List(pl.UInt8), + "numeric_value": pl.List(pl.Float32), + "start_time": pl.Datetime("us"), + "time": pl.List(pl.Datetime("us")), +} +WANT_TOKENIZATION_SCHEMAS = { + "tokenization/schemas/train/0": pl.DataFrame( + { + "patient_id": [239684, 1195293], + "code": [[6, 7], [5, 7]], + "numeric_value": [[None, None], [None, None]], + "start_time": [datetime(1980, 12, 28), datetime(1978, 6, 20)], + "time": [ + [ + datetime(1980, 12, 28, 0, 0, 0), + datetime(2010, 5, 11, 17, 41, 51), + datetime(2010, 5, 11, 17, 48, 48), + datetime(2010, 5, 11, 18, 25, 35), + datetime(2010, 5, 11, 18, 57, 18), + datetime(2010, 5, 11, 19, 27, 19), + ], + [ + datetime(1978, 6, 20, 0, 0, 0), + datetime(2010, 6, 20, 19, 23, 52), + datetime(2010, 6, 20, 19, 25, 32), + datetime(2010, 6, 20, 19, 45, 19), + datetime(2010, 6, 20, 20, 12, 31), + datetime(2010, 6, 20, 20, 24, 44), + datetime(2010, 6, 20, 20, 41, 33), + datetime(2010, 6, 20, 20, 50, 4), + ], + ], + }, + schema=TOKENIZATION_SCHEMA_DF_SCHEMA, + ), + "tokenization/schemas/train/1": pl.DataFrame( + {k: [] for k in ["patient_id", "code", "numeric_value", "start_time", "time"]}, + schema=TOKENIZATION_SCHEMA_DF_SCHEMA, + ), + "tokenization/schemas/tuning/0": pl.DataFrame( + {k: [] for k in ["patient_id", "code", "numeric_value", "start_time", "time"]}, + schema=TOKENIZATION_SCHEMA_DF_SCHEMA, + ), + "tokenization/schemas/held_out/0": pl.DataFrame( + { + "patient_id": [1500733], + "code": [[6, 7]], + "numeric_value": [[None, None]], + "start_time": [datetime(1986, 7, 20)], + "time": [ + [ + datetime(1986, 7, 20, 0, 0, 0), + datetime(2010, 6, 3, 14, 54, 38), + datetime(2010, 6, 3, 15, 39, 49), + datetime(2010, 6, 3, 16, 20, 49), + datetime(2010, 6, 3, 16, 44, 26), + ] + ], + }, + schema=TOKENIZATION_SCHEMA_DF_SCHEMA, + ), +} + +TOKENIZATION_CODE = """ +```python + +>>> import polars as pl +>>> from tests.test_multi_stage_preprocess_pipeline import WANT_NORMALIZATION as dfs +>>> + +``` +""" + +TOKENIZATION_EVENT_SEQS_DF_SCHEMA = { + "patient_id": pl.UInt32, + "code": pl.List(pl.List(pl.UInt8)), + "numeric_value": pl.List(pl.List(pl.Float32)), + "time_delta_days": pl.List(pl.Float32), +} + +WANT_TOKENIZATION_EVENT_SEQS = { + "tokenization/event_seqs/train/0": pl.DataFrame( + { + "patient_id": [239684, 1195293], + "code": [ + [[10, 4], [11, 2, 1, 8, 9], [11, 2, 8, 9], [12, 2, 8, 9], [12, 2, 8, 9], [12, 2, 3]], + [ + [10, 4], + [12, 2, 1, 8, 9], + [12, 2, 8, 9], + [12, 2, 8, 9], + [12, 2, 8, 9], + [12, 2, 8, 9], + [12, 2, 8, 9], + [12, 2, 3], + ], + ], + "numeric_value": [ + [ + [float("nan"), float("nan")], + [float("nan"), float("nan"), float("nan"), float("nan"), float("nan")], + [float("nan"), float("nan"), float("nan"), float("nan")], + [float("nan"), float("nan"), 0.9341503977775574, float("nan")], + [float("nan"), float("nan"), 0.6264293789863586, float("nan")], + [float("nan"), float("nan"), float("nan")], + ], + [ + [float("nan"), float("nan")], + [float("nan"), float("nan"), float("nan"), -0.7583094239234924, -0.0889078751206398], + [float("nan"), float("nan"), 1.2034040689468384, -0.0889078751206398], + [float("nan"), float("nan"), float("nan"), -0.6222330927848816], + [float("nan"), float("nan"), 0.5879650115966797, -1.1555582284927368], + [float("nan"), float("nan"), -1.2583553791046143, -0.0889078751206398], + [float("nan"), float("nan"), -1.3352841138839722, 2.04443359375], + [float("nan"), float("nan"), float("nan")], + ], + ], + "time_delta_days": ( + WANT_TOKENIZATION_SCHEMAS["tokenization/schemas/train/0"] + .select( + pl.col("time") + .list.diff() + .list.eval((pl.element().dt.total_seconds() / 86400).fill_null(float("nan"))) + )["time"] + .to_list() + ), + }, + schema=TOKENIZATION_EVENT_SEQS_DF_SCHEMA, + ), + "tokenization/event_seqs/train/1": pl.DataFrame( + {k: [] for k in ["patient_id", "code", "numeric_value", "time_delta_days"]}, + schema=TOKENIZATION_EVENT_SEQS_DF_SCHEMA, + ), + "tokenization/event_seqs/tuning/0": pl.DataFrame( + {k: [] for k in ["patient_id", "code", "numeric_value", "time_delta_days"]}, + schema=TOKENIZATION_EVENT_SEQS_DF_SCHEMA, + ), + "tokenization/event_seqs/held_out/0": pl.DataFrame( + { + "patient_id": [1500733], + "code": [ + [ + [10, 4], + [11, 2, 8, 9], + [11, 2, 8, 9], + [11, 2, 8, 9], + [11, 2, 3], + ] + ], + "numeric_value": [ + [ + [float("nan"), float("nan")], + [float("nan"), float("nan"), float("nan"), -0.0889078751206398], + [float("nan"), float("nan"), float("nan"), 1.5111083984375], + [float("nan"), float("nan"), float("nan"), 0.4444173276424408], + [float("nan"), float("nan"), float("nan")], + ] + ], + "time_delta_days": ( + WANT_TOKENIZATION_SCHEMAS["tokenization/schemas/held_out/0"] + .select( + pl.col("time") + .list.diff() + .list.eval((pl.element().dt.total_seconds() / 86400).fill_null(float("nan"))) + )["time"] + .to_list() + ), + }, + schema=TOKENIZATION_EVENT_SEQS_DF_SCHEMA, + ), +} + + +WANT_NRTs = { + "data/train/0.nrt": JointNestedRaggedTensorDict( + WANT_TOKENIZATION_EVENT_SEQS["tokenization/event_seqs/train/0"] + .select("time_delta_days", "code", "numeric_value") + .to_dict(as_series=False) + ), + "data/train/1.nrt": JointNestedRaggedTensorDict({}), # this shard was fully filtered out. + "data/tuning/0.nrt": JointNestedRaggedTensorDict({}), # this shard was fully filtered out. + "data/held_out/0.nrt": JointNestedRaggedTensorDict( + WANT_TOKENIZATION_EVENT_SEQS["tokenization/event_seqs/held_out/0"] + .select("time_delta_days", "code", "numeric_value") + .to_dict(as_series=False) + ), +} + + +def test_pipeline(): + multi_stage_transform_tester( + transform_scripts=[ + FILTER_PATIENTS_SCRIPT, + ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, + AGGREGATE_CODE_METADATA_SCRIPT, + OCCLUDE_OUTLIERS_SCRIPT, + AGGREGATE_CODE_METADATA_SCRIPT, + FIT_VOCABULARY_INDICES_SCRIPT, + NORMALIZATION_SCRIPT, + TOKENIZATION_SCRIPT, + TENSORIZATION_SCRIPT, + ], + stage_names=[ + "filter_patients", + "add_time_derived_measurements", + "fit_outlier_detection", + "occlude_outliers", + "fit_normalization", + "fit_vocabulary_indices", + "normalization", + "tokenization", + "tensorization", + ], + stage_configs=STAGE_CONFIG_YAML, + want_metadata={ + **WANT_FIT_OUTLIERS, + **WANT_FIT_NORMALIZATION, + **WANT_FIT_VOCABULARY_INDICES, + }, + want_data={ + **WANT_FILTER, + **WANT_TIME_DERIVED, + **WANT_OCCLUDE_OUTLIERS, + **WANT_NORMALIZATION, + **WANT_TOKENIZATION_SCHEMAS, + **WANT_TOKENIZATION_EVENT_SEQS, + **WANT_NRTs, + }, + outputs_from_cohort_dir=True, + input_code_metadata=MEDS_CODE_METADATA, + ) diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 12332e2..46992ed 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -224,6 +224,6 @@ def test_normalization(): transform_script=NORMALIZATION_SCRIPT, stage_name="normalization", transform_stage_kwargs=None, - code_metadata=MEDS_CODE_METADATA_CSV, - want_outputs=WANT_SHARDS, + input_code_metadata=MEDS_CODE_METADATA_CSV, + want_data=WANT_SHARDS, ) diff --git a/tests/test_occlude_outliers.py b/tests/test_occlude_outliers.py index 2060bb7..63e9376 100644 --- a/tests/test_occlude_outliers.py +++ b/tests/test_occlude_outliers.py @@ -171,5 +171,5 @@ def test_occlude_outliers(): transform_script=OCCLUDE_OUTLIERS_SCRIPT, stage_name="occlude_outliers", transform_stage_kwargs={"stddev_cutoff": 1}, - want_outputs=WANT_SHARDS, + want_data=WANT_SHARDS, ) diff --git a/tests/test_reorder_measurements.py b/tests/test_reorder_measurements.py index 305b85b..c90dee4 100644 --- a/tests/test_reorder_measurements.py +++ b/tests/test_reorder_measurements.py @@ -111,5 +111,5 @@ def test_reorder_measurements(): transform_script=REORDER_MEASUREMENTS_SCRIPT, stage_name="reorder_measurements", transform_stage_kwargs={"ordered_code_patterns": ORDERED_CODE_PATTERNS}, - want_outputs=WANT_SHARDS, + want_data=WANT_SHARDS, ) diff --git a/tests/test_reshard_to_split.py b/tests/test_reshard_to_split.py index 3eeecfe..65056e5 100644 --- a/tests/test_reshard_to_split.py +++ b/tests/test_reshard_to_split.py @@ -195,7 +195,7 @@ def test_reshard_to_split(): transform_script=RESHARD_TO_SPLIT_SCRIPT, stage_name="reshard_to_split", transform_stage_kwargs={"n_patients_per_shard": 2}, - want_outputs=WANT_SHARDS, + want_data=WANT_SHARDS, input_shards=IN_SHARDS, input_shards_map=IN_SHARDS_MAP, input_splits_map=SPLITS, diff --git a/tests/test_tensorization.py b/tests/test_tensorization.py index 789dcc1..0337155 100644 --- a/tests/test_tensorization.py +++ b/tests/test_tensorization.py @@ -13,7 +13,7 @@ from .transform_tester_base import TENSORIZATION_SCRIPT, single_stage_transform_tester WANT_NRTS = { - k.replace("event_seqs/", ""): JointNestedRaggedTensorDict( + f'{k.replace("event_seqs/", "")}.nrt': JointNestedRaggedTensorDict( v.select("time_delta_days", "code", "numeric_value").to_dict(as_series=False) ) for k, v in TOKENIZED_SHARDS.items() @@ -26,6 +26,5 @@ def test_tensorization(): stage_name="tensorization", transform_stage_kwargs=None, input_shards=TOKENIZED_SHARDS, - want_outputs=WANT_NRTS, - file_suffix=".nrt", + want_data=WANT_NRTS, ) diff --git a/tests/test_tokenization.py b/tests/test_tokenization.py index 58db843..693add1 100644 --- a/tests/test_tokenization.py +++ b/tests/test_tokenization.py @@ -39,9 +39,9 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: SEQ_SCHEMA = { "patient_id": NORMALIZED_MEDS_SCHEMA["patient_id"], - "code": pl.List(pl.List(pl.Float64)), + "code": pl.List(pl.List(pl.UInt8)), "numeric_value": pl.List(pl.List(NORMALIZED_MEDS_SCHEMA["numeric_value"])), - "time_delta_days": pl.List(pl.Float64), + "time_delta_days": pl.List(pl.Float32), } TRAIN_0_TIMES = [ @@ -223,5 +223,5 @@ def test_tokenization(): stage_name="tokenization", transform_stage_kwargs=None, input_shards=NORMALIZED_SHARDS, - want_outputs={**WANT_SCHEMAS, **WANT_EVENT_SEQS}, + want_data={**WANT_SCHEMAS, **WANT_EVENT_SEQS}, ) diff --git a/tests/transform_tester_base.py b/tests/transform_tester_base.py index 8dbd637..6845e0c 100644 --- a/tests/transform_tester_base.py +++ b/tests/transform_tester_base.py @@ -4,10 +4,18 @@ scripts. """ +from yaml import load as load_yaml + +try: + from yaml import CLoader as Loader +except ImportError: + from yaml import Loader + import json import os import tempfile from collections import defaultdict +from contextlib import contextmanager from io import StringIO from pathlib import Path @@ -16,7 +24,7 @@ import rootutils from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict -from .utils import assert_df_equal, parse_meds_csvs, run_command +from .utils import MEDS_PL_SCHEMA, assert_df_equal, parse_meds_csvs, run_command root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) @@ -195,6 +203,11 @@ } +def parse_shards_yaml(yaml_str: str, **schema_updates) -> pl.DataFrame: + schema = {**MEDS_PL_SCHEMA, **schema_updates} + return parse_meds_csvs(load_yaml(yaml_str, Loader=Loader), schema=schema) + + def parse_code_metadata_csv(csv_str: str) -> pl.DataFrame: cols = csv_str.strip().split("\n")[0].split(",") schema = {col: dt for col, dt in MEDS_CODE_METADATA_SCHEMA.items() if col in cols} @@ -210,8 +223,6 @@ def parse_code_metadata_csv(csv_str: str) -> pl.DataFrame: def check_NRT_output( output_fp: Path, want_nrt: JointNestedRaggedTensorDict, - stderr: str, - stdout: str, ): assert output_fp.is_file(), f"Expected {output_fp} to exist." @@ -219,8 +230,6 @@ def check_NRT_output( # assert got_nrt.schema == want_nrt.schema, ( # f"Expected the schema of the NRT at {output_fp} to be equal to the target.\n" - # f"Script stdout:\n{stdout}\n" - # f"Script stderr:\n{stderr}\n" # f"Wanted:\n{want_nrt.schema}\n" # f"Got:\n{got_nrt.schema}" # ) @@ -230,8 +239,6 @@ def check_NRT_output( assert got_tensors.keys() == want_tensors.keys(), ( f"Expected the keys of the NRT at {output_fp} to be equal to the target.\n" - f"Script stdout:\n{stdout}\n" - f"Script stderr:\n{stderr}\n" f"Wanted:\n{list(want_tensors.keys())}\n" f"Got:\n{list(got_tensors.keys())}" ) @@ -242,8 +249,6 @@ def check_NRT_output( assert type(want_v) is type(got_v), ( f"Expected tensor {k} of the NRT at {output_fp} to be of the same type as the target.\n" - f"Script stdout:\n{stdout}\n" - f"Script stderr:\n{stderr}\n" f"Wanted:\n{type(want_v)}\n" f"Got:\n{type(got_v)}" ) @@ -251,24 +256,18 @@ def check_NRT_output( if isinstance(want_v, list): assert len(want_v) == len(got_v), ( f"Expected list {k} of the NRT at {output_fp} to be of the same length as the target.\n" - f"Script stdout:\n{stdout}\n" - f"Script stderr:\n{stderr}\n" f"Wanted:\n{len(want_v)}\n" f"Got:\n{len(got_v)}" ) for i, (want_i, got_i) in enumerate(zip(want_v, got_v)): assert np.array_equal(want_i, got_i, equal_nan=True), ( f"Expected tensor {k}[{i}] of the NRT at {output_fp} to be equal to the target.\n" - f"Script stdout:\n{stdout}\n" - f"Script stderr:\n{stderr}\n" f"Wanted:\n{want_i}\n" f"Got:\n{got_i}" ) else: assert np.array_equal(want_v, got_v, equal_nan=True), ( f"Expected tensor {k} of the NRT at {output_fp} to be equal to the target.\n" - f"Script stdout:\n{stdout}\n" - f"Script stderr:\n{stderr}\n" f"Wanted:\n{want_v}\n" f"Got:\n{got_v}" ) @@ -277,8 +276,6 @@ def check_NRT_output( def check_df_output( output_fp: Path, want_df: pl.DataFrame, - stderr: str, - stdout: str, check_column_order: bool = False, check_row_order: bool = True, **kwargs, @@ -289,30 +286,19 @@ def check_df_output( assert_df_equal( want_df, got_df, - ( - f"Expected the dataframe at {output_fp} to be equal to the target.\n" - f"Script stdout:\n{stdout}\n" - f"Script stderr:\n{stderr}" - ), + (f"Expected the dataframe at {output_fp} to be equal to the target.\n"), check_column_order=check_column_order, check_row_order=check_row_order, **kwargs, ) -def single_stage_transform_tester( - transform_script: str | Path, - stage_name: str, - transform_stage_kwargs: dict[str, str] | None, - want_outputs: pl.DataFrame | dict[str, pl.DataFrame], - code_metadata: pl.DataFrame | str | None = None, +@contextmanager +def input_MEDS_dataset( + input_code_metadata: pl.DataFrame | str | None = None, input_shards: dict[str, pl.DataFrame] | None = None, - do_pass_stage_name: bool = False, - file_suffix: str = ".parquet", - do_use_config_yaml: bool = False, input_shards_map: dict[str, list[int]] | None = None, input_splits_map: dict[str, list[int]] | None = None, - assert_no_other_outputs: bool = True, ): with tempfile.TemporaryDirectory() as d: MEDS_dir = Path(d) / "MEDS_cohort" @@ -320,7 +306,6 @@ def single_stage_transform_tester( MEDS_data_dir = MEDS_dir / "data" MEDS_metadata_dir = MEDS_dir / "metadata" - cohort_metadata_dir = cohort_dir / "metadata" # Create the directories MEDS_data_dir.mkdir(parents=True) @@ -355,12 +340,71 @@ def single_stage_transform_tester( df.write_parquet(fp, use_pyarrow=True) code_metadata_fp = MEDS_metadata_dir / "codes.parquet" - if code_metadata is None: - code_metadata = MEDS_CODE_METADATA - elif isinstance(code_metadata, str): - code_metadata = parse_code_metadata_csv(code_metadata) - code_metadata.write_parquet(code_metadata_fp, use_pyarrow=True) + if input_code_metadata is None: + input_code_metadata = MEDS_CODE_METADATA + elif isinstance(input_code_metadata, str): + input_code_metadata = parse_code_metadata_csv(input_code_metadata) + input_code_metadata.write_parquet(code_metadata_fp, use_pyarrow=True) + + yield MEDS_dir, cohort_dir + +def check_outputs( + cohort_dir: Path, + want_data: dict[str, pl.DataFrame] | None = None, + want_metadata: dict[str, pl.DataFrame] | pl.DataFrame | None = None, + assert_no_other_outputs: bool = True, + outputs_from_cohort_dir: bool = False, +): + if want_metadata is not None: + if isinstance(want_metadata, pl.DataFrame): + want_metadata = {"codes.parquet": want_metadata} + metadata_root = cohort_dir if outputs_from_cohort_dir else cohort_dir / "metadata" + for shard_name, want in want_metadata.items(): + if Path(shard_name).suffix == "": + shard_name = f"{shard_name}.parquet" + check_df_output(metadata_root / shard_name, want) + + if want_data: + data_root = cohort_dir if outputs_from_cohort_dir else cohort_dir / "data" + all_file_suffixes = set() + for shard_name, want in want_data.items(): + if Path(shard_name).suffix == "": + shard_name = f"{shard_name}.parquet" + + file_suffix = Path(shard_name).suffix + all_file_suffixes.add(file_suffix) + + output_fp = data_root / f"{shard_name}" + if file_suffix == ".parquet": + check_df_output(output_fp, want) + elif file_suffix == ".nrt": + check_NRT_output(output_fp, want) + else: + raise ValueError(f"Unknown file suffix: {file_suffix}") + + if assert_no_other_outputs: + all_outputs = [] + for suffix in all_file_suffixes: + all_outputs.extend(list((data_root).glob(f"**/*{suffix}"))) + assert len(want_data) == len(all_outputs), ( + f"Want {len(want_data)} outputs, but found {len(all_outputs)}.\n" + f"Found outputs: {[fp.relative_to(data_root) for fp in all_outputs]}\n" + ) + + +def single_stage_transform_tester( + transform_script: str | Path, + stage_name: str, + transform_stage_kwargs: dict[str, str] | None, + do_pass_stage_name: bool = False, + do_use_config_yaml: bool = False, + want_data: dict[str, pl.DataFrame] | None = None, + want_metadata: pl.DataFrame | None = None, + assert_no_other_outputs: bool = True, + **input_data_kwargs, +): + with input_MEDS_dataset(**input_data_kwargs) as (MEDS_dir, cohort_dir): pipeline_config_kwargs = { "input_dir": str(MEDS_dir.resolve()), "cohort_dir": str(cohort_dir.resolve()), @@ -368,8 +412,6 @@ def single_stage_transform_tester( "hydra.verbose": True, } - if do_pass_stage_name: - pipeline_config_kwargs["stage"] = stage_name if transform_stage_kwargs: pipeline_config_kwargs["stage_configs"] = {stage_name: transform_stage_kwargs} @@ -381,29 +423,79 @@ def single_stage_transform_tester( if do_use_config_yaml: run_command_kwargs["do_use_config_yaml"] = True run_command_kwargs["config_name"] = "preprocess" + if do_pass_stage_name: + run_command_kwargs["stage"] = stage_name + run_command_kwargs["do_pass_stage_name"] = True # Run the transform stderr, stdout = run_command(**run_command_kwargs) - # Check the output - if isinstance(want_outputs, pl.DataFrame): - # The want output is a code_metadata file in the root directory in this case. - check_df_output(cohort_metadata_dir / "codes.parquet", want_outputs, stderr, stdout) - else: - for shard_name, want in want_outputs.items(): - output_fp = cohort_dir / "data" / f"{shard_name}{file_suffix}" - if file_suffix == ".parquet": - check_df_output(output_fp, want, stderr, stdout) - elif file_suffix == ".nrt": - check_NRT_output(output_fp, want, stderr, stdout) - else: - raise ValueError(f"Unknown file suffix: {file_suffix}") - - if assert_no_other_outputs: - all_outputs = list((cohort_dir / "data").glob(f"**/*{file_suffix}")) - assert len(want_outputs) == len(all_outputs), ( - f"Expected {len(want_outputs)} outputs, but found {len(all_outputs)}.\n" - f"Found outputs: {[fp.relative_to(cohort_dir/'data') for fp in all_outputs]}\n" - f"Script stdout:\n{stdout}\n" - f"Script stderr:\n{stderr}" - ) + try: + check_outputs(cohort_dir, want_data=want_data, want_metadata=want_metadata) + except Exception as e: + raise AssertionError( + f"Single stage transform {stage_name} failed.\n" + f"Script stdout:\n{stdout}\n" + f"Script stderr:\n{stderr}" + ) from e + + +def multi_stage_transform_tester( + transform_scripts: list[str | Path], + stage_names: list[str], + stage_configs: dict[str, str] | str | None, + do_pass_stage_name: bool | dict[str, bool] = True, + want_data: dict[str, pl.DataFrame] | None = None, + want_metadata: pl.DataFrame | None = None, + outputs_from_cohort_dir: bool = True, + **input_data_kwargs, +): + with input_MEDS_dataset(**input_data_kwargs) as (MEDS_dir, cohort_dir): + match stage_configs: + case None: + stage_configs = {} + case str(): + stage_configs = load_yaml(stage_configs, Loader=Loader) + case dict(): + pass + case _: + raise ValueError(f"Unknown stage_configs type: {type(stage_configs)}") + + match do_pass_stage_name: + case True: + do_pass_stage_name = {stage_name: True for stage_name in stage_names} + case False: + do_pass_stage_name = {stage_name: False for stage_name in stage_names} + case dict(): + pass + case _: + raise ValueError(f"Unknown do_pass_stage_name type: {type(do_pass_stage_name)}") + + pipeline_config_kwargs = { + "input_dir": str(MEDS_dir.resolve()), + "cohort_dir": str(cohort_dir.resolve()), + "stages": stage_names, + "stage_configs": stage_configs, + "hydra.verbose": True, + } + + script_outputs = {} + n_stages = len(stage_names) + for i, (stage, script) in enumerate(zip(stage_names, transform_scripts)): + script_outputs[stage] = run_command( + script=script, + hydra_kwargs=pipeline_config_kwargs, + do_use_config_yaml=True, + config_name="preprocess", + test_name=f"Multi stage transform {i}/{n_stages}: {stage}", + stage_name=stage, + do_pass_stage_name=do_pass_stage_name[stage], + ) + + check_outputs( + cohort_dir, + want_data=want_data, + want_metadata=want_metadata, + outputs_from_cohort_dir=outputs_from_cohort_dir, + assert_no_other_outputs=False, # this currently doesn't work due to metadata / data confusions. + ) diff --git a/tests/utils.py b/tests/utils.py index 5fca87e..e6c9d3f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,6 +15,7 @@ "time": pl.Datetime("us"), "code": pl.Utf8, "numeric_value": pl.Float32, + "numeric_value/is_inlier": pl.Boolean, } @@ -26,10 +27,12 @@ def parse_meds_csvs( TODO: doctests. """ - read_schema = {**schema} - read_schema["time"] = pl.Utf8 + default_read_schema = {**schema} + default_read_schema["time"] = pl.Utf8 def reader(csv_str: str) -> pl.DataFrame: + cols = csv_str.strip().split("\n")[0].split(",") + read_schema = {k: v for k, v in default_read_schema.items() if k in cols} return pl.read_csv(StringIO(csv_str), schema=read_schema).with_columns( pl.col("time").str.strptime(MEDS_PL_SCHEMA["time"], DEFAULT_CSV_TS_FORMAT) ) @@ -105,6 +108,8 @@ def run_command( config_name: str | None = None, should_error: bool = False, do_use_config_yaml: bool = False, + stage_name: str | None = None, + do_pass_stage_name: bool = False, ): script = ["python", str(script.resolve())] if isinstance(script, Path) else [script] command_parts = script @@ -139,6 +144,11 @@ def run_command( command_parts.append(f"--config-name={config_name}") command_parts.append(" ".join(dict_to_hydra_kwargs(hydra_kwargs))) + if do_pass_stage_name: + if stage_name is None: + raise ValueError("stage_name must be provided if do_pass_stage_name is True.") + command_parts.append(f"stage={stage_name}") + full_cmd = " ".join(command_parts) err_cmd_lines.append(f"Running command: {full_cmd}") command_out = subprocess.run(full_cmd, shell=True, capture_output=True)