Skip to content

Commit

Permalink
Added an integration test for filter measurements and further improve…
Browse files Browse the repository at this point in the history
…d the interface.
  • Loading branch information
mmcdermott committed Jul 24, 2024
1 parent 4d452d4 commit 7a3c2d9
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 45 deletions.
4 changes: 2 additions & 2 deletions src/MEDS_polars_functions/configs/preprocess.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defaults:
- filter_patients
- add_time_derived_measurements
- count_code_occurrences
- filter_codes
- filter_measurements
- fit_outlier_detection
- filter_outliers
- fit_normalization
Expand All @@ -26,7 +26,7 @@ stages:
- filter_patients
- add_time_derived_measurements
- preliminary_counts
- filter_codes
- filter_measurements
- fit_outlier_detection
- filter_outliers
- fit_normalization
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
filter_codes:
filter_measurements:
min_patients_per_code: null
min_occurrences_per_code: null
18 changes: 9 additions & 9 deletions src/MEDS_polars_functions/filters/filter_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
pl.enable_string_cache()


def filter_codes_fntr(
def filter_measurements_fntr(
stage_cfg: DictConfig, code_metadata: pl.LazyFrame, code_modifier_columns: list[str] | None = None
) -> Callable[[pl.LazyFrame], pl.LazyFrame]:
"""Returns a function that filters patient events to only encompass those with a set of permissible codes.
Expand All @@ -38,7 +38,7 @@ def filter_codes_fntr(
... "modifier1": [1, 1, 2, 2],
... }).lazy()
>>> stage_cfg = DictConfig({"min_patients_per_code": 2, "min_occurrences_per_code": 3})
>>> fn = filter_codes_fntr(stage_cfg, code_metadata_df, ["modifier1"])
>>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"])
>>> fn(data).collect()
shape: (2, 3)
┌────────────┬──────┬───────────┐
Expand All @@ -50,7 +50,7 @@ def filter_codes_fntr(
│ 1 ┆ B ┆ 1 │
└────────────┴──────┴───────────┘
>>> stage_cfg = DictConfig({"min_patients_per_code": 1, "min_occurrences_per_code": 4})
>>> fn = filter_codes_fntr(stage_cfg, code_metadata_df, ["modifier1"])
>>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"])
>>> fn(data).collect()
shape: (2, 3)
┌────────────┬──────┬───────────┐
Expand All @@ -62,7 +62,7 @@ def filter_codes_fntr(
│ 2 ┆ A ┆ 2 │
└────────────┴──────┴───────────┘
>>> stage_cfg = DictConfig({"min_patients_per_code": 1})
>>> fn = filter_codes_fntr(stage_cfg, code_metadata_df, ["modifier1"])
>>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"])
>>> fn(data).collect()
shape: (4, 3)
┌────────────┬──────┬───────────┐
Expand All @@ -76,7 +76,7 @@ def filter_codes_fntr(
│ 2 ┆ C ┆ 2 │
└────────────┴──────┴───────────┘
>>> stage_cfg = DictConfig({"min_patients_per_code": None, "min_occurrences_per_code": None})
>>> fn = filter_codes_fntr(stage_cfg, code_metadata_df, ["modifier1"])
>>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"])
>>> fn(data).collect()
shape: (4, 3)
┌────────────┬──────┬───────────┐
Expand All @@ -90,7 +90,7 @@ def filter_codes_fntr(
│ 2 ┆ C ┆ 2 │
└────────────┴──────┴───────────┘
>>> stage_cfg = DictConfig({"min_occurrences_per_code": 5})
>>> fn = filter_codes_fntr(stage_cfg, code_metadata_df, ["modifier1"])
>>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"])
>>> fn(data).collect()
shape: (1, 3)
┌────────────┬──────┬───────────┐
Expand Down Expand Up @@ -120,7 +120,7 @@ def filter_codes_fntr(

allowed_code_metadata = (code_metadata.filter(pl.all_horizontal(filter_exprs)).select(join_cols)).lazy()

def filter_codes_fn(df: pl.LazyFrame) -> pl.LazyFrame:
def filter_measurements_fn(df: pl.LazyFrame) -> pl.LazyFrame:
f"""Filters patient events to only encompass those with a set of permissible codes.
In particular, this function filters the DataFrame to only include (code, modifier) pairs that have
Expand All @@ -139,7 +139,7 @@ def filter_codes_fn(df: pl.LazyFrame) -> pl.LazyFrame:
.drop(idx_col)
)

return filter_codes_fn
return filter_measurements_fn


