Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds the ability to specialize existing transformations to be applied with separate configurations to different portions of the data based on configurable filters. #119

Merged
merged 14 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,4 @@ repos:
- id: nbqa-isort
args: ["--profile=black"]
- id: nbqa-flake8
args:
[
"--extend-ignore=E203,E402,E501,F401,F841",
"--exclude=logs/*,data/*",
]
args: ["--extend-ignore=E203,E402,E501,F401,F841", "--exclude=logs/*,data/*"]
3 changes: 1 addition & 2 deletions eICU_Example/configs/table_preprocessors.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ admissiondx:
offset_col: "admitdxenteredoffset"
pseudotime_col: "admitDxEnteredTimestamp"
output_data_cols: ["admitdxname", "admitdxid"]
warning_items:
["How should we use `admitdxtest`?", "How should we use `admitdxpath`?"]
warning_items: ["How should we use `admitdxtest`?", "How should we use `admitdxpath`?"]

allergy:
offset_col: "allergyenteredoffset"
Expand Down
7 changes: 1 addition & 6 deletions src/MEDS_transforms/extract/convert_to_sharded_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@
from MEDS_transforms.extract import CONFIG_YAML
from MEDS_transforms.extract.shard_events import META_KEYS
from MEDS_transforms.mapreduce.mapper import rwlock_wrap
from MEDS_transforms.utils import (
is_col_field,
parse_col_field,
stage_init,
write_lazyframe,
)
from MEDS_transforms.utils import is_col_field, parse_col_field, stage_init, write_lazyframe


def in_format(fmt: str, ts_name: str) -> pl.Expr:
Expand Down
464 changes: 431 additions & 33 deletions src/MEDS_transforms/mapreduce/mapper.py

Large diffs are not rendered by default.

5 changes: 1 addition & 4 deletions tests/test_add_time_derived_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
"""


from .transform_tester_base import (
ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT,
single_stage_transform_tester,
)
from .transform_tester_base import ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, single_stage_transform_tester
from .utils import parse_meds_csvs

AGE_CALCULATION_STR = """
Expand Down
116 changes: 112 additions & 4 deletions tests/test_filter_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
"""


from .transform_tester_base import (
FILTER_MEASUREMENTS_SCRIPT,
single_stage_transform_tester,
)
from .transform_tester_base import FILTER_MEASUREMENTS_SCRIPT, single_stage_transform_tester
from .utils import parse_meds_csvs

# This is the code metadata
Expand Down Expand Up @@ -118,3 +115,114 @@ def test_filter_measurements():
transform_stage_kwargs={"min_patients_per_code": 2},
want_outputs=WANT_SHARDS,
)


# This is the code metadata
# MEDS_CODE_METADATA_CSV = """
# code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,description,parent_code
# ,44,4,28,3198.8389005974336,382968.28937288234,,
# ADMISSION//CARDIAC,2,2,0,,,,
# ADMISSION//ORTHOPEDIC,1,1,0,,,,
# ADMISSION//PULMONARY,1,1,0,,,,
# DISCHARGE,4,4,0,,,,
# DOB,4,4,0,,,,
# EYE_COLOR//BLUE,1,1,0,,,"Blue Eyes. Less common than brown.",
# EYE_COLOR//BROWN,1,1,0,,,"Brown Eyes. The most common eye color.",
# EYE_COLOR//HAZEL,2,2,0,,,"Hazel eyes. These are uncommon",
# HEIGHT,4,4,4,656.8389005974336,108056.12937288235,,
# HR,12,4,12,1360.5000000000002,158538.77,"Heart Rate",LOINC/8867-4
# TEMP,12,4,12,1181.4999999999998,116373.38999999998,"Body Temperature",LOINC/8310-5
# """
#
# In the test that applies to the match and revise framework, we'll filter codes in the following manner:
# - Codes that start with ADMISSION// will be filtered to occur at least 2 times, which are:
# ADMISSION//CARDIAC
# - Codes in [HR] will be filtered to occur at least 15 times, which are:
# (no codes)
# - Codes that start with EYE_COLOR// will be filtered to occur at least 4 times, which are:
# (no codes)
# - Other codes won't be filtered, so we will retain HEIGHT, DISCHARGE, DOB, TEMP

