Skip to content

Commit

Permalink
Merge pull request #121 from Oufattole/aces_store_windows
Browse files Browse the repository at this point in the history
added storage functionality for raw windows
  • Loading branch information
mmcdermott authored Aug 28, 2024
2 parents 28bf935 + 41f8bde commit 3006041
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 2 deletions.
4 changes: 3 additions & 1 deletion src/aces/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def main(cfg: DictConfig):
logger.warning("Output dataframe is empty; adding an empty patient ID column.")
result = result.with_columns(pl.lit(None, dtype=pl.Int64).alias("patient_id"))
result = result.head(0)

if cfg.window_stats_dir:
Path(cfg.window_stats_filepath).parent.mkdir(exist_ok=True, parents=True)
result.write_parquet(cfg.window_stats_filepath)
result = get_and_validate_label_schema(result)
pq.write_table(result, cfg.output_filepath)
else:
Expand Down
3 changes: 3 additions & 0 deletions src/aces/configs/aces.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ config_path: ${cohort_dir}/${cohort_name}.yaml
# Path to store the output file. The `${data._prefix}` addition allows us to add shard specific prefixes in a
# sharded data mode.
output_filepath: ${cohort_dir}/${cohort_name}${data._prefix}.parquet
# Optional path to store the output file with the raw window data.
window_stats_dir: null
window_stats_filepath: ${window_stats_dir}/${cohort_name}${data._prefix}.parquet

log_dir: ${cohort_dir}/${cohort_name}/.logs

Expand Down
268 changes: 267 additions & 1 deletion tests/test_meds.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,23 @@

root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True)

import tempfile
from io import StringIO
from pathlib import Path

import polars as pl
import pyarrow as pa
from loguru import logger
from meds import label_schema
from yaml import load as load_yaml

from .utils import cli_test
from .utils import (
assert_df_equal,
cli_test,
run_command,
write_input_files,
write_task_configs,
)