@hydra.main(
Expand All @@ -151,7 +151,7 @@ def main(cfg: DictConfig):
code_metadata = pl.read_parquet(
Path(cfg.stage_cfg.metadata_input_dir) / "code_metadata.parquet", use_pyarrow=True
)
compute_fn = filter_codes_fntr(cfg.stage_cfg, code_metadata)
compute_fn = filter_measurements_fntr(cfg.stage_cfg, code_metadata)

map_over(cfg, compute_fn=compute_fn)

Expand Down
59 changes: 26 additions & 33 deletions tests/transform_tester_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,17 @@
scripts.
"""

import json
import os
import tempfile
from io import StringIO
from pathlib import Path

import polars as pl
import rootutils

from .utils import assert_df_equal, parse_meds_csvs, run_command

root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True)

code_root = root / "src" / "MEDS_polars_functions"
Expand Down Expand Up @@ -43,18 +50,17 @@
TENSORIZE_SCRIPT = "MEDS_transform-tensorize"
TOKENIZE_SCRIPT = "MEDS_transform-tokenize"

import tempfile
from io import StringIO
from pathlib import Path

import polars as pl

from .utils import assert_df_equal, run_command

pl.enable_string_cache()

# Test MEDS data (inputs)

SPLITS = {
"train/0": [239684, 1195293],
"train/1": [68729, 814703],
"tuning/0": [754281],
"held_out/0": [1500733],
}

MEDS_TRAIN_0 = """
patient_id,timestamp,code,numerical_value
239684,,EYE_COLOR//BROWN,
Expand Down Expand Up @@ -133,31 +139,14 @@
1500733,"06/03/2010, 16:44:26",DISCHARGE,
"""

# TODO: Make use meds library
MEDS_PL_SCHEMA = {
"patient_id": pl.UInt32,
"timestamp": pl.Datetime("us"),
"code": pl.Categorical,
"numerical_value": pl.Float32,
}
MEDS_CSV_TS_FORMAT = "%m/%d/%Y, %H:%M:%S"


def parse_meds_csv(csv_str: str) -> pl.DataFrame:
read_schema = {**MEDS_PL_SCHEMA}
read_schema["timestamp"] = pl.Utf8

return pl.read_csv(StringIO(csv_str), schema=read_schema).with_columns(
pl.col("timestamp").str.strptime(MEDS_PL_SCHEMA["timestamp"], MEDS_CSV_TS_FORMAT)
)


MEDS_SHARDS = {
"train/0": parse_meds_csv(MEDS_TRAIN_0),
"train/1": parse_meds_csv(MEDS_TRAIN_1),
"tuning/0": parse_meds_csv(MEDS_TUNING_0),
"held_out/0": parse_meds_csv(MEDS_HELD_OUT_0),
}
MEDS_SHARDS = parse_meds_csvs(
{
"train/0": MEDS_TRAIN_0,
"train/1": MEDS_TRAIN_1,
"tuning/0": MEDS_TUNING_0,
"held_out/0": MEDS_HELD_OUT_0,
}
)


MEDS_CODE_METADATA_CSV = """
Expand Down Expand Up @@ -225,6 +214,10 @@ def single_stage_transform_tester(
MEDS_dir.mkdir()
cohort_dir.mkdir()

# Write the splits
splits_fp = MEDS_dir / "splits.json"
splits_fp.write_text(json.dumps(SPLITS))

# Write the shards
for shard_name, df in MEDS_SHARDS.items():
fp = MEDS_dir / f"{shard_name}.parquet"
Expand Down
33 changes: 33 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,42 @@
import subprocess
from io import StringIO
from pathlib import Path

import polars as pl
from polars.testing import assert_frame_equal

DEFAULT_CSV_TS_FORMAT = "%m/%d/%Y, %H:%M:%S"

# TODO: Make use meds library
MEDS_PL_SCHEMA = {
"patient_id": pl.UInt32,
"timestamp": pl.Datetime("us"),
"code": pl.Categorical,
"numerical_value": pl.Float32,
}


def parse_meds_csvs(
csvs: str | dict[str, str], schema: dict[str, pl.DataType] = MEDS_PL_SCHEMA
) -> pl.DataFrame | dict[str, pl.DataFrame]:
"""Converts a string or dict of named strings to a MEDS DataFrame by interpreting them as CSVs.
TODO: doctests.
"""

read_schema = {**schema}
read_schema["timestamp"] = pl.Utf8

def reader(csv_str: str) -> pl.DataFrame:
return pl.read_csv(StringIO(csv_str), schema=read_schema).with_columns(
pl.col("timestamp").str.strptime(MEDS_PL_SCHEMA["timestamp"], DEFAULT_CSV_TS_FORMAT)
)

if isinstance(csvs, str):
return reader(csvs)
else:
return {k: reader(v) for k, v in csvs.items()}


def dict_to_hydra_kwargs(d: dict[str, str]) -> str:
"""Converts a dictionary to a hydra kwargs string for testing purposes.
Expand Down

0 comments on commit 7a3c2d9

Please sign in to comment.