MR_WANT_TRAIN_0 = """
patient_id,time,code,numeric_value
239684,,HEIGHT,175.271115221764
239684,"12/28/1980, 00:00:00",DOB,
239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC,
239684,"05/11/2010, 17:41:51",TEMP,96.0
239684,"05/11/2010, 17:48:48",TEMP,96.2
239684,"05/11/2010, 18:25:35",TEMP,95.8
239684,"05/11/2010, 18:57:18",TEMP,95.5
239684,"05/11/2010, 19:27:19",DISCHARGE,
1195293,,HEIGHT,164.6868838269085
1195293,"06/20/1978, 00:00:00",DOB,
1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC,
1195293,"06/20/2010, 19:23:52",TEMP,100.0
1195293,"06/20/2010, 19:25:32",TEMP,100.0
1195293,"06/20/2010, 19:45:19",TEMP,99.9
1195293,"06/20/2010, 20:12:31",TEMP,99.8
1195293,"06/20/2010, 20:24:44",TEMP,100.0
1195293,"06/20/2010, 20:41:33",TEMP,100.4
1195293,"06/20/2010, 20:50:04",DISCHARGE,
"""

MR_WANT_TRAIN_1 = """
patient_id,time,code,numeric_value
68729,,HEIGHT,160.3953106166676
68729,"03/09/1978, 00:00:00",DOB,
68729,"05/26/2010, 02:30:56",TEMP,97.8
68729,"05/26/2010, 04:51:52",DISCHARGE,
814703,,HEIGHT,156.48559093209357
814703,"03/28/1976, 00:00:00",DOB,
814703,"02/05/2010, 05:55:39",TEMP,100.1
814703,"02/05/2010, 07:02:30",DISCHARGE,
"""

MR_WANT_TUNING_0 = """
patient_id,time,code,numeric_value
754281,,HEIGHT,166.22261567137025
754281,"12/19/1988, 00:00:00",DOB,
754281,"01/03/2010, 06:27:59",TEMP,99.8
754281,"01/03/2010, 08:22:13",DISCHARGE,
"""

MR_WANT_HELD_OUT_0 = """
patient_id,time,code,numeric_value
1500733,,HEIGHT,158.60131573580904
1500733,"07/20/1986, 00:00:00",DOB,
1500733,"06/03/2010, 14:54:38",TEMP,100.0
1500733,"06/03/2010, 15:39:49",TEMP,100.3
1500733,"06/03/2010, 16:20:49",TEMP,100.1
1500733,"06/03/2010, 16:44:26",DISCHARGE,
"""

MR_WANT_SHARDS = parse_meds_csvs(
{
"train/0": MR_WANT_TRAIN_0,
"train/1": MR_WANT_TRAIN_1,
"tuning/0": MR_WANT_TUNING_0,
"held_out/0": MR_WANT_HELD_OUT_0,
}
)

MATCH_REVISE_KEY = "_match_revise"
MATCHER_KEY = "_matcher"


def test_match_revise_filter_measurements():
single_stage_transform_tester(
transform_script=FILTER_MEASUREMENTS_SCRIPT,
stage_name="filter_measurements",
transform_stage_kwargs={
"_match_revise": [
{"_matcher": {"code": "ADMISSION//CARDIAC"}, "min_patients_per_code": 2},
{"_matcher": {"code": "ADMISSION//ORTHOPEDIC"}, "min_patients_per_code": 2},
{"_matcher": {"code": "ADMISSION//PULMONARY"}, "min_patients_per_code": 2},
{"_matcher": {"code": "HR"}, "min_patients_per_code": 15},
{"_matcher": {"code": "EYE_COLOR//BLUE"}, "min_patients_per_code": 4},
{"_matcher": {"code": "EYE_COLOR//BROWN"}, "min_patients_per_code": 4},
{"_matcher": {"code": "EYE_COLOR//HAZEL"}, "min_patients_per_code": 4},
],
},
want_outputs=MR_WANT_SHARDS,
do_use_config_yaml=True,
)
5 changes: 1 addition & 4 deletions tests/test_occlude_outliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@

import polars as pl

from .transform_tester_base import (
OCCLUDE_OUTLIERS_SCRIPT,
single_stage_transform_tester,
)
from .transform_tester_base import OCCLUDE_OUTLIERS_SCRIPT, single_stage_transform_tester
from .utils import MEDS_PL_SCHEMA, parse_meds_csvs

