Skip to content

Commit

Permalink
Added a multi-stage test which currently, appropriately, fails due to #…
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Aug 14, 2024
1 parent d5b0782 commit 138eb1e
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 3 deletions.
87 changes: 87 additions & 0 deletions tests/test_multi_stage_preprocess_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Tests a multi-stage pre-processing pipeline. Only checks the end result, not the intermediate files.
Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed
scripts.
In this test, the following stages are run:
- filter_patients
- add_time_derived_measurements
- fit_outlier_detection
- occlude_outliers
- fit_normalization
- fit_vocabulary_indices
- normalization
- tokenization
- tensorization
The stage configuration arguments will be as given in the yaml block below:
"""


from .transform_tester_base import (
ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT,
AGGREGATE_CODE_METADATA_SCRIPT,
FILTER_PATIENTS_SCRIPT,
FIT_VOCABULARY_INDICES_SCRIPT,
NORMALIZATION_SCRIPT,
OCCLUDE_OUTLIERS_SCRIPT,
TENSORIZATION_SCRIPT,
TOKENIZATION_SCRIPT,
multi_stage_transform_tester,
)

STAGE_CONFIG_YAML = """
filter_patients:
min_events_per_patient: 5
add_time_derived_measurements:
age:
DOB_code: "DOB"
age_code: "AGE"
age_unit: "years"
fit_outlier_detection:
aggregations:
- "values/n_occurrences"
- "values/sum"
- "values/sum_sqd"
occlude_outliers:
stddev_cutoff: 1
fit_normalization:
aggregations:
- "code/n_occurrences"
- "code/n_patients"
- "values/n_occurrences"
- "values/sum"
- "values/sum_sqd"
fit_vocabulary_indices:
is_metadata: true
ordering_method: "lexicographic"
output_dir: "${cohort_dir}"
"""


def test_pipeline():
multi_stage_transform_tester(
transform_scripts=[
FILTER_PATIENTS_SCRIPT,
ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT,
AGGREGATE_CODE_METADATA_SCRIPT,
OCCLUDE_OUTLIERS_SCRIPT,
AGGREGATE_CODE_METADATA_SCRIPT,
FIT_VOCABULARY_INDICES_SCRIPT,
NORMALIZATION_SCRIPT,
TOKENIZATION_SCRIPT,
TENSORIZATION_SCRIPT,
],
stage_names=[
"filter_patients",
"add_time_derived_measurements",
"fit_outlier_detection",
"occlude_outliers",
"fit_normalization",
"fit_vocabulary_indices",
"normalization",
"tokenization",
"tensorization",
],
stage_configs=STAGE_CONFIG_YAML,
)
65 changes: 62 additions & 3 deletions tests/transform_tester_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
scripts.
"""

from yaml import load as load_yaml

try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader
import json
import os
import tempfile
Expand Down Expand Up @@ -274,9 +280,7 @@ def check_df_output(
assert_df_equal(
want_df,
got_df,
(
f"Expected the dataframe at {output_fp} to be equal to the target.\n"
),
(f"Expected the dataframe at {output_fp} to be equal to the target.\n"),
check_column_order=check_column_order,
check_row_order=check_row_order,
**kwargs,
Expand Down Expand Up @@ -338,6 +342,7 @@ def input_MEDS_dataset(

yield MEDS_dir, cohort_dir


def check_outputs(
cohort_dir: Path,
want_data: dict[str, pl.DataFrame] | None = None,
Expand Down Expand Up @@ -416,3 +421,57 @@ def single_stage_transform_tester(
f"Script stdout:\n{stdout}\n"
f"Script stderr:\n{stderr}"
) from e


def multi_stage_transform_tester(
transform_scripts: list[str | Path],
stage_names: list[str],
stage_configs: dict[str, str] | str | None,
do_pass_stage_name: bool | dict[str, bool] = True,
want_data: dict[str, pl.DataFrame] | None = None,
want_metadata: pl.DataFrame | None = None,
**input_data_kwargs,
):
with input_MEDS_dataset(**input_data_kwargs) as (MEDS_dir, cohort_dir):
match stage_configs:
case None:
stage_configs = {}
case str():
stage_configs = load_yaml(stage_configs, Loader=Loader)
case dict():
pass
case _:
raise ValueError(f"Unknown stage_configs type: {type(stage_configs)}")

match do_pass_stage_name:
case True:
do_pass_stage_name = {stage_name: True for stage_name in stage_names}
case False:
do_pass_stage_name = {stage_name: False for stage_name in stage_names}
case dict():
pass
case _:
raise ValueError(f"Unknown do_pass_stage_name type: {type(do_pass_stage_name)}")

pipeline_config_kwargs = {
"input_dir": str(MEDS_dir.resolve()),
"cohort_dir": str(cohort_dir.resolve()),
"stages": stage_names,
"stage_configs": stage_configs,
"hydra.verbose": True,
}

script_outputs = {}
n_stages = len(stage_names)
for i, (stage, script) in enumerate(zip(stage_names, transform_scripts)):
script_outputs[stage] = run_command(
script=script,
hydra_kwargs=pipeline_config_kwargs,
do_use_config_yaml=True,
config_name="preprocess",
test_name=f"Multi stage transform {i}/{n_stages}: {stage}",
stage_name=stage,
do_pass_stage_name=do_pass_stage_name[stage],
)

check_outputs(cohort_dir, want_data=want_data, want_metadata=want_metadata)

0 comments on commit 138eb1e

Please sign in to comment.