try:
from yaml import CLoader as Loader
Expand Down Expand Up @@ -360,6 +368,175 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]:
"""
)

WANT_EMPTY_WINDOW_SCHEMA = {"patient_id": pl.Int64}
WANT_NON_EMPTY_WINDOW_SCHEMA = {
"patient_id": pl.UInt32,
"prediction_time": pl.Datetime,
"boolean_value": pl.Int64,
"trigger": pl.Datetime,
"input.end_summary": pl.Struct(
[
pl.Field("window_name", pl.Utf8),
pl.Field("timestamp_at_start", pl.Datetime),
pl.Field("timestamp_at_end", pl.Datetime),
pl.Field("admission", pl.Int64),
pl.Field("discharge", pl.Int64),
pl.Field("death", pl.Int64),
pl.Field("discharge_or_death", pl.Int64),
pl.Field("_ANY_EVENT", pl.Int64),
]
),
"input.start_summary": pl.Struct(
[
pl.Field("window_name", pl.Utf8),
pl.Field("timestamp_at_start", pl.Datetime),
pl.Field("timestamp_at_end", pl.Datetime),
pl.Field("admission", pl.Int64),
pl.Field("discharge", pl.Int64),
pl.Field("death", pl.Int64),
pl.Field("discharge_or_death", pl.Int64),
pl.Field("_ANY_EVENT", pl.Int64),
]
),
"gap.end_summary": pl.Struct(
[
pl.Field("window_name", pl.Utf8),
pl.Field("timestamp_at_start", pl.Datetime),
pl.Field("timestamp_at_end", pl.Datetime),
pl.Field("admission", pl.Int64),
pl.Field("discharge", pl.Int64),
pl.Field("death", pl.Int64),
pl.Field("discharge_or_death", pl.Int64),
pl.Field("_ANY_EVENT", pl.Int64),
]
),
"target.end_summary": pl.Struct(
[
pl.Field("window_name", pl.Utf8),
pl.Field("timestamp_at_start", pl.Datetime),
pl.Field("timestamp_at_end", pl.Datetime),
pl.Field("admission", pl.Int64),
pl.Field("discharge", pl.Int64),
pl.Field("death", pl.Int64),
pl.Field("discharge_or_death", pl.Int64),
pl.Field("_ANY_EVENT", pl.Int64),
]
),
}

WANT_TRAIN_WINDOW_DATA = """
[
{
"patient_id": 4,
"prediction_time": "1991-01-28 23:32:00",
"boolean_value": 0,
"trigger": "1991-01-27 23:32:00",
"input.end_summary": {
"window_name": "input.end",
"timestamp_at_start": "1991-01-27 23:32:00",
"timestamp_at_end": "1991-01-28 23:32:00",
"admission": 0,
"discharge": 0,
"death": 0,
"discharge_or_death": 0,
"_ANY_EVENT": 4
},
"input.start_summary": {
"window_name": "input.start",
"timestamp_at_start": "1989-12-01 12:03:00",
"timestamp_at_end": "1991-01-28 23:32:00",
"admission": 2,
"discharge": 1,
"death": 0,
"discharge_or_death": 1,
"_ANY_EVENT": 16
},
"gap.end_summary": {
"window_name": "gap.end",
"timestamp_at_start": "1991-01-27 23:32:00",
"timestamp_at_end": "1991-01-29 23:32:00",
"admission": 0,
"discharge": 0,
"death": 0,
"discharge_or_death": 0,
"_ANY_EVENT": 5
},
"target.end_summary": {
"window_name": "target.end",
"timestamp_at_start": "1991-01-29 23:32:00",
"timestamp_at_end": "1991-01-31 02:15:00",
"admission": 0,
"discharge": 1,
"death": 0,
"discharge_or_death": 1,
"_ANY_EVENT": 7
}
}
]
"""

WANT_HELD_OUT_WINDOW_DATA = """
[
{
"patient_id": 1,
"prediction_time": "1991-01-28 23:32:00",
"boolean_value": 0,
"trigger": "1991-01-27 23:32:00",
"input.end_summary": {
"window_name": "input.end",
"timestamp_at_start": "1991-01-27 23:32:00",
"timestamp_at_end": "1991-01-28 23:32:00",
"admission": 0,
"discharge": 0,
"death": 0,
"discharge_or_death": 0,
"_ANY_EVENT": 4
},
"input.start_summary": {
"window_name": "input.start",
"timestamp_at_start": "1989-12-01 12:03:00",
"timestamp_at_end": "1991-01-28 23:32:00",
"admission": 2,
"discharge": 1,
"death": 0,
"discharge_or_death": 1,
"_ANY_EVENT": 16
},
"gap.end_summary": {
"window_name": "gap.end",
"timestamp_at_start": "1991-01-27 23:32:00",
"timestamp_at_end": "1991-01-29 23:32:00",
"admission": 0,
"discharge": 0,
"death": 0,
"discharge_or_death": 0,
"_ANY_EVENT": 5
},
"target.end_summary": {
"window_name": "target.end",
"timestamp_at_start": "1991-01-29 23:32:00",
"timestamp_at_end": "1991-01-31 02:15:00",
"admission": 0,
"discharge": 1,
"death": 0,
"discharge_or_death": 1,
"_ANY_EVENT": 7
}
}
]
"""


WANT_WINDOW_SHARDS = {
"train/0.parquet": pl.DataFrame({}, schema=WANT_EMPTY_WINDOW_SCHEMA),
"train/1.parquet": pl.read_json(StringIO(WANT_TRAIN_WINDOW_DATA), schema=WANT_NON_EMPTY_WINDOW_SCHEMA),
"held_out/0/0.parquet": pl.DataFrame({}, schema=WANT_EMPTY_WINDOW_SCHEMA),
"empty_shard.parquet": pl.DataFrame({}, schema=WANT_EMPTY_WINDOW_SCHEMA),
"held_out.parquet": pl.read_json(
StringIO(WANT_HELD_OUT_WINDOW_DATA), schema=WANT_NON_EMPTY_WINDOW_SCHEMA
),
}


def test_meds():
cli_test(
Expand All @@ -368,3 +545,92 @@ def test_meds():
want_outputs_by_task={TASK_NAME: WANT_SHARDS},
data_standard="meds",
)


def test_meds_window_storage():
input_files = MEDS_SHARDS
task = TASK_NAME
want_outputs_by_task = {TASK_NAME: WANT_SHARDS}
data_standard = "meds"

with tempfile.TemporaryDirectory() as root_dir:
root_dir = Path(root_dir)
data_dir = root_dir / "sample_data" / "data"
cohort_dir = root_dir / "sample_cohort"

wrote_files = write_input_files(data_dir, input_files)
assert len(wrote_files) > 1, "No input files were written."
sharded = True
command = "aces-cli --multirun"

wrote_configs = write_task_configs(cohort_dir, {TASK_NAME: TASK_CFG})
if len(wrote_configs) == 0:
raise ValueError("No task configs were written.")

want_outputs = {
cohort_dir / task / f"{n}.parquet": df for n, df in want_outputs_by_task[task].items()
}
window_dir = Path(cohort_dir / "window_stats")
want_window_outputs = {
window_dir / task / filename: want_df for filename, want_df in WANT_WINDOW_SHARDS.items()
}

extraction_config_kwargs = {
"cohort_dir": str(cohort_dir.resolve()),
"cohort_name": task,
"hydra.verbose": True,
"data.standard": data_standard,
"window_stats_dir": str(window_dir.resolve()),
}

if len(wrote_files) > 1:
extraction_config_kwargs["data"] = "sharded"
extraction_config_kwargs["data.root"] = str(data_dir.resolve())
extraction_config_kwargs['"data.shard'] = f'$(expand_shards {str(data_dir.resolve())})"'
else:
extraction_config_kwargs["data.path"] = str(list(wrote_files.values())[0].resolve())

stderr, stdout = run_command(command, extraction_config_kwargs, f"CLI should run for {task}")

try:
if sharded:
out_dir = cohort_dir / task
all_out_fps = list(out_dir.glob("**/*.parquet"))
all_out_fps_str = ", ".join(str(x.relative_to(out_dir)) for x in all_out_fps)
if len(all_out_fps) == 0 and len(want_outputs) > 0:
all_directory_contents = ", ".join(
str(x.relative_to(cohort_dir)) for x in cohort_dir.glob("**/*")
)

raise AssertionError(
f"No output files found for task '{task}'. Found files: {all_directory_contents}"
)

assert len(all_out_fps) == len(
want_outputs
), f"Expected {len(want_outputs)} outputs, got {len(all_out_fps)}: {all_out_fps_str}"

for want_fp, want_df in want_outputs.items():
out_shard = want_fp.relative_to(cohort_dir)
assert want_fp.is_file(), f"Expected {out_shard} to exist."

got_df = pl.read_parquet(want_fp)
assert_df_equal(
want_df, got_df, f"Data mismatch for shard '{out_shard}':\n{want_df}\n{got_df}"
)
assert window_dir.exists(), f"Expected window stats directory {window_dir} to exist."
out_fps = list(window_dir.glob("**/*.parquet"))
assert len(out_fps) == len(
want_window_outputs
), f"Expected {len(want_window_outputs)} window output files, got {len(out_fps)}"

for want_fp, want_df in want_window_outputs.items():
out_shard = want_fp.relative_to(window_dir)
assert want_fp.is_file(), f"Expected {out_shard} to exist."
got_df = pl.read_parquet(want_fp)
assert_df_equal(
want_df, got_df, f"Data mismatch for window shard '{out_shard}':\n{want_df}\n{got_df}"
)
except AssertionError as e:
logger.error(f"{stderr}\n{stdout}")
raise AssertionError(f"Error running task '{task}': {e}") from e

0 comments on commit 3006041

Please sign in to comment.