# This is the code metadata
Expand Down
5 changes: 1 addition & 4 deletions tests/test_reorder_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
"""


from .transform_tester_base import (
REORDER_MEASUREMENTS_SCRIPT,
single_stage_transform_tester,
)
from .transform_tester_base import REORDER_MEASUREMENTS_SCRIPT, single_stage_transform_tester
from .utils import parse_meds_csvs

ORDERED_CODE_PATTERNS = [
Expand Down
16 changes: 11 additions & 5 deletions tests/transform_tester_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def single_stage_transform_tester(
input_shards: dict[str, pl.DataFrame] | None = None,
do_pass_stage_name: bool = False,
file_suffix: str = ".parquet",
do_use_config_yaml: bool = False,
):
with tempfile.TemporaryDirectory() as d:
MEDS_dir = Path(d) / "MEDS_cohort"
Expand Down Expand Up @@ -337,12 +338,17 @@ def single_stage_transform_tester(
if transform_stage_kwargs:
pipeline_config_kwargs["stage_configs"] = {stage_name: transform_stage_kwargs}

run_command_kwargs = {
"script": transform_script,
"hydra_kwargs": pipeline_config_kwargs,
"test_name": f"Single stage transform: {stage_name}",
}
if do_use_config_yaml:
run_command_kwargs["do_use_config_yaml"] = True
run_command_kwargs["config_name"] = "preprocess"

# Run the transform
stderr, stdout = run_command(
transform_script,
pipeline_config_kwargs,
f"Single stage transform: {stage_name}",
)
stderr, stdout = run_command(**run_command_kwargs)

# Check the output
if isinstance(want_outputs, pl.DataFrame):
Expand Down
51 changes: 44 additions & 7 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import subprocess
import tempfile
from io import StringIO
from pathlib import Path

import polars as pl
from omegaconf import OmegaConf
from polars.testing import assert_frame_equal

DEFAULT_CSV_TS_FORMAT = "%m/%d/%Y, %H:%M:%S"
Expand Down Expand Up @@ -102,31 +104,66 @@ def run_command(
test_name: str,
config_name: str | None = None,
should_error: bool = False,
do_use_config_yaml: bool = False,
):
script = ["python", str(script.resolve())] if isinstance(script, Path) else [script]
command_parts = script
if config_name is not None:
command_parts.append(f"--config-name={config_name}")
command_parts.append(" ".join(dict_to_hydra_kwargs(hydra_kwargs)))

err_cmd_lines = []

if do_use_config_yaml:
if config_name is None:
raise ValueError("config_name must be provided if do_use_config_yaml is True.")

conf = OmegaConf.create(
{
"defaults": [config_name],
**hydra_kwargs,
}
)

conf_dir = tempfile.TemporaryDirectory()
conf_path = Path(conf_dir.name) / "config.yaml"
OmegaConf.save(conf, conf_path)

command_parts.extend(
[
f"--config-path={str(conf_path.parent.resolve())}",
"--config-name=config",
"'hydra.searchpath=[pkg://MEDS_transforms.configs]'",
]
)
err_cmd_lines.append(f"Using config yaml:\n{OmegaConf.to_yaml(conf)}")
else:
if config_name is not None:
command_parts.append(f"--config-name={config_name}")
command_parts.append(" ".join(dict_to_hydra_kwargs(hydra_kwargs)))

full_cmd = " ".join(command_parts)
err_cmd_lines.append(f"Running command: {full_cmd}")
command_out = subprocess.run(full_cmd, shell=True, capture_output=True)

command_errored = command_out.returncode != 0

stderr = command_out.stderr.decode()
err_cmd_lines.append(f"stderr:\n{stderr}")
stdout = command_out.stdout.decode()
err_cmd_lines.append(f"stdout:\n{stdout}")

if should_error and not command_errored:
if do_use_config_yaml:
conf_dir.cleanup()
raise AssertionError(
f"{test_name} failed as command did not error when expected!\n"
f"command:{full_cmd}\nstdout:\n{stdout}\nstderr:\n{stderr}"
f"{test_name} failed as command did not error when expected!\n" + "\n".join(err_cmd_lines)
)
elif not should_error and command_errored:
if do_use_config_yaml:
conf_dir.cleanup()
raise AssertionError(
f"{test_name} failed as command errored when not expected!"
f"\ncommand:{full_cmd}\nstdout:\n{stdout}\nstderr:\n{stderr}"
f"{test_name} failed as command errored when not expected!\n" + "\n".join(err_cmd_lines)
)
if do_use_config_yaml:
conf_dir.cleanup()
return stderr, stdout


Expand Down
Loading