diff --git a/tests/test_multi_stage_preprocess_pipeline.py b/tests/test_multi_stage_preprocess_pipeline.py new file mode 100644 index 0000000..f752f9b --- /dev/null +++ b/tests/test_multi_stage_preprocess_pipeline.py @@ -0,0 +1,87 @@ +"""Tests a multi-stage pre-processing pipeline. Only checks the end result, not the intermediate files. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. + +In this test, the following stages are run: + - filter_patients + - add_time_derived_measurements + - fit_outlier_detection + - occlude_outliers + - fit_normalization + - fit_vocabulary_indices + - normalization + - tokenization + - tensorization + +The stage configuration arguments will be as given in the yaml block below: +""" + + +from .transform_tester_base import ( + ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, + AGGREGATE_CODE_METADATA_SCRIPT, + FILTER_PATIENTS_SCRIPT, + FIT_VOCABULARY_INDICES_SCRIPT, + NORMALIZATION_SCRIPT, + OCCLUDE_OUTLIERS_SCRIPT, + TENSORIZATION_SCRIPT, + TOKENIZATION_SCRIPT, + multi_stage_transform_tester, +) + +STAGE_CONFIG_YAML = """ +filter_patients: + min_events_per_patient: 5 +add_time_derived_measurements: + age: + DOB_code: "DOB" + age_code: "AGE" + age_unit: "years" +fit_outlier_detection: + aggregations: + - "values/n_occurrences" + - "values/sum" + - "values/sum_sqd" +occlude_outliers: + stddev_cutoff: 1 +fit_normalization: + aggregations: + - "code/n_occurrences" + - "code/n_patients" + - "values/n_occurrences" + - "values/sum" + - "values/sum_sqd" +fit_vocabulary_indices: + is_metadata: true + ordering_method: "lexicographic" + output_dir: "${cohort_dir}" +""" + + +def test_pipeline(): + multi_stage_transform_tester( + transform_scripts=[ + FILTER_PATIENTS_SCRIPT, + ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, + AGGREGATE_CODE_METADATA_SCRIPT, + OCCLUDE_OUTLIERS_SCRIPT, + AGGREGATE_CODE_METADATA_SCRIPT, + FIT_VOCABULARY_INDICES_SCRIPT, + NORMALIZATION_SCRIPT, + TOKENIZATION_SCRIPT, + TENSORIZATION_SCRIPT, + ], + stage_names=[ + "filter_patients", + "add_time_derived_measurements", + "fit_outlier_detection", + "occlude_outliers", + "fit_normalization", + "fit_vocabulary_indices", + "normalization", + "tokenization", + "tensorization", + ], + stage_configs=STAGE_CONFIG_YAML, + ) diff --git a/tests/transform_tester_base.py b/tests/transform_tester_base.py index 7012020..afebf4e 100644 --- a/tests/transform_tester_base.py +++ b/tests/transform_tester_base.py @@ -4,6 +4,12 @@ scripts. """ +from yaml import load as load_yaml + +try: + from yaml import CLoader as Loader +except ImportError: + from yaml import Loader import json import os import tempfile @@ -274,9 +280,7 @@ def check_df_output( assert_df_equal( want_df, got_df, - ( - f"Expected the dataframe at {output_fp} to be equal to the target.\n" - ), + (f"Expected the dataframe at {output_fp} to be equal to the target.\n"), check_column_order=check_column_order, check_row_order=check_row_order, **kwargs, @@ -338,6 +342,7 @@ def input_MEDS_dataset( yield MEDS_dir, cohort_dir + def check_outputs( cohort_dir: Path, want_data: dict[str, pl.DataFrame] | None = None, @@ -416,3 +421,57 @@ def single_stage_transform_tester( f"Script stdout:\n{stdout}\n" f"Script stderr:\n{stderr}" ) from e + + +def multi_stage_transform_tester( + transform_scripts: list[str | Path], + stage_names: list[str], + stage_configs: dict[str, str] | str | None, + do_pass_stage_name: bool | dict[str, bool] = True, + want_data: dict[str, pl.DataFrame] | None = None, + want_metadata: pl.DataFrame | None = None, + **input_data_kwargs, +): + with input_MEDS_dataset(**input_data_kwargs) as (MEDS_dir, cohort_dir): + match stage_configs: + case None: + stage_configs = {} + case str(): + stage_configs = load_yaml(stage_configs, Loader=Loader) + case dict(): + pass + case _: + raise ValueError(f"Unknown stage_configs type: {type(stage_configs)}") + + match do_pass_stage_name: + case True: + do_pass_stage_name = {stage_name: True for stage_name in stage_names} + case False: + do_pass_stage_name = {stage_name: False for stage_name in stage_names} + case dict(): + pass + case _: + raise ValueError(f"Unknown do_pass_stage_name type: {type(do_pass_stage_name)}") + + pipeline_config_kwargs = { + "input_dir": str(MEDS_dir.resolve()), + "cohort_dir": str(cohort_dir.resolve()), + "stages": stage_names, + "stage_configs": stage_configs, + "hydra.verbose": True, + } + + script_outputs = {} + n_stages = len(stage_names) + for i, (stage, script) in enumerate(zip(stage_names, transform_scripts)): + script_outputs[stage] = run_command( + script=script, + hydra_kwargs=pipeline_config_kwargs, + do_use_config_yaml=True, + config_name="preprocess", + test_name=f"Multi stage transform {i}/{n_stages}: {stage}", + stage_name=stage, + do_pass_stage_name=do_pass_stage_name[stage], + ) + + check_outputs(cohort_dir, want_data=want_data, want_metadata=want_metadata)