From d4f66c53c40a635e73a376e74c31faffe9e5ccca Mon Sep 17 00:00:00 2001 From: Justin Xu <52216145+justin13601@users.noreply.github.com> Date: Sun, 28 Jul 2024 02:30:40 -0700 Subject: [PATCH 01/11] Added more flexibility for expand shards, as well as reading dir as glob, closes #77 (#81) --- src/aces/expand_shards.py | 49 ++++++++++++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/src/aces/expand_shards.py b/src/aces/expand_shards.py index cf0e5cd..9638092 100755 --- a/src/aces/expand_shards.py +++ b/src/aces/expand_shards.py @@ -1,30 +1,67 @@ #!/usr/bin/env python +import glob +import os +import re import sys def expand_shards(*shards: str) -> str: - """This function expands a set of shard prefixes and number of shards into a list of all shards. + """This function expands a set of shard prefixes and number of shards into a list of all shards or expands + a directory into a list of all files within it. This can be useful with Hydra applications where you wish to expand a list of options for the sweeper to sweep over but can't use an OmegaConf resolver as those are evaluated after the sweep has been initialized. Args: - shards: A list of shard prefixes and number of shards to expand. + shards: A list of shard prefixes and number of shards to expand, or a directory to list all files. - Returns: A comma-separated list of all shards, expanded to the specified number. + Returns: A comma-separated list of all shards, expanded to the specified number, or all files in the + directory. Examples: + >>> import polars as pl + >>> import tempfile + >>> from pathlib import Path + >>> expand_shards("train/4", "val/IID/1", "val/prospective/1") 'train/0,train/1,train/2,train/3,val/IID/0,val/prospective/0' + >>> expand_shards("data/data_4", "data/test_4") + 'data/data_0,data/data_1,data/data_2,data/data_3,data/test_0,data/test_1,data/test_2,data/test_3' + + >>> parquet_data = pl.DataFrame({ + ... "patient_id": [1, 1, 1, 2, 3], + ... "timestamp": ["1/1/1989 00:00", "1/1/1989 01:00", "1/1/1989 01:00", "1/1/1989 02:00", None], + ... "code": ['admission', 'discharge', 'discharge', 'admission', "gender"], + ... }).with_columns(pl.col("timestamp").str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M")) + + >>> with tempfile.TemporaryDirectory() as tmpdirname: + ... for i in range(4): + ... with tempfile.NamedTemporaryFile(mode="w", suffix=".parquet") as f: + ... data_path = Path(tmpdirname + f"/file_{i}") + ... parquet_data.write_parquet(data_path) + ... result = expand_shards(tmpdirname) + ... ','.join(sorted(os.path.basename(f) for f in result.split(','))) + 'file_0,file_1,file_2,file_3' """ result = [] for arg in shards: - prefix = arg[: arg.rfind("/")] - num = int(arg[arg.rfind("/") + 1 :]) - result.extend(f"{prefix}/{i}" for i in range(num)) + if os.path.isdir(arg): + # If the argument is a directory, list all files in the directory + files = glob.glob(os.path.join(arg, "*")) + result.extend(files) + else: + # Otherwise, treat it as a shard prefix and number of shards + match = re.match(r"(.+)([/_])(\d+)$", arg) + if match: + prefix = match.group(1) + delimiter = match.group(2) + num = int(match.group(3)) + result.extend(f"{prefix}{delimiter}{i}" for i in range(num)) + else: + raise ValueError(f"Invalid shard format: {arg}") return ",".join(result) From 3817bd47cfadfb455c2e4d58b9461c107af143ae Mon Sep 17 00:00:00 2001 From: Justin Xu <52216145+justin13601@users.noreply.github.com> Date: Sun, 28 Jul 2024 02:32:28 -0700 Subject: [PATCH 02/11] Sample hypothesis testing (#64) Co-authored-by: Matthew McDermott --- pyproject.toml | 2 +- tests/test_aggregate_hypothesis.py | 117 +++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+), 1 deletion(-) create mode 100644 tests/test_aggregate_hypothesis.py diff --git a/pyproject.toml b/pyproject.toml index 5973c33..0612da6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ build-backend = "setuptools.build_meta" [project.optional-dependencies] dev = [ - "pre-commit", "pytest", "pytest-cov", "pytest-subtests", "rootutils" + "pre-commit", "pytest", "pytest-cov", "pytest-subtests", "rootutils", "hypothesis" ] profiling = ["psutil"] diff --git a/tests/test_aggregate_hypothesis.py b/tests/test_aggregate_hypothesis.py new file mode 100644 index 0000000..2c4828e --- /dev/null +++ b/tests/test_aggregate_hypothesis.py @@ -0,0 +1,117 @@ +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + +from datetime import datetime, timedelta + +import polars as pl +import polars.selectors as cs +from hypothesis import given, settings +from hypothesis import strategies as st +from polars.testing import assert_series_equal +from polars.testing.parametric import column, dataframes + +from aces.aggregate import aggregate_temporal_window +from aces.types import TemporalWindowBounds + +datetime_st = st.datetimes(min_value=datetime(1989, 12, 1), max_value=datetime(1999, 12, 31)) + +N_PREDICATES = 5 +PREDICATE_DATAFRAMES = dataframes( + cols=[ + column("subject_id", allow_null=False, dtype=pl.UInt32), + column("timestamp", allow_null=False, dtype=pl.Datetime("ms"), strategy=datetime_st), + *[column(f"predicate_{i}", allow_null=False, dtype=pl.UInt8) for i in range(1, N_PREDICATES + 1)], + ], + min_size=1, + max_size=50, +) + + +@given( + df=PREDICATE_DATAFRAMES, + left_inclusive=st.booleans(), + right_inclusive=st.booleans(), + window_size=st.timedeltas(min_value=timedelta(days=1), max_value=timedelta(days=365 * 5)), + offset=st.timedeltas(min_value=timedelta(days=0), max_value=timedelta(days=365)), +) +@settings(max_examples=50) +def test_aggregate_temporal_window( + df: pl.DataFrame, left_inclusive: bool, right_inclusive: bool, window_size: timedelta, offset: timedelta +): + """Tests whether calling the `aggregate_temporal_window` function works produces a consistent output.""" + + max_N_subjects = 3 + df = df.with_columns( + (pl.col("subject_id") % max_N_subjects).alias("subject_id"), + cs.starts_with("predicate_").cast(pl.Int32).name.keep(), + ).sort("subject_id", "timestamp") + + endpoint_expr = TemporalWindowBounds( + left_inclusive=left_inclusive, right_inclusive=right_inclusive, window_size=window_size, offset=offset + ) + + # Should run: + agg_df = aggregate_temporal_window(df.lazy(), endpoint_expr) + assert agg_df is not None + agg_df = agg_df.collect() + + # This will return something of the below form: + # + # shape: (6, 7) + # ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐ + # │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │ + # │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + # │ i64 ┆ datetime[μs] ┆ datetime[μs] ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 │ + # ╞════════════╪═════════════════════╪═════════════════════╪═════════════════════╪══════╪══════╪══════╡ + # │ 1 ┆ 1989-12-01 12:03:00 ┆ 1989-12-02 12:03:00 ┆ 1989-12-01 12:03:00 ┆ 1 ┆ 1 ┆ 2 │ + # │ 1 ┆ 1989-12-02 05:17:00 ┆ 1989-12-03 05:17:00 ┆ 1989-12-02 05:17:00 ┆ 1 ┆ 1 ┆ 1 │ + # │ 1 ┆ 1989-12-02 12:03:00 ┆ 1989-12-03 12:03:00 ┆ 1989-12-02 12:03:00 ┆ 1 ┆ 0 ┆ 0 │ + # │ 1 ┆ 1989-12-06 11:00:00 ┆ 1989-12-07 11:00:00 ┆ 1989-12-06 11:00:00 ┆ 0 ┆ 1 ┆ 0 │ + # │ 2 ┆ 1989-12-01 13:14:00 ┆ 1989-12-02 13:14:00 ┆ 1989-12-01 13:14:00 ┆ 0 ┆ 1 ┆ 1 │ + # │ 2 ┆ 1989-12-03 15:17:00 ┆ 1989-12-04 15:17:00 ┆ 1989-12-03 15:17:00 ┆ 0 ┆ 0 ┆ 0 │ + # └────────────┴─────────────────────┴─────────────────────┴─────────────────────┴──────┴──────┴──────┘ + # + # We're going to validate this by asserting that the sums of the predicate columns between the rows + # for a given subject are consistent. + + assert set(df.columns).issubset(set(agg_df.columns)) + assert len(agg_df.columns) == len(df.columns) + 2 + assert "timestamp_at_start" in agg_df.columns + assert "timestamp_at_end" in agg_df.columns + assert_series_equal(agg_df["subject_id"], df["subject_id"]) + assert_series_equal(agg_df["timestamp"], df["timestamp"]) + + # Now we're going to validate the sums of the predicate columns between the rows for a given subject are + # consistent. + for subject_id in range(max_N_subjects): + if subject_id not in df["subject_id"]: + assert subject_id not in agg_df["subject_id"] + continue + + raw_subj = df.filter(pl.col("subject_id") == subject_id) + agg_subj = agg_df.filter(pl.col("subject_id") == subject_id) + + for row in agg_subj.iter_rows(named=True): + start = row["timestamp_at_start"] + end = row["timestamp_at_end"] + + if left_inclusive: + st_filter = pl.col("timestamp") >= start + else: + st_filter = pl.col("timestamp") > start + + if right_inclusive: + et_filter = pl.col("timestamp") <= end + else: + et_filter = pl.col("timestamp") < end + + raw_filtered = raw_subj.filter(st_filter & et_filter) + if len(raw_filtered) == 0: + for i in range(1, N_PREDICATES + 1): + # TODO: Is this right? Or should it always be one or the other? + assert (row[f"predicate_{i}"] is None) or (row[f"predicate_{i}"] == 0) + else: + raw_sums = raw_filtered.select(cs.starts_with("predicate_")).sum() + for i in range(1, N_PREDICATES + 1): + assert raw_sums[f"predicate_{i}"].item() == row[f"predicate_{i}"] From 74d0abff9f8a51a2f0bd4e2469aafa673a4781af Mon Sep 17 00:00:00 2001 From: Justin Xu <52216145+justin13601@users.noreply.github.com> Date: Mon, 5 Aug 2024 03:48:21 -0700 Subject: [PATCH 03/11] =?UTF-8?q?TaskExtractorConfig=20accepts=20optional?= =?UTF-8?q?=20predicates=5Fpath=20to=20split=20predica=E2=80=A6=20(#82)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * TaskExtractorConfig accepts optional predicates_path to split predicates from main config_path (#42) * Extract trigger only (#85) * Make windows struct optional when loading config * Update spacing * Added patient_demographics and parsing of static variables (#48) (#86) * Added patient_demographics and parsing of static variables (#48) * Added filtering based on static predicates (#48) * Fix tests and column renames to match MEDS * Update e2e test * Added tests to boost coverage * Added more tests for coverage --- sample_configs/inhospital_mortality.yaml | 4 + sample_data/sample_data.csv | 107 +++---- src/aces/config.py | 294 ++++++++++++++++++-- src/aces/constraints.py | 67 +++++ src/aces/predicates.py | 337 ++++++++++++++--------- src/aces/query.py | 11 +- tests/test_check_static_variables.py | 55 ++++ tests/test_e2e.py | 168 +++++------ 8 files changed, 736 insertions(+), 307 deletions(-) create mode 100644 tests/test_check_static_variables.py diff --git a/sample_configs/inhospital_mortality.yaml b/sample_configs/inhospital_mortality.yaml index 0695f60..a8f7eae 100644 --- a/sample_configs/inhospital_mortality.yaml +++ b/sample_configs/inhospital_mortality.yaml @@ -9,6 +9,10 @@ predicates: discharge_or_death: expr: or(discharge, death) +patient_demographics: + male: + code: SEX//male + trigger: admission windows: diff --git a/sample_data/sample_data.csv b/sample_data/sample_data.csv index 2cd8bc9..33e4d05 100644 --- a/sample_data/sample_data.csv +++ b/sample_data/sample_data.csv @@ -1,52 +1,55 @@ -subject_id,timestamp,admission,death,discharge,lab,spo2,normal_spo2,abnormally_low_spo2,abnormally_high_spo2,procedure_start,procedure_end,ventilation,diagnosis_ICD9CM_41071,diagnosis_ICD10CM_I214 -1,12/1/1989 12:03,1,0,0,0,0,0,0,0,0,0,0,0,0 -1,12/1/1989 13:14,0,0,0,1,1,1,0,0,0,0,0,0,0 -1,12/1/1989 15:17,0,0,0,1,1,1,0,0,0,0,0,0,0 -1,12/1/1989 16:17,0,0,0,1,1,1,0,0,0,0,0,0,0 -1,12/1/1989 20:17,0,0,0,1,1,1,0,0,0,0,0,0,0 -1,12/2/1989 3:00,0,0,0,1,1,1,0,0,0,0,0,0,0 -1,12/2/1989 9:00,0,0,0,0,0,0,0,0,0,0,0,0,0 -1,12/2/1989 10:00,0,0,0,0,0,0,0,0,1,0,1,0,0 -1,12/2/1989 14:22,0,0,0,0,0,0,0,0,0,1,1,0,0 -1,12/2/1989 15:00,0,0,1,0,0,0,0,0,0,0,0,0,0 -1,1/21/1991 11:59,0,0,0,0,0,0,0,0,0,0,0,1,0 -1,1/27/1991 23:32,1,0,0,0,0,0,0,0,0,0,0,0,0 -1,1/27/1991 23:46,0,0,0,1,1,1,0,0,0,0,0,0,0 -1,1/28/1991 3:18,0,0,0,1,0,0,0,0,0,0,0,0,0 -1,1/28/1991 3:28,0,0,0,0,0,0,0,0,1,0,1,0,0 -1,1/28/1991 4:36,0,0,0,1,1,1,0,0,0,0,0,0,0 -1,1/29/1991 23:32,0,0,0,1,1,1,0,0,0,0,0,0,0 -1,1/30/1991 5:00,0,0,0,1,0,0,0,0,0,0,0,0,0 -1,1/30/1991 8:00,0,0,0,1,1,0,0,1,0,0,0,0,0 -1,1/30/1991 11:00,0,0,0,1,1,1,0,0,0,0,0,0,0 -1,1/30/1991 14:00,0,0,0,0,0,0,0,0,0,0,0,0,0 -1,1/30/1991 14:15,0,0,0,1,0,0,0,0,0,0,0,0,0 -1,1/31/1991 1:00,0,0,0,0,0,0,0,0,0,1,1,0,0 -1,1/31/1991 2:15,0,0,1,0,0,0,0,0,0,0,0,0,0 -1,2/8/1991 8:15,0,0,0,1,1,1,0,0,0,0,0,0,0 -1,3/3/1991 19:33,1,0,0,0,0,0,0,0,0,0,0,0,0 -1,3/3/1991 20:33,0,0,0,1,1,0,1,0,0,0,0,0,0 -1,3/3/1991 21:38,0,1,0,0,0,0,0,0,0,0,0,0,0 -2,3/8/1996 2:24,1,0,0,0,0,0,0,0,0,0,0,0,0 -2,3/8/1996 2:35,0,0,0,1,1,1,0,0,0,0,0,0,0 -2,3/8/1996 4:00,0,0,0,1,1,1,0,0,0,0,0,0,0 -2,3/8/1996 10:00,0,0,0,0,0,0,0,0,0,0,0,0,0 -2,3/8/1996 16:00,0,0,1,0,0,0,0,0,0,0,0,0,0 -2,6/5/1996 0:32,1,0,0,0,0,0,0,0,0,0,0,0,0 -2,6/5/1996 0:48,0,0,0,0,0,0,0,0,0,0,0,0,1 -2,6/5/1996 1:59,0,0,0,0,0,0,0,0,1,0,1,0,0 -2,6/7/1996 6:00,0,0,0,1,0,0,0,0,0,0,0,0,0 -2,6/7/1996 9:00,0,0,0,1,1,0,1,0,0,0,0,0,0 -2,6/7/1996 12:00,0,0,0,1,1,1,0,0,0,0,0,0,0 -2,6/7/1996 15:00,0,0,0,0,0,0,0,0,0,1,1,0,0 -2,6/7/1996 15:00,0,0,0,0,0,0,0,0,0,0,0,0,0 -2,6/8/1996 3:00,0,1,0,0,0,0,0,0,0,0,0,0,0 -3,3/8/1996 2:22,0,0,0,0,0,0,0,0,1,0,1,0,0 -3,3/8/1996 2:24,1,0,0,0,0,0,0,0,0,0,0,0,0 -3,3/8/1996 2:37,0,0,0,1,1,1,0,0,0,0,0,0,0 -3,3/9/1996 8:00,0,0,0,1,0,0,0,0,0,0,0,0,0 -3,3/9/1996 11:00,0,0,0,1,1,1,0,0,0,0,0,0,0 -3,3/9/1996 19:00,0,0,0,1,1,1,0,0,0,0,0,0,0 -3,3/9/1996 22:00,0,0,0,0,0,0,0,0,0,0,0,0,0 -3,3/11/1996 21:00,0,0,0,0,0,0,0,0,0,1,1,0,0 -3,3/12/1996 0:00,0,1,0,0,0,0,0,0,0,0,0,0,0 +subject_id,timestamp,male,female,admission,death,discharge,lab,spo2,normal_spo2,abnormally_low_spo2,abnormally_high_spo2,procedure_start,procedure_end,ventilation,diagnosis_ICD9CM_41071,diagnosis_ICD10CM_I214 +1,,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0 +1,12/1/1989 12:03,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0 +1,12/1/1989 13:14,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +1,12/1/1989 15:17,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +1,12/1/1989 16:17,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +1,12/1/1989 20:17,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +1,12/2/1989 3:00,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +1,12/2/1989 9:00,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 +1,12/2/1989 10:00,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0 +1,12/2/1989 14:22,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0 +1,12/2/1989 15:00,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0 +1,1/21/1991 11:59,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0 +1,1/27/1991 23:32,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0 +1,1/27/1991 23:46,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +1,1/28/1991 3:18,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0 +1,1/28/1991 3:28,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0 +1,1/28/1991 4:36,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +1,1/29/1991 23:32,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +1,1/30/1991 5:00,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0 +1,1/30/1991 8:00,0,0,0,0,0,1,1,0,0,1,0,0,0,0,0 +1,1/30/1991 11:00,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +1,1/30/1991 14:00,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 +1,1/30/1991 14:15,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0 +1,1/31/1991 1:00,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0 +1,1/31/1991 2:15,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0 +1,2/8/1991 8:15,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +1,3/3/1991 19:33,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0 +1,3/3/1991 20:33,0,0,0,0,0,1,1,0,1,0,0,0,0,0,0 +1,3/3/1991 21:38,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0 +2,,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0 +2,3/8/1996 2:24,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0 +2,3/8/1996 2:35,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +2,3/8/1996 4:00,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +2,3/8/1996 10:00,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 +2,3/8/1996 16:00,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0 +2,6/5/1996 0:32,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0 +2,6/5/1996 0:48,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1 +2,6/5/1996 1:59,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0 +2,6/7/1996 6:00,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0 +2,6/7/1996 9:00,0,0,0,0,0,1,1,0,1,0,0,0,0,0,0 +2,6/7/1996 12:00,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +2,6/7/1996 15:00,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0 +2,6/7/1996 15:00,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 +2,6/8/1996 3:00,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0 +3,,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0 +3,3/8/1996 2:22,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0 +3,3/8/1996 2:24,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0 +3,3/8/1996 2:37,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +3,3/9/1996 8:00,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0 +3,3/9/1996 11:00,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +3,3/9/1996 19:00,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +3,3/9/1996 22:00,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 +3,3/11/1996 21:00,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0 +3,3/12/1996 0:00,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0 diff --git a/src/aces/config.py b/src/aces/config.py index f1af2f9..39f71f0 100644 --- a/src/aces/config.py +++ b/src/aces/config.py @@ -33,6 +33,7 @@ class PlainPredicateConfig: value_max: float | None = None value_min_inclusive: bool | None = None value_max_inclusive: bool | None = None + static: bool = False def MEDS_eval_expr(self) -> pl.Expr: """Returns a Polars expression that evaluates this predicate for a MEDS formatted dataset. @@ -43,31 +44,35 @@ def MEDS_eval_expr(self) -> pl.Expr: Examples: >>> expr = PlainPredicateConfig("BP//systolic", 120, 140, True, False).MEDS_eval_expr() >>> print(expr) # doctest: +NORMALIZE_WHITESPACE - [(col("code")) == (String(BP//systolic))].all_horizontal([[(col("value")) >= - (dyn int: 120)], [(col("value")) < (dyn int: 140)]]) + [(col("code")) == (String(BP//systolic))].all_horizontal([[(col("numerical_value")) >= + (dyn int: 120)], [(col("numerical_value")) < (dyn int: 140)]]) >>> cfg = PlainPredicateConfig("BP//systolic", value_min=120, value_min_inclusive=False) >>> expr = cfg.MEDS_eval_expr() >>> print(expr) # doctest: +NORMALIZE_WHITESPACE - [(col("code")) == (String(BP//systolic))].all_horizontal([[(col("value")) > + [(col("code")) == (String(BP//systolic))].all_horizontal([[(col("numerical_value")) > (dyn int: 120)]]) + >>> cfg = PlainPredicateConfig("BP//systolic", value_max=140, value_max_inclusive=True) + >>> expr = cfg.MEDS_eval_expr() + >>> print(expr) # doctest: +NORMALIZE_WHITESPACE + [(col("code")) == (String(BP//systolic))].all_horizontal([[(col("numerical_value")) <= + (dyn int: 140)]]) >>> cfg = PlainPredicateConfig("BP//diastolic") >>> expr = cfg.MEDS_eval_expr() >>> print(expr) # doctest: +NORMALIZE_WHITESPACE [(col("code")) == (String(BP//diastolic))] """ - criteria = [pl.col("code") == self.code] if self.value_min is not None: if self.value_min_inclusive: - criteria.append(pl.col("value") >= self.value_min) + criteria.append(pl.col("numerical_value") >= self.value_min) else: - criteria.append(pl.col("value") > self.value_min) + criteria.append(pl.col("numerical_value") > self.value_min) if self.value_max is not None: if self.value_max_inclusive: - criteria.append(pl.col("value") <= self.value_max) + criteria.append(pl.col("numerical_value") <= self.value_max) else: - criteria.append(pl.col("value") < self.value_max) + criteria.append(pl.col("numerical_value") < self.value_max) if len(criteria) == 1: return criteria[0] @@ -95,6 +100,11 @@ def ESGPT_eval_expr(self, values_column: str | None = None) -> pl.Expr: >>> print(expr) # doctest: +NORMALIZE_WHITESPACE [(col("BP")) == (String(systolic))].all_horizontal([[(col("blood_pressure_value")) > (dyn int: 120)]]) + >>> cfg = PlainPredicateConfig("BP//systolic", value_max=140, value_max_inclusive=True) + >>> expr = cfg.ESGPT_eval_expr("blood_pressure_value") + >>> print(expr) # doctest: +NORMALIZE_WHITESPACE + [(col("BP")) == (String(systolic))].all_horizontal([[(col("blood_pressure_value")) <= + (dyn int: 140)]]) >>> expr = PlainPredicateConfig("BP//diastolic").ESGPT_eval_expr() >>> print(expr) # doctest: +NORMALIZE_WHITESPACE [(col("BP")) == (String(diastolic))] @@ -104,6 +114,20 @@ def ESGPT_eval_expr(self, values_column: str | None = None) -> pl.Expr: >>> expr = PlainPredicateConfig("BP//diastolic//atrial").ESGPT_eval_expr() >>> print(expr) # doctest: +NORMALIZE_WHITESPACE [(col("BP")) == (String(diastolic//atrial))] + >>> expr = PlainPredicateConfig("BP//diastolic", None, None).ESGPT_eval_expr() + >>> print(expr) # doctest: +NORMALIZE_WHITESPACE + [(col("BP")) == (String(diastolic))] + >>> expr = PlainPredicateConfig("BP").ESGPT_eval_expr() + >>> print(expr) # doctest: +NORMALIZE_WHITESPACE + col("BP").is_not_null() + >>> expr = PlainPredicateConfig("BP//systolic", value_min=120).ESGPT_eval_expr() + Traceback (most recent call last): + ... + ValueError: Must specify a values column for ESGPT predicates with a value_min = 120 + >>> expr = PlainPredicateConfig("BP//systolic", value_max=140).ESGPT_eval_expr() + Traceback (most recent call last): + ... + ValueError: Must specify a values column for ESGPT predicates with a value_max = 140 """ code_is_in_parts = "//" in self.code @@ -119,7 +143,7 @@ def ESGPT_eval_expr(self, values_column: str | None = None) -> pl.Expr: else: criteria = [pl.col(measurement_name) == code] elif (self.value_min is None) and (self.value_max is None): - return pl.col(code).is_not_null() + return pl.col(self.code).is_not_null() else: values_column = self.code criteria = [] @@ -176,9 +200,14 @@ class DerivedPredicateConfig: Traceback (most recent call last): ... ValueError: Derived predicate expression must start with 'and(' or 'or('. Got: 'PA + PB' + >>> pred = DerivedPredicateConfig("") + Traceback (most recent call last): + ... + ValueError: Derived predicates must have a non-empty expression field. """ expr: str + static: bool = False def __post_init__(self): if not self.expr: @@ -348,6 +377,23 @@ class WindowConfig: offset=datetime.timedelta(0)) >>> gap_window.root_node 'start' + >>> gap_window = WindowConfig( + ... start="input.end", + ... end="start + 0h", + ... start_inclusive=False, + ... end_inclusive=True, + ... has={"discharge": "(None, 0)", "death": "(None, 0)"} + ... ) + >>> gap_window.referenced_event + ('input', 'end') + >>> sorted(gap_window.referenced_predicates) + ['death', 'discharge'] + >>> gap_window.start_endpoint_expr is None + True + >>> gap_window.end_endpoint_expr is None # doctest: +NORMALIZE_WHITESPACE + True + >>> gap_window.root_node + 'start' >>> target_window = WindowConfig( ... start="gap.end", ... end="start -> discharge_or_death", @@ -368,6 +414,26 @@ class WindowConfig: offset=datetime.timedelta(0)) >>> target_window.root_node 'start' + >>> target_window = WindowConfig( + ... start="end", + ... end="gap.end <- discharge_or_death", + ... start_inclusive=False, + ... end_inclusive=True, + ... has={} + ... ) + >>> target_window.referenced_event + ('gap', 'end') + >>> sorted(target_window.referenced_predicates) + ['discharge_or_death'] + >>> target_window.start_endpoint_expr is None + True + >>> target_window.end_endpoint_expr # doctest: +NORMALIZE_WHITESPACE + ToEventWindowBounds(left_inclusive=False, + end_event='-discharge_or_death', + right_inclusive=False, + offset=datetime.timedelta(0)) + >>> target_window.root_node + 'end' >>> invalid_window = WindowConfig( ... start="gap.end gap.start", ... end="start -> discharge_or_death", @@ -382,6 +448,30 @@ class WindowConfig: '.start' or '.end'. Got: 'gap.end gap.start' >>> invalid_window = WindowConfig( + ... start="input", + ... end="start window -> discharge_or_death", + ... start_inclusive=False, + ... end_inclusive=True, + ... has={"discharge": "(None, 0)", "death": "(None, 0)"} + ... ) # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: Window boundary reference must be either a valid alphanumeric/'_' string or a reference + to another window's start or end event, formatted as a valid alphanumeric/'_' string, followed by + '.start' or '.end'. Got: 'start window' + >>> invalid_window = WindowConfig( + ... start="input", + ... end="window.foo -> discharge_or_death", + ... start_inclusive=False, + ... end_inclusive=True, + ... has={"discharge": "(None, 0)", "death": "(None, 0)"} + ... ) # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: Window boundary reference must be either a valid alphanumeric/'_' string or a reference + to another window's start or end event, formatted as a valid alphanumeric/'_' string, followed by + '.start' or '.end'. Got: 'window.foo' + >>> invalid_window = WindowConfig( ... start=None, end=None, start_inclusive=True, end_inclusive=True, has={} ... ) Traceback (most recent call last): @@ -479,7 +569,6 @@ class WindowConfig: @classmethod def _check_reference(cls, reference: str): """Checks to ensure referenced events are valid.""" - err_str = ( "Window boundary reference must be either a valid alphanumeric/'_' string " "or a reference to another window's start or end event, formatted as a valid " @@ -770,24 +859,27 @@ class TaskExtractorConfig: value_min=None, value_max=None, value_min_inclusive=None, - value_max_inclusive=None), + value_max_inclusive=None, + static=False), 'discharge': PlainPredicateConfig(code='discharge', value_min=None, value_max=None, value_min_inclusive=None, - value_max_inclusive=None), + value_max_inclusive=None, + static=False), 'death': PlainPredicateConfig(code='death', value_min=None, value_max=None, value_min_inclusive=None, - value_max_inclusive=None)} + value_max_inclusive=None, + static=False)} >>> print(config.label_window) # doctest: +NORMALIZE_WHITESPACE target >>> print(config.index_timestamp_window) # doctest: +NORMALIZE_WHITESPACE input >>> print(config.derived_predicates) # doctest: +NORMALIZE_WHITESPACE - {'death_or_discharge': DerivedPredicateConfig(expr='or(death, discharge)')} + {'death_or_discharge': DerivedPredicateConfig(expr='or(death, discharge)', static=False)} >>> print(nx.write_network_text(config.predicates_DAG)) ╟── death ╎ └─╼ death_or_discharge ╾ discharge @@ -799,16 +891,134 @@ class TaskExtractorConfig: ├── input.start └── gap.end └── target.end + + >>> config_path = "/foo/non_existent_file.yaml" + >>> cfg = TaskExtractorConfig.load(config_path) + Traceback (most recent call last): + ... + FileNotFoundError: Cannot load missing configuration file /foo/non_existent_file.yaml! + + >>> import tempfile + >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as f: + ... config_path = Path(f.name) + ... cfg = TaskExtractorConfig.load(config_path) + Traceback (most recent call last): + ... + ValueError: Only supports reading from '.yaml'. Got: '.txt' in ....txt'. + >>> predicates_path = "/foo/non_existent_predicates.yaml" + >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f: + ... config_path = Path(f.name) + ... cfg = TaskExtractorConfig.load(config_path, predicates_path) + Traceback (most recent call last): + ... + FileNotFoundError: Cannot load missing predicates file /foo/non_existent_predicates.yaml! + >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as f: + ... predicates_path = Path(f.name) + ... with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f2: + ... config_path = Path(f2.name) + ... cfg = TaskExtractorConfig.load(config_path, predicates_path) + Traceback (most recent call last): + ... + ValueError: Only supports reading from '.yaml'. Got: '.txt' in ....txt'. + >>> import yaml + >>> data = { + ... 'predicates': {}, + ... 'trigger': {}, + ... 'foo': {} + ... } + >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f: + ... config_path = Path(f.name) + ... yaml.dump(data, f) + ... cfg = TaskExtractorConfig.load(config_path) + Traceback (most recent call last): + ... + ValueError: Unrecognized keys in configuration file: 'foo' + >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f: + ... predicates_path = Path(f.name) + ... yaml.dump(data, f) + ... with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f2: + ... config_path = Path(f2.name) + ... cfg = TaskExtractorConfig.load(config_path, predicates_path) + Traceback (most recent call last): + ... + ValueError: Unrecognized keys in configuration file: 'foo, trigger' + + >>> predicates = {"foo bar": PlainPredicateConfig("foo")} + >>> trigger = EventConfig("foo") + >>> config = TaskExtractorConfig(predicates=predicates, trigger=trigger, windows={}) + Traceback (most recent call last): + ... + ValueError: Predicate name 'foo bar' is invalid; must be composed of alphanumeric or '_' characters. + >>> predicates = {"foo": str("foo")} + >>> trigger = EventConfig("foo") + >>> config = TaskExtractorConfig(predicates=predicates, trigger=trigger, windows={}) + ... # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: Invalid predicate configuration for 'foo': foo. Must be either a PlainPredicateConfig or + DerivedPredicateConfig object. Got: + >>> predicates = { + ... "foo": PlainPredicateConfig("foo"), + ... "foobar": DerivedPredicateConfig("or(foo, bar)"), + ... } + >>> trigger = EventConfig("foo") + >>> config = TaskExtractorConfig(predicates=predicates, trigger=trigger, windows={}) + Traceback (most recent call last): + ... + KeyError: "Missing 1 relationships:\\nDerived predicate 'foobar' references undefined predicate 'bar'" + + >>> predicates = {"foo": PlainPredicateConfig("foo")} + >>> trigger = EventConfig("foo") + >>> windows = {"foo bar": WindowConfig("gap.end", "start + 24h", True, True)} + >>> config = TaskExtractorConfig(predicates=predicates, trigger=trigger, windows=windows) + Traceback (most recent call last): + ... + ValueError: Window name 'foo bar' is invalid; must be composed of alphanumeric or '_' characters. + >>> windows = {"foo": WindowConfig("gap.end", "start + 24h", True, True, {}, "bar")} + >>> config = TaskExtractorConfig(predicates=predicates, trigger=trigger, windows=windows) + Traceback (most recent call last): + ... + ValueError: Label must be one of the defined predicates. Got: bar for window 'foo' + >>> windows = {"foo": WindowConfig("gap.end", "start + 24h", True, True, {}, "foo", "bar")} + >>> config = TaskExtractorConfig(predicates=predicates, trigger=trigger, windows=windows) + Traceback (most recent call last): + ... + ValueError: Index timestamp must be either 'start' or 'end'. Got: bar for window 'foo' + >>> windows = { + ... "foo": WindowConfig("gap.end", "start + 24h", True, True, {}, "foo"), + ... "bar": WindowConfig("gap.end", "start + 24h", True, True, {}, "foo") + ... } + >>> config = TaskExtractorConfig(predicates=predicates, trigger=trigger, windows=windows) + Traceback (most recent call last): + ... + ValueError: Only one window can be labeled, found 2 labeled windows: foo, bar + >>> windows = { + ... "foo": WindowConfig("gap.end", "start + 24h", True, True, {}, "foo", "start"), + ... "bar": WindowConfig("gap.end", "start + 24h", True, True, {}, index_timestamp="start") + ... } + >>> config = TaskExtractorConfig(predicates=predicates, trigger=trigger, windows=windows) + ... # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: Only the 'start'/'end' of one window can be used as the index timestamp, found + 2 windows with index_timestamp: foo, bar + + >>> predicates = {"foo": PlainPredicateConfig("foo")} + >>> trigger = EventConfig("bar") + >>> config = TaskExtractorConfig(predicates=predicates, trigger=trigger, windows={}) + Traceback (most recent call last): + ... + KeyError: "Trigger event predicate 'bar' not found in predicates: foo" """ predicates: dict[str, PlainPredicateConfig | DerivedPredicateConfig] trigger: EventConfig - windows: dict[str, WindowConfig] + windows: dict[str, WindowConfig] | None label_window: str | None = None index_timestamp_window: str | None = None @classmethod - def load(cls, config_path: str | Path) -> TaskExtractorConfig: + def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> TaskExtractorConfig: """Load a configuration file from the given path and return it as a dict. Args: @@ -829,12 +1039,43 @@ def load(cls, config_path: str | Path) -> TaskExtractorConfig: loaded_dict = yaml.load(config_path.read_text()) else: raise ValueError( - f"Only supports reading from '.yaml' files currently. Got: '{config_path.suffix}'" + f"Only supports reading from '.yaml'. Got: '{config_path.suffix}' in '{config_path.name}'." ) - predicates = loaded_dict.pop("predicates") + if predicates_path: + if isinstance(predicates_path, str): + predicates_path = Path(predicates_path) + + if not predicates_path.is_file(): + raise FileNotFoundError( + f"Cannot load missing predicates file {str(predicates_path.resolve())}!" + ) + + if predicates_path.suffix == ".yaml": + yaml = ruamel.yaml.YAML(typ="safe", pure=True) + predicates_dict = yaml.load(predicates_path.read_text()) + else: + raise ValueError( + f"Only supports reading from '.yaml'. Got: '{predicates_path.suffix}' in " + f"'{predicates_path.name}'." + ) + + predicates = predicates_dict.pop("predicates") + patient_demographics = predicates_dict.pop("patient_demographics", None) + + # Remove the description if it exists - currently unused except for readability in the YAML + _ = predicates_dict.pop("description", None) + + if predicates_dict: + raise ValueError( + f"Unrecognized keys in configuration file: '{', '.join(predicates_dict.keys())}'" + ) + else: + predicates = loaded_dict.pop("predicates") + patient_demographics = loaded_dict.pop("patient_demographics", None) + trigger = loaded_dict.pop("trigger") - windows = loaded_dict.pop("windows") + windows = loaded_dict.pop("windows", None) # Remove the description if it exists - currently unused except for readability in the YAML _ = loaded_dict.pop("description", None) @@ -848,11 +1089,24 @@ def load(cls, config_path: str | Path) -> TaskExtractorConfig: for n, p in predicates.items() } + if patient_demographics: + logger.info("Parsing patient demographics...") + patient_demographics = { + n: PlainPredicateConfig(**p, static=True) for n, p in patient_demographics.items() + } + predicates.update(patient_demographics) + logger.info("Parsing trigger event...") trigger = EventConfig(trigger) logger.info("Parsing windows...") - windows = {n: WindowConfig(**w) for n, w in windows.items()} + if windows is None: + windows = {} + logger.warning( + "No windows specified in configuration file. Extracting only matching trigger events." + ) + else: + windows = {n: WindowConfig(**w) for n, w in windows.items()} return cls(predicates=predicates, trigger=trigger, windows=windows) diff --git a/src/aces/constraints.py b/src/aces/constraints.py index c1e7f69..0f09bb2 100644 --- a/src/aces/constraints.py +++ b/src/aces/constraints.py @@ -98,3 +98,70 @@ def check_constraints( should_drop = should_drop | drop_expr return summary_df.filter(~should_drop) + + +def check_static_variables(patient_demographics: list[str], predicates_df: pl.DataFrame) -> pl.DataFrame: + """Checks the constraints on the counts of predicates in the summary dataframe. + + Args: + patient_demographics: List of columns representing static patient demographics. + predicates_df: Dataframe containing a row for each event with patient demographics and timestamps. + + Returns: A filtered dataframe containing only the rows that satisfy the patient demographics. + + Raises: + ValueError: If the static predicate used by constraint is not in the predicates dataframe. + + Examples: + >>> from datetime import datetime + >>> predicates_df = pl.DataFrame({ + ... "subject_id": [1, 1, 1, 1, 1, 2, 2, 2], + ... "timestamp": [ + ... # Subject 1 + ... None, + ... datetime(year=1989, month=12, day=1, hour=12, minute=3), + ... datetime(year=1989, month=12, day=2, hour=5, minute=17), + ... datetime(year=1989, month=12, day=2, hour=12, minute=3), + ... datetime(year=1989, month=12, day=6, hour=11, minute=0), + ... # Subject 2 + ... None, + ... datetime(year=1989, month=12, day=1, hour=13, minute=14), + ... datetime(year=1989, month=12, day=3, hour=15, minute=17), + ... ], + ... "is_A": [0, 1, 4, 1, 0, 3, 3, 3], + ... "is_B": [0, 0, 2, 0, 0, 2, 10, 2], + ... "is_C": [0, 1, 1, 1, 0, 0, 1, 1], + ... "male": [1, 0, 0, 0, 0, 0, 0, 0] + ... }) + + >>> check_static_variables(['male'], predicates_df) + shape: (4, 5) + ┌────────────┬─────────────────────┬──────┬──────┬──────┐ + │ subject_id ┆ timestamp ┆ is_A ┆ is_B ┆ is_C │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 │ + ╞════════════╪═════════════════════╪══════╪══════╪══════╡ + │ 1 ┆ 1989-12-01 12:03:00 ┆ 1 ┆ 0 ┆ 1 │ + │ 1 ┆ 1989-12-02 05:17:00 ┆ 4 ┆ 2 ┆ 1 │ + │ 1 ┆ 1989-12-02 12:03:00 ┆ 1 ┆ 0 ┆ 1 │ + │ 1 ┆ 1989-12-06 11:00:00 ┆ 0 ┆ 0 ┆ 0 │ + └────────────┴─────────────────────┴──────┴──────┴──────┘ + """ + for demographic in patient_demographics: + if demographic not in predicates_df.columns: + raise ValueError(f"Static predicate '{demographic}' not found in the predicates dataframe.") + + keep_expr = ((pl.col("timestamp").is_null()) & (pl.col(demographic) == 1)).alias("keep_expr") + + exclude_expr = ~keep_expr + exclude_count = predicates_df.filter(exclude_expr).shape[0] + + logger.info(f"Excluding {exclude_count:,} rows due to the '{demographic}' criteria.") + + predicates_df = predicates_df.filter( + pl.col("subject_id").is_in(predicates_df.filter(keep_expr).select("subject_id").unique()) + ) + + return predicates_df.drop_nulls(subset=["timestamp"]).drop( + *[x for x in patient_demographics if x in predicates_df.columns] + ) diff --git a/src/aces/predicates.py b/src/aces/predicates.py index bac82d2..862b678 100644 --- a/src/aces/predicates.py +++ b/src/aces/predicates.py @@ -35,25 +35,31 @@ def direct_load_plain_predicates( Example: >>> import tempfile >>> CSV_data = pl.DataFrame({ - ... "subject_id": [1, 1, 2], - ... "timestamp": ["01/01/2021 00:00", "01/01/2021 12:00", "01/02/2021 00:00"], - ... "is_admission": [1, 0, 1], - ... "is_discharge": [0, 1, 0], + ... "subject_id": [1, 1, 1, 1, 2, 2], + ... "timestamp": [None, "01/01/2021 00:00", None, "01/01/2021 12:00", "01/02/2021 00:00", None], + ... "is_admission": [0, 1, 0, 0, 1, 0], + ... "is_discharge": [0, 0, 0, 1, 0, 0], + ... "is_male": [1, 0, 0, 0, 0, 0], + ... "is_female": [0, 0, 0, 0, 0, 1], + ... "brown_eyes": [0, 0, 1, 0, 0, 0], ... }) >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".parquet") as f: ... data_path = Path(f.name) ... CSV_data.write_parquet(data_path) - ... direct_load_plain_predicates(data_path, ["is_admission", "is_discharge"], "%m/%d/%Y %H:%M") - shape: (3, 4) - ┌────────────┬─────────────────────┬──────────────┬──────────────┐ - │ subject_id ┆ timestamp ┆ is_admission ┆ is_discharge │ - │ --- ┆ --- ┆ --- ┆ --- │ - │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 │ - ╞════════════╪═════════════════════╪══════════════╪══════════════╡ - │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 │ - │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 │ - │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 │ - └────────────┴─────────────────────┴──────────────┴──────────────┘ + ... direct_load_plain_predicates(data_path, ["is_admission", "is_discharge", "is_male", + ... "is_female", "brown_eyes"], "%m/%d/%Y %H:%M") + shape: (5, 7) + ┌────────────┬─────────────────────┬──────────────┬──────────────┬─────────┬───────────┬────────────┐ + │ subject_id ┆ timestamp ┆ is_admission ┆ is_discharge ┆ is_male ┆ is_female ┆ brown_eyes │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ + ╞════════════╪═════════════════════╪══════════════╪══════════════╪═════════╪═══════════╪════════════╡ + │ 1 ┆ null ┆ 0 ┆ 0 ┆ 1 ┆ 0 ┆ 1 │ + │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 0 │ + │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 0 ┆ 0 │ + │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 0 │ + │ 2 ┆ null ┆ 0 ┆ 0 ┆ 0 ┆ 1 ┆ 0 │ + └────────────┴─────────────────────┴──────────────┴──────────────┴─────────┴───────────┴────────────┘ If the timestamp column is already a timestamp, then the `ts_format` argument id not needed, but can be used without an error. @@ -64,17 +70,20 @@ def direct_load_plain_predicates( ... .with_columns(pl.col("timestamp").str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M")) ... .write_parquet(data_path) ... ) - ... direct_load_plain_predicates(data_path, ["is_admission", "is_discharge"], "%m/%d/%Y %H:%M") - shape: (3, 4) - ┌────────────┬─────────────────────┬──────────────┬──────────────┐ - │ subject_id ┆ timestamp ┆ is_admission ┆ is_discharge │ - │ --- ┆ --- ┆ --- ┆ --- │ - │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 │ - ╞════════════╪═════════════════════╪══════════════╪══════════════╡ - │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 │ - │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 │ - │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 │ - └────────────┴─────────────────────┴──────────────┴──────────────┘ + ... direct_load_plain_predicates(data_path, ["is_admission", "is_discharge", "is_male", + ... "is_female", "brown_eyes"], "%m/%d/%Y %H:%M") + shape: (5, 7) + ┌────────────┬─────────────────────┬──────────────┬──────────────┬─────────┬───────────┬────────────┐ + │ subject_id ┆ timestamp ┆ is_admission ┆ is_discharge ┆ is_male ┆ is_female ┆ brown_eyes │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ + ╞════════════╪═════════════════════╪══════════════╪══════════════╪═════════╪═══════════╪════════════╡ + │ 1 ┆ null ┆ 0 ┆ 0 ┆ 1 ┆ 0 ┆ 1 │ + │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 0 │ + │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 0 ┆ 0 │ + │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 0 │ + │ 2 ┆ null ┆ 0 ┆ 0 ┆ 0 ┆ 1 ┆ 0 │ + └────────────┴─────────────────────┴──────────────┴──────────────┴─────────┴───────────┴────────────┘ >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".parquet") as f: ... data_path = Path(f.name) ... ( @@ -82,45 +91,53 @@ def direct_load_plain_predicates( ... .with_columns(pl.col("timestamp").str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M")) ... .write_parquet(data_path) ... ) - ... direct_load_plain_predicates(data_path, ["is_admission", "is_discharge"], None) - shape: (3, 4) - ┌────────────┬─────────────────────┬──────────────┬──────────────┐ - │ subject_id ┆ timestamp ┆ is_admission ┆ is_discharge │ - │ --- ┆ --- ┆ --- ┆ --- │ - │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 │ - ╞════════════╪═════════════════════╪══════════════╪══════════════╡ - │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 │ - │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 │ - │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 │ - └────────────┴─────────────────────┴──────────────┴──────────────┘ + ... direct_load_plain_predicates(data_path, ["is_admission", "is_discharge", "is_male", + ... "is_female", "brown_eyes"], None) + shape: (5, 7) + ┌────────────┬─────────────────────┬──────────────┬──────────────┬─────────┬───────────┬────────────┐ + │ subject_id ┆ timestamp ┆ is_admission ┆ is_discharge ┆ is_male ┆ is_female ┆ brown_eyes │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ + ╞════════════╪═════════════════════╪══════════════╪══════════════╪═════════╪═══════════╪════════════╡ + │ 1 ┆ null ┆ 0 ┆ 0 ┆ 1 ┆ 0 ┆ 1 │ + │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 0 │ + │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 0 ┆ 0 │ + │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 0 │ + │ 2 ┆ null ┆ 0 ┆ 0 ┆ 0 ┆ 1 ┆ 0 │ + └────────────┴─────────────────────┴──────────────┴──────────────┴─────────┴───────────┴────────────┘ >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".csv") as f: ... data_path = Path(f.name) ... CSV_data.write_csv(data_path) - ... direct_load_plain_predicates(data_path, ["is_admission", "is_discharge"], "%m/%d/%Y %H:%M") - shape: (3, 4) - ┌────────────┬─────────────────────┬──────────────┬──────────────┐ - │ subject_id ┆ timestamp ┆ is_admission ┆ is_discharge │ - │ --- ┆ --- ┆ --- ┆ --- │ - │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 │ - ╞════════════╪═════════════════════╪══════════════╪══════════════╡ - │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 │ - │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 │ - │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 │ - └────────────┴─────────────────────┴──────────────┴──────────────┘ + ... direct_load_plain_predicates(data_path, ["is_admission", "is_discharge", "is_male", + ... "is_female", "brown_eyes"], "%m/%d/%Y %H:%M") + shape: (5, 7) + ┌────────────┬─────────────────────┬──────────────┬──────────────┬─────────┬───────────┬────────────┐ + │ subject_id ┆ timestamp ┆ is_admission ┆ is_discharge ┆ is_male ┆ is_female ┆ brown_eyes │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ + ╞════════════╪═════════════════════╪══════════════╪══════════════╪═════════╪═══════════╪════════════╡ + │ 1 ┆ null ┆ 0 ┆ 0 ┆ 1 ┆ 0 ┆ 1 │ + │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 0 │ + │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 0 ┆ 0 │ + │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 0 │ + │ 2 ┆ null ┆ 0 ┆ 0 ┆ 0 ┆ 1 ┆ 0 │ + └────────────┴─────────────────────┴──────────────┴──────────────┴─────────┴───────────┴────────────┘ >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".csv") as f: ... data_path = Path(f.name) ... CSV_data.write_csv(data_path) - ... direct_load_plain_predicates(data_path, ["is_discharge"], "%m/%d/%Y %H:%M") - shape: (3, 3) - ┌────────────┬─────────────────────┬──────────────┐ - │ subject_id ┆ timestamp ┆ is_discharge │ - │ --- ┆ --- ┆ --- │ - │ i64 ┆ datetime[μs] ┆ i64 │ - ╞════════════╪═════════════════════╪══════════════╡ - │ 1 ┆ 2021-01-01 00:00:00 ┆ 0 │ - │ 1 ┆ 2021-01-01 12:00:00 ┆ 1 │ - │ 2 ┆ 2021-01-02 00:00:00 ┆ 0 │ - └────────────┴─────────────────────┴──────────────┘ + ... direct_load_plain_predicates(data_path, ["is_discharge", "brown_eyes"], "%m/%d/%Y %H:%M") + shape: (5, 4) + ┌────────────┬─────────────────────┬──────────────┬────────────┐ + │ subject_id ┆ timestamp ┆ is_discharge ┆ brown_eyes │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 │ + ╞════════════╪═════════════════════╪══════════════╪════════════╡ + │ 1 ┆ null ┆ 0 ┆ 1 │ + │ 1 ┆ 2021-01-01 00:00:00 ┆ 0 ┆ 0 │ + │ 1 ┆ 2021-01-01 12:00:00 ┆ 1 ┆ 0 │ + │ 2 ┆ 2021-01-02 00:00:00 ┆ 0 ┆ 0 │ + │ 2 ┆ null ┆ 0 ┆ 0 │ + └────────────┴─────────────────────┴──────────────┴────────────┘ >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".csv") as f: ... data_path = Path(f.name) ... CSV_data.write_csv(data_path) @@ -183,8 +200,7 @@ def direct_load_plain_predicates( if missing_columns: raise pl.ColumnNotFoundError(missing_columns) - data = data.select(columns).drop_nulls(subset=["subject_id", "timestamp"]) - + data = data.select(columns) ts_type = data.schema["timestamp"] if ts_type == pl.Utf8: if ts_format is None: @@ -276,33 +292,31 @@ def generate_plain_predicates_from_meds(data_path: Path, predicates: dict) -> pl >>> parquet_data = pl.DataFrame({ ... "patient_id": [1, 1, 1, 2, 3], ... "timestamp": ["1/1/1989 00:00", "1/1/1989 01:00", "1/1/1989 01:00", "1/1/1989 02:00", None], - ... "code": ['admission', 'discharge', 'discharge', 'admission', "gender"], + ... "code": ['admission', 'discharge', 'discharge', 'admission', "gender//male"], ... }).with_columns(pl.col("timestamp").str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M")) >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".parquet") as f: ... data_path = Path(f.name) ... parquet_data.write_parquet(data_path) ... generate_plain_predicates_from_meds( ... data_path, - ... {"discharge": PlainPredicateConfig("discharge")} + ... {"discharge": PlainPredicateConfig("discharge"), + ... "male": PlainPredicateConfig("gender//male", static=True)} ... ) - shape: (3, 3) - ┌────────────┬─────────────────────┬───────────┐ - │ subject_id ┆ timestamp ┆ discharge │ - │ --- ┆ --- ┆ --- │ - │ i64 ┆ datetime[μs] ┆ i64 │ - ╞════════════╪═════════════════════╪═══════════╡ - │ 1 ┆ 1989-01-01 00:00:00 ┆ 0 │ - │ 1 ┆ 1989-01-01 01:00:00 ┆ 2 │ - │ 2 ┆ 1989-01-01 02:00:00 ┆ 0 │ - └────────────┴─────────────────────┴───────────┘ + shape: (4, 4) + ┌────────────┬─────────────────────┬───────────┬──────┐ + │ subject_id ┆ timestamp ┆ discharge ┆ male │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 │ + ╞════════════╪═════════════════════╪═══════════╪══════╡ + │ 1 ┆ 1989-01-01 00:00:00 ┆ 0 ┆ 0 │ + │ 1 ┆ 1989-01-01 01:00:00 ┆ 2 ┆ 0 │ + │ 2 ┆ 1989-01-01 02:00:00 ┆ 0 ┆ 0 │ + │ 3 ┆ null ┆ 0 ┆ 1 │ + └────────────┴─────────────────────┴───────────┴──────┘ """ logger.info("Loading MEDS data...") - data = ( - pl.read_parquet(data_path) - .rename({"patient_id": "subject_id"}) - .drop_nulls(subset=["subject_id", "timestamp"]) - ) + data = pl.read_parquet(data_path).rename({"patient_id": "subject_id"}) if data.columns == ["subject_id", "events"]: data = unnest_meds(data) @@ -324,6 +338,7 @@ def generate_plain_predicates_from_meds(data_path: Path, predicates: dict) -> pl def process_esgpt_data( + subjects_df: pl.DataFrame, events_df: pl.DataFrame, dynamic_measurements_df: pl.DataFrame, value_columns: dict[str, str], @@ -342,6 +357,12 @@ def process_esgpt_data( Examples: >>> from datetime import datetime >>> from .config import PlainPredicateConfig + >>> subjects_df = pl.DataFrame({ + ... "subject_id": [1, 2], + ... "MRN": ["A123", "B456"], + ... "eye_colour": ["brown", "blue"], + ... "dob": [datetime(1980, 1, 1), datetime(1990, 1, 1)], + ... }) >>> events_df = pl.DataFrame({ ... "event_id": [1, 2, 3, 4], ... "subject_id": [1, 1, 2, 2], @@ -369,23 +390,26 @@ def process_esgpt_data( ... "high_Potassium": "lab_val", ... } >>> predicates = { - ... "is_admission": PlainPredicateConfig(code="event_type//adm"), - ... "is_discharge": PlainPredicateConfig(code="event_type//dis"), + ... "is_adm": PlainPredicateConfig(code="event_type//adm"), + ... "is_dis": PlainPredicateConfig(code="event_type//dis"), ... "high_HR": PlainPredicateConfig(code="HR", value_min=140), ... "high_Potassium": PlainPredicateConfig(code="lab//K", value_min=5.0), + ... "eye_colour": PlainPredicateConfig(code="eye_colour//brown", static=True), ... } - >>> process_esgpt_data(events_df, dynamic_measurements_df, value_columns, predicates) - shape: (4, 6) - ┌────────────┬─────────────────────┬──────────────┬──────────────┬─────────┬────────────────┐ - │ subject_id ┆ timestamp ┆ is_admission ┆ is_discharge ┆ high_HR ┆ high_Potassium │ - │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ - │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ - ╞════════════╪═════════════════════╪══════════════╪══════════════╪═════════╪════════════════╡ - │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 ┆ 1 ┆ 1 │ - │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 0 │ - │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 │ - │ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 1 ┆ 0 │ - └────────────┴─────────────────────┴──────────────┴──────────────┴─────────┴────────────────┘ + >>> process_esgpt_data(subjects_df, events_df, dynamic_measurements_df, value_columns, predicates) + shape: (6, 7) + ┌────────────┬─────────────────────┬────────┬────────┬─────────┬────────────────┬────────────┐ + │ subject_id ┆ timestamp ┆ is_adm ┆ is_dis ┆ high_HR ┆ high_Potassium ┆ eye_colour │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ + ╞════════════╪═════════════════════╪════════╪════════╪═════════╪════════════════╪════════════╡ + │ 1 ┆ null ┆ 0 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │ + │ 2 ┆ null ┆ 0 ┆ 0 ┆ 0 ┆ 0 ┆ 0 │ + │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 ┆ 1 ┆ 1 ┆ 0 │ + │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 0 ┆ 0 │ + │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 0 │ + │ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 1 ┆ 0 ┆ 0 │ + └────────────┴─────────────────────┴────────┴────────┴─────────┴────────────────┴────────────┘ """ logger.info("Generating plain predicate columns...") @@ -394,6 +418,10 @@ def process_esgpt_data( events_df = events_df.with_columns( plain_predicate.ESGPT_eval_expr().cast(PRED_CNT_TYPE).alias(name) ) + elif plain_predicate.static: + subjects_df = subjects_df.with_columns( + plain_predicate.ESGPT_eval_expr().cast(PRED_CNT_TYPE).alias(name), + ) else: values_column = value_columns[name] dynamic_measurements_df = dynamic_measurements_df.with_columns( @@ -401,6 +429,8 @@ def process_esgpt_data( ) logger.info(f"Added predicate column '{name}'.") + # clean up predicates_df + logger.info("Cleaning up predicates dataframe...") predicate_cols = list(predicates.keys()) # aggregate dynamic_measurements_df by summing predicates (counts) @@ -419,9 +449,20 @@ def process_esgpt_data( # join events_df and dynamic_measurements_df for the final predicates_df data = events_df.join(dynamic_measurements_df, on="event_id", how="left") - # clean up predicates_df - logger.info("Cleaning up predicates dataframe...") - return data.select(["subject_id", "timestamp"] + predicate_cols) + # return concatenated subjects_df and data + static_rows = subjects_df.select( + "subject_id", + pl.lit(None).alias("timestamp").cast(pl.Datetime), + *[pl.lit(0).alias(c).cast(PRED_CNT_TYPE) for c in predicate_cols if not predicates[c].static], + *[pl.col(c) for c in predicate_cols if predicates[c].static], + ) + data = data.select( + "subject_id", + "timestamp", + *[pl.col(c) for c in predicate_cols if not predicates[c].static], + *[pl.lit(0).alias(c).cast(PRED_CNT_TYPE) for c in predicate_cols if predicates[c].static], + ) + return pl.concat([static_rows, data]) def generate_plain_predicates_from_esgpt(data_path: Path, predicates: dict) -> pl.DataFrame: @@ -457,6 +498,7 @@ def generate_plain_predicates_from_esgpt(data_path: Path, predicates: dict) -> p "If you mean to use a MEDS dataset, please specify the 'MEDS' standard." ) from e + subjects_df = ESD.subjects_df events_df = ESD.events_df dynamic_measurements_df = ESD.dynamic_measurements_df config = ESD.config @@ -469,7 +511,7 @@ def generate_plain_predicates_from_esgpt(data_path: Path, predicates: dict) -> p measurement_name = plain_predicate.code.split("//")[0] value_columns[name] = config.measurement_configs[measurement_name].values_column - return process_esgpt_data(events_df, dynamic_measurements_df, value_columns, predicates) + return process_esgpt_data(subjects_df, events_df, dynamic_measurements_df, value_columns, predicates) def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.DataFrame: @@ -491,16 +533,25 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D >>> import tempfile >>> from .config import PlainPredicateConfig, DerivedPredicateConfig, EventConfig, WindowConfig >>> data = pl.DataFrame({ - ... "subject_id": [1, 1, 2, 2], - ... "timestamp": ["01/01/2021 00:00", "01/01/2021 12:00", "01/02/2021 00:00", "01/02/2021 12:00"], - ... "adm": [1, 0, 1, 0], - ... "dis": [0, 1, 0, 0], - ... "death": [0, 0, 0, 1], + ... "subject_id": [1, 1, 1, 2, 2, 2], + ... "timestamp": [ + ... None, + ... "01/01/2021 00:00", + ... "01/01/2021 12:00", + ... None, + ... "01/02/2021 00:00", + ... "01/02/2021 12:00"], + ... "adm": [0, 1, 0, 0, 1, 0], + ... "dis": [0, 0, 1, 0, 0, 0], + ... "death": [0, 0, 0, 0, 0, 1], + ... "male": [1, 0, 0, 0, 0, 0], + ... "female": [0, 0, 0, 1, 0, 0], ... }) >>> predicates = { ... "adm": PlainPredicateConfig("adm"), ... "dis": PlainPredicateConfig("dis"), ... "death": PlainPredicateConfig("death"), + ... "male": PlainPredicateConfig("male", static=True), # predicate match based on name for direct ... "death_or_dis": DerivedPredicateConfig("or(death, dis)"), ... } >>> trigger = EventConfig("adm") @@ -538,17 +589,19 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D ... "path": str(data_path), "standard": "direct", "ts_format": "%m/%d/%Y %H:%M" ... }) ... get_predicates_df(config, data_config) - shape: (4, 7) - ┌────────────┬─────────────────────┬─────┬─────┬───────┬──────────────┬────────────┐ - │ subject_id ┆ timestamp ┆ adm ┆ dis ┆ death ┆ death_or_dis ┆ _ANY_EVENT │ - │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ - │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ - ╞════════════╪═════════════════════╪═════╪═════╪═══════╪══════════════╪════════════╡ - │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │ - │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 1 ┆ 1 │ - │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │ - │ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 1 ┆ 1 ┆ 1 │ - └────────────┴─────────────────────┴─────┴─────┴───────┴──────────────┴────────────┘ + shape: (6, 8) + ┌────────────┬─────────────────────┬─────┬─────┬───────┬──────┬──────────────┬────────────┐ + │ subject_id ┆ timestamp ┆ adm ┆ dis ┆ death ┆ male ┆ death_or_dis ┆ _ANY_EVENT │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ + ╞════════════╪═════════════════════╪═════╪═════╪═══════╪══════╪══════════════╪════════════╡ + │ 1 ┆ null ┆ 0 ┆ 0 ┆ 0 ┆ 1 ┆ 0 ┆ null │ + │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │ + │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 0 ┆ 1 ┆ 1 │ + │ 2 ┆ null ┆ 0 ┆ 0 ┆ 0 ┆ 0 ┆ 0 ┆ null │ + │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │ + │ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 1 ┆ 0 ┆ 1 ┆ 1 │ + └────────────┴─────────────────────┴─────┴─────┴───────┴──────┴──────────────┴────────────┘ >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".parquet") as f: ... data_path = Path(f.name) ... ( @@ -558,19 +611,21 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D ... ) ... data_config = DictConfig({"path": str(data_path), "standard": "direct", "ts_format": None}) ... get_predicates_df(config, data_config) - shape: (4, 7) - ┌────────────┬─────────────────────┬─────┬─────┬───────┬──────────────┬────────────┐ - │ subject_id ┆ timestamp ┆ adm ┆ dis ┆ death ┆ death_or_dis ┆ _ANY_EVENT │ - │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ - │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ - ╞════════════╪═════════════════════╪═════╪═════╪═══════╪══════════════╪════════════╡ - │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │ - │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 1 ┆ 1 │ - │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │ - │ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 1 ┆ 1 ┆ 1 │ - └────────────┴─────────────────────┴─────┴─────┴───────┴──────────────┴────────────┘ + shape: (6, 8) + ┌────────────┬─────────────────────┬─────┬─────┬───────┬──────┬──────────────┬────────────┐ + │ subject_id ┆ timestamp ┆ adm ┆ dis ┆ death ┆ male ┆ death_or_dis ┆ _ANY_EVENT │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ + ╞════════════╪═════════════════════╪═════╪═════╪═══════╪══════╪══════════════╪════════════╡ + │ 1 ┆ null ┆ 0 ┆ 0 ┆ 0 ┆ 1 ┆ 0 ┆ null │ + │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │ + │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 0 ┆ 1 ┆ 1 │ + │ 2 ┆ null ┆ 0 ┆ 0 ┆ 0 ┆ 0 ┆ 0 ┆ null │ + │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │ + │ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 1 ┆ 0 ┆ 1 ┆ 1 │ + └────────────┴─────────────────────┴─────┴─────┴───────┴──────┴──────────────┴────────────┘ >>> any_event_trigger = EventConfig("_ANY_EVENT") - >>> adm_only_predicates = {"adm": PlainPredicateConfig("adm")} + >>> adm_only_predicates = {"adm": PlainPredicateConfig("adm"), "male": PlainPredicateConfig("male")} >>> st_end_windows = { ... "input": WindowConfig( ... start="end - 365d", @@ -593,17 +648,19 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D ... "path": str(data_path), "standard": "direct", "ts_format": "%m/%d/%Y %H:%M" ... }) ... get_predicates_df(st_end_config, data_config) - shape: (4, 6) - ┌────────────┬─────────────────────┬─────┬────────────┬───────────────┬─────────────┐ - │ subject_id ┆ timestamp ┆ adm ┆ _ANY_EVENT ┆ _RECORD_START ┆ _RECORD_END │ - │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ - │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ - ╞════════════╪═════════════════════╪═════╪════════════╪═══════════════╪═════════════╡ - │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 1 ┆ 1 ┆ 0 │ - │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 1 │ - │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 1 ┆ 1 ┆ 0 │ - │ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 1 │ - └────────────┴─────────────────────┴─────┴────────────┴───────────────┴─────────────┘ + shape: (6, 7) + ┌────────────┬─────────────────────┬─────┬──────┬────────────┬───────────────┬─────────────┐ + │ subject_id ┆ timestamp ┆ adm ┆ male ┆ _ANY_EVENT ┆ _RECORD_START ┆ _RECORD_END │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ + ╞════════════╪═════════════════════╪═════╪══════╪════════════╪═══════════════╪═════════════╡ + │ 1 ┆ null ┆ 0 ┆ 1 ┆ null ┆ null ┆ null │ + │ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 0 ┆ 1 ┆ 1 ┆ 0 │ + │ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 0 ┆ 1 ┆ 0 ┆ 1 │ + │ 2 ┆ null ┆ 0 ┆ 0 ┆ null ┆ null ┆ null │ + │ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 1 ┆ 1 ┆ 0 │ + │ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 1 ┆ 0 ┆ 1 │ + └────────────┴─────────────────────┴─────┴──────┴────────────┴───────────────┴─────────────┘ >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".csv") as f: ... data_path = Path(f.name) ... data.write_csv(data_path) @@ -640,7 +697,7 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D logger.info(f"Added predicate column '{name}'.") predicate_cols.append(name) - data = data.sort(by=["subject_id", "timestamp"]) + data = data.sort(by=["subject_id", "timestamp"], nulls_last=False) # add special predicates: # a column of 1s representing any predicate @@ -666,7 +723,13 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D special_predicates.append(cfg.trigger.predicate) if ANY_EVENT_COLUMN in special_predicates: - data = data.with_columns(pl.lit(1).alias(ANY_EVENT_COLUMN).cast(PRED_CNT_TYPE)) + data = data.with_columns( + pl.when(pl.col("timestamp").is_not_null()) + .then(pl.lit(1)) + .otherwise(pl.lit(None)) + .alias(ANY_EVENT_COLUMN) + .cast(PRED_CNT_TYPE) + ) logger.info(f"Added predicate column '{ANY_EVENT_COLUMN}'.") predicate_cols.append(ANY_EVENT_COLUMN) if START_OF_RECORD_KEY in special_predicates: diff --git a/src/aces/query.py b/src/aces/query.py index d9c5152..b7605d6 100644 --- a/src/aces/query.py +++ b/src/aces/query.py @@ -9,7 +9,7 @@ from loguru import logger from .config import TaskExtractorConfig -from .constraints import check_constraints +from .constraints import check_constraints, check_static_variables from .extract_subtree import extract_subtree from .utils import log_tree @@ -56,6 +56,15 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame log_tree(cfg.window_tree) logger.info("Beginning query...") + + static_variables = [pred for pred in cfg.predicates if cfg.predicates[pred].static] + if static_variables: + logger.info("Static variable criteria specified, filtering patient demographics...") + predicates_df = check_static_variables(static_variables, predicates_df) + else: + logger.info("No static variable criteria specified, removing all rows with null timestamps...") + predicates_df = predicates_df.drop_nulls(subset=["subject_id", "timestamp"]) + logger.info("Identifying possible trigger nodes based on the specified trigger event...") prospective_root_anchors = check_constraints({cfg.trigger.predicate: (1, None)}, predicates_df).select( "subject_id", pl.col("timestamp").alias("subtree_anchor_timestamp") diff --git a/tests/test_check_static_variables.py b/tests/test_check_static_variables.py new file mode 100644 index 0000000..fbedc19 --- /dev/null +++ b/tests/test_check_static_variables.py @@ -0,0 +1,55 @@ +from datetime import datetime + +import polars as pl +import pytest + +from src.aces.constraints import check_static_variables + + +def test_check_static_variables(): + # Create a sample DataFrame + predicates_df = pl.DataFrame( + { + "subject_id": [1, 1, 1, 1, 1, 2, 2, 2], + "timestamp": [ + None, + datetime(year=1989, month=12, day=1, hour=12, minute=3), + datetime(year=1989, month=12, day=2, hour=5, minute=17), + datetime(year=1989, month=12, day=2, hour=12, minute=3), + datetime(year=1989, month=12, day=6, hour=11, minute=0), + None, + datetime(year=1989, month=12, day=1, hour=13, minute=14), + datetime(year=1989, month=12, day=3, hour=15, minute=17), + ], + "is_A": [0, 1, 4, 1, 0, 3, 3, 3], + "is_B": [0, 0, 2, 0, 0, 2, 10, 2], + "is_C": [0, 1, 1, 1, 0, 0, 1, 1], + "male": [1, 0, 0, 0, 0, 0, 0, 0], + } + ) + + # Test filtering based on 'male' demographic + filtered_df = check_static_variables(["male"], predicates_df) + expected_df = pl.DataFrame( + { + "subject_id": [1, 1, 1, 1], + "timestamp": [ + datetime(year=1989, month=12, day=1, hour=12, minute=3), + datetime(year=1989, month=12, day=2, hour=5, minute=17), + datetime(year=1989, month=12, day=2, hour=12, minute=3), + datetime(year=1989, month=12, day=6, hour=11, minute=0), + ], + "is_A": [1, 4, 1, 0], + "is_B": [0, 2, 0, 0], + "is_C": [1, 1, 1, 0], + } + ) + assert filtered_df.frame_equal(expected_df) + + # Test ValueError when demographic column is missing + with pytest.raises(ValueError, match="Static predicate 'female' not found in the predicates dataframe."): + check_static_variables(["female"], predicates_df) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 41ce398..a301193 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -16,62 +16,68 @@ TS_FORMAT = "%m/%d/%Y %H:%M" PRED_CNT_TYPE = pl.Int64 +EVENT_INDEX_TYPE = pl.UInt64 ANY_EVENT_COLUMN = "_ANY_EVENT" +LAST_EVENT_INDEX_COLUMN = "_LAST_EVENT_INDEX" + # Data (input) PREDICATES_CSV = """ -subject_id,timestamp,admission,death,discharge,lab,spo2,normal_spo2,abnormally_low_spo2,abnormally_high_spo2,procedure_start,procedure_end,ventilation,diagnosis_ICD9CM/41071,diagnosis_ICD10CM/I214 -1,12/1/1989 12:03,1,0,0,0,0,0,0,0,0,0,0,0,0 -1,12/1/1989 13:14,0,0,0,1,1,1,0,0,0,0,0,0,0 -1,12/1/1989 15:17,0,0,0,1,1,1,0,0,0,0,0,0,0 -1,12/1/1989 16:17,0,0,0,1,1,1,0,0,0,0,0,0,0 -1,12/1/1989 20:17,0,0,0,1,1,1,0,0,0,0,0,0,0 -1,12/2/1989 3:00,0,0,0,1,1,1,0,0,0,0,0,0,0 -1,12/2/1989 9:00,0,0,0,0,0,0,0,0,0,0,0,0,0 -1,12/2/1989 10:00,0,0,0,0,0,0,0,0,1,0,1,0,0 -1,12/2/1989 14:22,0,0,0,0,0,0,0,0,0,1,1,0,0 -1,12/2/1989 15:00,0,0,1,0,0,0,0,0,0,0,0,0,0 -1,1/21/1991 11:59,0,0,0,0,0,0,0,0,0,0,0,1,0 -1,1/27/1991 23:32,1,0,0,0,0,0,0,0,0,0,0,0,0 -1,1/27/1991 23:46,0,0,0,1,1,1,0,0,0,0,0,0,0 -1,1/28/1991 3:18,0,0,0,1,0,0,0,0,0,0,0,0,0 -1,1/28/1991 3:28,0,0,0,0,0,0,0,0,1,0,1,0,0 -1,1/28/1991 4:36,0,0,0,1,1,1,0,0,0,0,0,0,0 -1,1/29/1991 23:32,0,0,0,1,1,1,0,0,0,0,0,0,0 -1,1/30/1991 5:00,0,0,0,1,0,0,0,0,0,0,0,0,0 -1,1/30/1991 8:00,0,0,0,1,1,0,0,1,0,0,0,0,0 -1,1/30/1991 11:00,0,0,0,1,1,1,0,0,0,0,0,0,0 -1,1/30/1991 14:00,0,0,0,0,0,0,0,0,0,0,0,0,0 -1,1/30/1991 14:15,0,0,0,1,0,0,0,0,0,0,0,0,0 -1,1/31/1991 1:00,0,0,0,0,0,0,0,0,0,1,1,0,0 -1,1/31/1991 2:15,0,0,1,0,0,0,0,0,0,0,0,0,0 -1,2/8/1991 8:15,0,0,0,1,1,1,0,0,0,0,0,0,0 -1,3/3/1991 19:33,1,0,0,0,0,0,0,0,0,0,0,0,0 -1,3/3/1991 20:33,0,0,0,1,1,0,1,0,0,0,0,0,0 -1,3/3/1991 21:38,0,1,0,0,0,0,0,0,0,0,0,0,0 -2,3/8/1996 2:24,1,0,0,0,0,0,0,0,0,0,0,0,0 -2,3/8/1996 2:35,0,0,0,1,1,1,0,0,0,0,0,0,0 -2,3/8/1996 4:00,0,0,0,1,1,1,0,0,0,0,0,0,0 -2,3/8/1996 10:00,0,0,0,0,0,0,0,0,0,0,0,0,0 -2,3/8/1996 16:00,0,0,1,0,0,0,0,0,0,0,0,0,0 -2,6/5/1996 0:32,1,0,0,0,0,0,0,0,0,0,0,0,0 -2,6/5/1996 0:48,0,0,0,0,0,0,0,0,0,0,0,0,1 -2,6/5/1996 1:59,0,0,0,0,0,0,0,0,1,0,1,0,0 -2,6/7/1996 6:00,0,0,0,1,0,0,0,0,0,0,0,0,0 -2,6/7/1996 9:00,0,0,0,1,1,0,1,0,0,0,0,0,0 -2,6/7/1996 12:00,0,0,0,1,1,1,0,0,0,0,0,0,0 -2,6/7/1996 15:00,0,0,0,0,0,0,0,0,0,1,1,0,0 -2,6/7/1996 15:00,0,0,0,0,0,0,0,0,0,0,0,0,0 -2,6/8/1996 3:00,0,1,0,0,0,0,0,0,0,0,0,0,0 -3,3/8/1996 2:22,0,0,0,0,0,0,0,0,1,0,1,0,0 -3,3/8/1996 2:24,1,0,0,0,0,0,0,0,0,0,0,0,0 -3,3/8/1996 2:37,0,0,0,1,1,1,0,0,0,0,0,0,0 -3,3/9/1996 8:00,0,0,0,1,0,0,0,0,0,0,0,0,0 -3,3/9/1996 11:00,0,0,0,1,1,1,0,0,0,0,0,0,0 -3,3/9/1996 19:00,0,0,0,1,1,1,0,0,0,0,0,0,0 -3,3/9/1996 22:00,0,0,0,0,0,0,0,0,0,0,0,0,0 -3,3/11/1996 21:00,0,0,0,0,0,0,0,0,0,1,1,0,0 -3,3/12/1996 0:00,0,1,0,0,0,0,0,0,0,0,0,0,0 +subject_id,timestamp,male,female,admission,death,discharge,lab,spo2,normal_spo2,abnormally_low_spo2,abnormally_high_spo2,procedure_start,procedure_end,ventilation,diagnosis_ICD9CM_41071,diagnosis_ICD10CM_I214 +1,,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0 +1,12/1/1989 12:03,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0 +1,12/1/1989 13:14,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +1,12/1/1989 15:17,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +1,12/1/1989 16:17,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +1,12/1/1989 20:17,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +1,12/2/1989 3:00,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +1,12/2/1989 9:00,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 +1,12/2/1989 10:00,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0 +1,12/2/1989 14:22,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0 +1,12/2/1989 15:00,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0 +1,1/21/1991 11:59,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0 +1,1/27/1991 23:32,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0 +1,1/27/1991 23:46,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +1,1/28/1991 3:18,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0 +1,1/28/1991 3:28,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0 +1,1/28/1991 4:36,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +1,1/29/1991 23:32,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +1,1/30/1991 5:00,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0 +1,1/30/1991 8:00,0,0,0,0,0,1,1,0,0,1,0,0,0,0,0 +1,1/30/1991 11:00,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +1,1/30/1991 14:00,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 +1,1/30/1991 14:15,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0 +1,1/31/1991 1:00,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0 +1,1/31/1991 2:15,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0 +1,2/8/1991 8:15,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +1,3/3/1991 19:33,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0 +1,3/3/1991 20:33,0,0,0,0,0,1,1,0,1,0,0,0,0,0,0 +1,3/3/1991 21:38,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0 +2,,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0 +2,3/8/1996 2:24,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0 +2,3/8/1996 2:35,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +2,3/8/1996 4:00,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +2,3/8/1996 10:00,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 +2,3/8/1996 16:00,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0 +2,6/5/1996 0:32,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0 +2,6/5/1996 0:48,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1 +2,6/5/1996 1:59,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0 +2,6/7/1996 6:00,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0 +2,6/7/1996 9:00,0,0,0,0,0,1,1,0,1,0,0,0,0,0,0 +2,6/7/1996 12:00,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +2,6/7/1996 15:00,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0 +2,6/7/1996 15:00,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 +2,6/8/1996 3:00,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0 +3,,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0 +3,3/8/1996 2:22,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0 +3,3/8/1996 2:24,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0 +3,3/8/1996 2:37,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +3,3/9/1996 8:00,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0 +3,3/9/1996 11:00,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +3,3/9/1996 19:00,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0 +3,3/9/1996 22:00,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 +3,3/11/1996 21:00,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0 +3,3/12/1996 0:00,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0 """ # Tasks (input) @@ -88,6 +94,10 @@ discharge_or_death: expr: or(discharge, death) + patient_demographics: + male: + code: SEX//male + trigger: admission windows: @@ -120,10 +130,10 @@ # Expected output EXPECTED_OUTPUT = { "inhospital_mortality": { - "subject_id": [1, 2], - "index_timestamp": ["01/28/1991 23:32", "06/06/1996 00:32"], - "label": [0, 1], - "trigger": ["01/27/1991 23:32", "06/05/1996 00:32"], + "subject_id": [1], + "index_timestamp": ["01/28/1991 23:32"], + "label": [0], + "trigger": ["01/27/1991 23:32"], "input.end_summary": [ { "window_name": "input.end", @@ -135,16 +145,6 @@ "discharge_or_death": 0, "_ANY_EVENT": 4, }, - { - "window_name": "input.end", - "timestamp_at_start": "06/05/1996 00:32", - "timestamp_at_end": "06/06/1996 00:32", - "admission": 0, - "discharge": 0, - "death": 0, - "discharge_or_death": 0, - "_ANY_EVENT": 2, - }, ], "input.start_summary": [ { @@ -157,16 +157,6 @@ "discharge_or_death": 1, "_ANY_EVENT": 16, }, - { - "window_name": "input.start", - "timestamp_at_start": "03/08/1996 02:24", - "timestamp_at_end": "06/06/1996 00:32", - "admission": 2, - "discharge": 1, - "death": 0, - "discharge_or_death": 1, - "_ANY_EVENT": 8, - }, ], "gap.end_summary": [ { @@ -179,16 +169,6 @@ "discharge_or_death": 0, "_ANY_EVENT": 5, }, - { - "window_name": "gap.end", - "timestamp_at_start": "06/05/1996 00:32", - "timestamp_at_end": "06/07/1996 00:32", - "admission": 0, - "discharge": 0, - "death": 0, - "discharge_or_death": 0, - "_ANY_EVENT": 2, - }, ], "target.end_summary": [ { @@ -201,16 +181,6 @@ "discharge_or_death": 1, "_ANY_EVENT": 7, }, - { - "window_name": "target.end", - "timestamp_at_start": "06/07/1996 00:32", - "timestamp_at_end": "06/08/1996 03:00", - "admission": 0, - "discharge": 0, - "death": 1, - "discharge_or_death": 1, - "_ANY_EVENT": 5, - }, ], } } @@ -326,7 +296,11 @@ def test_e2e(): ) else: want = pl.DataFrame({col_name: expected_data}).with_columns( - pl.col(col_name).cast(PRED_CNT_TYPE) + *[ + pl.col(col_name).cast(PRED_CNT_TYPE) + if col_name != LAST_EVENT_INDEX_COLUMN + else pl.col(col_name).cast(EVENT_INDEX_TYPE) + ] ) got = got_df.select(col_name) assert_df_equal(want, got, f"Data mismatch for task '{task_name}', column '{col_name}'") From be15d3d5b8ee775bb47b2b405544eba4dd2dd011 Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Mon, 5 Aug 2024 11:50:54 +0100 Subject: [PATCH 04/11] TODO: need to add support for these fields (#87, #90) --- sample_configs/inhospital_mortality.yaml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sample_configs/inhospital_mortality.yaml b/sample_configs/inhospital_mortality.yaml index a8f7eae..064e350 100644 --- a/sample_configs/inhospital_mortality.yaml +++ b/sample_configs/inhospital_mortality.yaml @@ -1,7 +1,10 @@ # Task: 24-hour In-hospital Mortality Prediction predicates: admission: - code: event_type//ADMISSION + code: + regex: + any: [event_type//ADMISSION, event_type//EMERGENCY, event_type//ELECTIVE] + other_col: foo discharge: code: event_type//DISCHARGE death: From 9c6cb3da320c444c0d66a2c6c4be4888f3c72845 Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Thu, 8 Aug 2024 18:50:26 +0100 Subject: [PATCH 05/11] Supports #87 --- sample_configs/inhospital_mortality.yaml | 2 +- src/aces/config.py | 30 ++++++++++++++++++------ 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/sample_configs/inhospital_mortality.yaml b/sample_configs/inhospital_mortality.yaml index 064e350..e08674f 100644 --- a/sample_configs/inhospital_mortality.yaml +++ b/sample_configs/inhospital_mortality.yaml @@ -4,7 +4,7 @@ predicates: code: regex: any: [event_type//ADMISSION, event_type//EMERGENCY, event_type//ELECTIVE] - other_col: foo + random_col: foo discharge: code: event_type//DISCHARGE death: diff --git a/src/aces/config.py b/src/aces/config.py index 39f71f0..dfcaa5f 100644 --- a/src/aces/config.py +++ b/src/aces/config.py @@ -34,6 +34,7 @@ class PlainPredicateConfig: value_min_inclusive: bool | None = None value_max_inclusive: bool | None = None static: bool = False + other_cols: dict[str, str] = field(default_factory=dict) def MEDS_eval_expr(self) -> pl.Expr: """Returns a Polars expression that evaluates this predicate for a MEDS formatted dataset. @@ -60,6 +61,11 @@ def MEDS_eval_expr(self) -> pl.Expr: >>> expr = cfg.MEDS_eval_expr() >>> print(expr) # doctest: +NORMALIZE_WHITESPACE [(col("code")) == (String(BP//diastolic))] + >>> cfg = PlainPredicateConfig("BP//diastolic", other_cols={"chamber": "atrial"}) + >>> expr = cfg.MEDS_eval_expr() + >>> print(expr) # doctest: +NORMALIZE_WHITESPACE + [(col("code")) == (String(BP//diastolic))].all_horizontal([[(col("chamber")) == + (String(atrial))]]) """ criteria = [pl.col("code") == self.code] @@ -74,6 +80,9 @@ def MEDS_eval_expr(self) -> pl.Expr: else: criteria.append(pl.col("numerical_value") < self.value_max) + if self.other_cols: + criteria.extend([pl.col(col) == value for col, value in self.other_cols.items()]) + if len(criteria) == 1: return criteria[0] else: @@ -860,19 +869,22 @@ class TaskExtractorConfig: value_max=None, value_min_inclusive=None, value_max_inclusive=None, - static=False), + static=False, + other_cols={}), 'discharge': PlainPredicateConfig(code='discharge', value_min=None, value_max=None, value_min_inclusive=None, value_max_inclusive=None, - static=False), + static=False, + other_cols={}), 'death': PlainPredicateConfig(code='death', value_min=None, value_max=None, value_min_inclusive=None, value_max_inclusive=None, - static=False)} + static=False, + other_cols={})} >>> print(config.label_window) # doctest: +NORMALIZE_WHITESPACE target @@ -1084,10 +1096,14 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta raise ValueError(f"Unrecognized keys in configuration file: '{', '.join(loaded_dict.keys())}'") logger.info("Parsing predicates...") - predicates = { - n: DerivedPredicateConfig(**p) if "expr" in p else PlainPredicateConfig(**p) - for n, p in predicates.items() - } + predicates = {} + for n, p in predicates.items(): + if "expr" in p: + predicates[n] = DerivedPredicateConfig(**p) + else: + config_data = {k: v for k, v in p.items() if k in PlainPredicateConfig.__dataclass_fields__} + other_cols = {k: v for k, v in p.items() if k not in config_data.keys()} + predicates[n] = PlainPredicateConfig(**p, other_cols=other_cols) if patient_demographics: logger.info("Parsing patient demographics...") From e209857e094443ad0bf3840d3923d3bad024b5b0 Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Thu, 8 Aug 2024 20:21:57 +0100 Subject: [PATCH 06/11] Fix tests and config parsing --- sample_configs/inhospital_mortality.yaml | 4 +--- src/aces/config.py | 12 +++++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sample_configs/inhospital_mortality.yaml b/sample_configs/inhospital_mortality.yaml index e08674f..572a5b8 100644 --- a/sample_configs/inhospital_mortality.yaml +++ b/sample_configs/inhospital_mortality.yaml @@ -1,9 +1,7 @@ # Task: 24-hour In-hospital Mortality Prediction predicates: admission: - code: - regex: - any: [event_type//ADMISSION, event_type//EMERGENCY, event_type//ELECTIVE] + code: event_type//ADMISSION random_col: foo discharge: code: event_type//DISCHARGE diff --git a/src/aces/config.py b/src/aces/config.py index dfcaa5f..8751abd 100644 --- a/src/aces/config.py +++ b/src/aces/config.py @@ -1096,21 +1096,21 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta raise ValueError(f"Unrecognized keys in configuration file: '{', '.join(loaded_dict.keys())}'") logger.info("Parsing predicates...") - predicates = {} + predicate_objs = {} for n, p in predicates.items(): if "expr" in p: - predicates[n] = DerivedPredicateConfig(**p) + predicate_objs[n] = DerivedPredicateConfig(**p) else: config_data = {k: v for k, v in p.items() if k in PlainPredicateConfig.__dataclass_fields__} other_cols = {k: v for k, v in p.items() if k not in config_data.keys()} - predicates[n] = PlainPredicateConfig(**p, other_cols=other_cols) + predicate_objs[n] = PlainPredicateConfig(**config_data, other_cols=other_cols) if patient_demographics: logger.info("Parsing patient demographics...") patient_demographics = { n: PlainPredicateConfig(**p, static=True) for n, p in patient_demographics.items() } - predicates.update(patient_demographics) + predicate_objs.update(patient_demographics) logger.info("Parsing trigger event...") trigger = EventConfig(trigger) @@ -1124,7 +1124,9 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta else: windows = {n: WindowConfig(**w) for n, w in windows.items()} - return cls(predicates=predicates, trigger=trigger, windows=windows) + print(predicate_objs) + + return cls(predicates=predicate_objs, trigger=trigger, windows=windows) def save(self, config_path: str | Path, do_overwrite: bool = False): """Load a configuration file from the given path and return it as a dict. From e9af0477e6f74666f57ec25618746f564ae369fa Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Fri, 9 Aug 2024 06:15:10 +0100 Subject: [PATCH 07/11] ESGPT support --- sample_configs/inhospital_mortality.yaml | 6 ++++-- src/aces/config.py | 6 ++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/sample_configs/inhospital_mortality.yaml b/sample_configs/inhospital_mortality.yaml index 572a5b8..500646f 100644 --- a/sample_configs/inhospital_mortality.yaml +++ b/sample_configs/inhospital_mortality.yaml @@ -1,8 +1,10 @@ # Task: 24-hour In-hospital Mortality Prediction predicates: admission: - code: event_type//ADMISSION - random_col: foo + code: + regex: + any: [event_type//ADMISSION, event_type//EMERGENCY, event_type//ELECTIVE] + all: discharge: code: event_type//DISCHARGE death: diff --git a/src/aces/config.py b/src/aces/config.py index 8751abd..fd45957 100644 --- a/src/aces/config.py +++ b/src/aces/config.py @@ -129,6 +129,9 @@ def ESGPT_eval_expr(self, values_column: str | None = None) -> pl.Expr: >>> expr = PlainPredicateConfig("BP").ESGPT_eval_expr() >>> print(expr) # doctest: +NORMALIZE_WHITESPACE col("BP").is_not_null() + >>> expr = PlainPredicateConfig("BP//systole", other_cols={"chamber": "atrial"}).ESGPT_eval_expr() + >>> print(expr) # doctest: +NORMALIZE_WHITESPACE + [(col("BP")) == (String(systole))].all_horizontal([[(col("chamber")) == (String(atrial))]]) >>> expr = PlainPredicateConfig("BP//systolic", value_min=120).ESGPT_eval_expr() Traceback (most recent call last): ... @@ -176,6 +179,9 @@ def ESGPT_eval_expr(self, values_column: str | None = None) -> pl.Expr: else: criteria.append(pl.col(values_column) < self.value_max) + if self.other_cols: + criteria.extend([pl.col(col) == value for col, value in self.other_cols.items()]) + if len(criteria) == 1: return criteria[0] else: From 020a3a7b86b06a22ec9d007761acdb553e67eaa7 Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Fri, 9 Aug 2024 07:56:18 +0100 Subject: [PATCH 08/11] Supports regex predicates and any specification for code field (#90) --- sample_configs/inhospital_mortality.yaml | 4 +- src/aces/config.py | 96 +++++++++++++++++++++--- src/aces/predicates.py | 1 + 3 files changed, 87 insertions(+), 14 deletions(-) diff --git a/sample_configs/inhospital_mortality.yaml b/sample_configs/inhospital_mortality.yaml index 500646f..6eea0c6 100644 --- a/sample_configs/inhospital_mortality.yaml +++ b/sample_configs/inhospital_mortality.yaml @@ -2,9 +2,7 @@ predicates: admission: code: - regex: - any: [event_type//ADMISSION, event_type//EMERGENCY, event_type//ELECTIVE] - all: + regex: ^event_type.* discharge: code: event_type//DISCHARGE death: diff --git a/src/aces/config.py b/src/aces/config.py index fd45957..633f223 100644 --- a/src/aces/config.py +++ b/src/aces/config.py @@ -28,7 +28,7 @@ @dataclasses.dataclass class PlainPredicateConfig: - code: str + code: str | dict value_min: float | None = None value_max: float | None = None value_min_inclusive: bool | None = None @@ -66,19 +66,93 @@ def MEDS_eval_expr(self) -> pl.Expr: >>> print(expr) # doctest: +NORMALIZE_WHITESPACE [(col("code")) == (String(BP//diastolic))].all_horizontal([[(col("chamber")) == (String(atrial))]]) + + >>> cfg = PlainPredicateConfig(code={'regex': None, 'any': None, 'all': None}) + >>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: Only one of 'regex', 'any', or 'all' can be specified in the code field! + Got: ['regex', 'any', 'all']. + >>> cfg = PlainPredicateConfig(code={'foo': None}) + >>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: Invalid specification in the code field! Got: {'foo': None}. + Expected one of 'regex', 'any', or 'all'. + >>> cfg = PlainPredicateConfig(code={'regex': ''}) + >>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: Invalid specification in the code field! Got: {'regex': ''}. + Expected a non-empty string for 'regex'. + >>> cfg = PlainPredicateConfig(code={'all': []}) + >>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: Invalid specification in the code field! Got: {'all': []}. + Expected a list of strings for 'all'. + + >>> cfg = PlainPredicateConfig(code={'regex': '^foo.*'}) + >>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE + >>> print(expr) # doctest: +NORMALIZE_WHITESPACE + col("code").str.contains([String(^foo.*)]) + >>> cfg = PlainPredicateConfig(code={'all': ['foo', 'bar']}) + >>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE + >>> print(expr) # doctest: +NORMALIZE_WHITESPACE + [(col("code")) == (String(foo))].all_horizontal([[(col("code")) == (String(bar))]]) + >>> cfg = PlainPredicateConfig(code={'any': ['foo', 'bar']}) + >>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE + >>> print(expr) # doctest: +NORMALIZE_WHITESPACE + [(col("code")) == (String(foo))].any_horizontal([[(col("code")) == (String(bar))]]) """ - criteria = [pl.col("code") == self.code] + criteria = [] + if isinstance(self.code, dict): + if len(self.code) > 1: + raise ValueError( + "Only one of 'regex', 'any', or 'all' can be specified in the code field! " + f"Got: {list(self.code.keys())}." + ) - if self.value_min is not None: - if self.value_min_inclusive: - criteria.append(pl.col("numerical_value") >= self.value_min) - else: - criteria.append(pl.col("numerical_value") > self.value_min) - if self.value_max is not None: - if self.value_max_inclusive: - criteria.append(pl.col("numerical_value") <= self.value_max) + if "regex" in self.code: + if not self.code["regex"] or not isinstance(self.code["regex"], str): + raise ValueError( + "Invalid specification in the code field! " + f"Got: {self.code}. " + "Expected a non-empty string for 'regex'." + ) + criteria.append(pl.col("code").str.contains(self.code["regex"])) + elif "any" in self.code or "all" in self.code: # 'all' is redundant? it shouldn't be possible...? + logic = list(self.code.keys())[0] + if not self.code[logic] or not isinstance(self.code[logic], list): + raise ValueError( + "Invalid specification in the code field! " + f"Got: {self.code}. " + f"Expected a list of strings for '{logic}'." + ) + criteria.append( + pl.all_horizontal([pl.col("code") == code for code in self.code[logic]]) + if logic == "all" + else pl.any_horizontal([pl.col("code") == code for code in self.code[logic]]) + ) else: - criteria.append(pl.col("numerical_value") < self.value_max) + raise ValueError( + "Invalid specification in the code field! " + f"Got: {self.code}. " + "Expected one of 'regex', 'any', or 'all'." + ) + else: + criteria.append(pl.col("code") == self.code) + + if self.value_min is not None: + if self.value_min_inclusive: + criteria.append(pl.col("numerical_value") >= self.value_min) + else: + criteria.append(pl.col("numerical_value") > self.value_min) + if self.value_max is not None: + if self.value_max_inclusive: + criteria.append(pl.col("numerical_value") <= self.value_max) + else: + criteria.append(pl.col("numerical_value") < self.value_max) if self.other_cols: criteria.extend([pl.col(col) == value for col, value in self.other_cols.items()]) diff --git a/src/aces/predicates.py b/src/aces/predicates.py index 862b678..7e26e2a 100644 --- a/src/aces/predicates.py +++ b/src/aces/predicates.py @@ -324,6 +324,7 @@ def generate_plain_predicates_from_meds(data_path: Path, predicates: dict) -> pl # generate plain predicate columns logger.info("Generating plain predicate columns...") for name, plain_predicate in predicates.items(): + data = data.with_columns(data["code"].cast(pl.Utf8).alias("code")) # may remove after MEDS v0.3 data = data.with_columns(plain_predicate.MEDS_eval_expr().cast(PRED_CNT_TYPE).alias(name)) logger.info(f"Added predicate column '{name}'.") From 9d5e9dca90e8ac48c6b5b59e6b1967de657fbb6c Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Sat, 10 Aug 2024 05:20:43 +0100 Subject: [PATCH 09/11] Removed 'all', used is_in, static variable filtering prints warning if no valid rows --- sample_configs/inhospital_mortality.yaml | 2 +- src/aces/config.py | 35 ++++++++++-------------- src/aces/query.py | 4 +++ 3 files changed, 20 insertions(+), 21 deletions(-) diff --git a/sample_configs/inhospital_mortality.yaml b/sample_configs/inhospital_mortality.yaml index 6eea0c6..384688f 100644 --- a/sample_configs/inhospital_mortality.yaml +++ b/sample_configs/inhospital_mortality.yaml @@ -2,7 +2,7 @@ predicates: admission: code: - regex: ^event_type.* + any: ["event_type//ADMISSION", "event_type//DISCHARGE"] discharge: code: event_type//DISCHARGE death: diff --git a/src/aces/config.py b/src/aces/config.py index 633f223..0121a7c 100644 --- a/src/aces/config.py +++ b/src/aces/config.py @@ -67,49 +67,49 @@ def MEDS_eval_expr(self) -> pl.Expr: [(col("code")) == (String(BP//diastolic))].all_horizontal([[(col("chamber")) == (String(atrial))]]) - >>> cfg = PlainPredicateConfig(code={'regex': None, 'any': None, 'all': None}) + >>> cfg = PlainPredicateConfig(code={'regex': None, 'any': None}) >>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE Traceback (most recent call last): ... - ValueError: Only one of 'regex', 'any', or 'all' can be specified in the code field! - Got: ['regex', 'any', 'all']. + ValueError: Only one of 'regex' or 'any' can be specified in the code field! + Got: ['regex', 'any']. >>> cfg = PlainPredicateConfig(code={'foo': None}) >>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE Traceback (most recent call last): ... ValueError: Invalid specification in the code field! Got: {'foo': None}. - Expected one of 'regex', 'any', or 'all'. + Expected one of 'regex', 'any'. >>> cfg = PlainPredicateConfig(code={'regex': ''}) >>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE Traceback (most recent call last): ... ValueError: Invalid specification in the code field! Got: {'regex': ''}. Expected a non-empty string for 'regex'. - >>> cfg = PlainPredicateConfig(code={'all': []}) + >>> cfg = PlainPredicateConfig(code={'any': []}) >>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE Traceback (most recent call last): ... - ValueError: Invalid specification in the code field! Got: {'all': []}. - Expected a list of strings for 'all'. + ValueError: Invalid specification in the code field! Got: {'any': []}. + Expected a list of strings for 'any'. >>> cfg = PlainPredicateConfig(code={'regex': '^foo.*'}) >>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE >>> print(expr) # doctest: +NORMALIZE_WHITESPACE col("code").str.contains([String(^foo.*)]) - >>> cfg = PlainPredicateConfig(code={'all': ['foo', 'bar']}) + >>> cfg = PlainPredicateConfig(code={'any': ['foo', 'bar']}) >>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE >>> print(expr) # doctest: +NORMALIZE_WHITESPACE [(col("code")) == (String(foo))].all_horizontal([[(col("code")) == (String(bar))]]) >>> cfg = PlainPredicateConfig(code={'any': ['foo', 'bar']}) >>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE >>> print(expr) # doctest: +NORMALIZE_WHITESPACE - [(col("code")) == (String(foo))].any_horizontal([[(col("code")) == (String(bar))]]) + col("code").is_in([Series]) """ criteria = [] if isinstance(self.code, dict): if len(self.code) > 1: raise ValueError( - "Only one of 'regex', 'any', or 'all' can be specified in the code field! " + "Only one of 'regex' or 'any' can be specified in the code field! " f"Got: {list(self.code.keys())}." ) @@ -121,24 +121,19 @@ def MEDS_eval_expr(self) -> pl.Expr: "Expected a non-empty string for 'regex'." ) criteria.append(pl.col("code").str.contains(self.code["regex"])) - elif "any" in self.code or "all" in self.code: # 'all' is redundant? it shouldn't be possible...? - logic = list(self.code.keys())[0] - if not self.code[logic] or not isinstance(self.code[logic], list): + elif "any" in self.code: + if not self.code["any"] or not isinstance(self.code["any"], list): raise ValueError( "Invalid specification in the code field! " f"Got: {self.code}. " - f"Expected a list of strings for '{logic}'." + f"Expected a list of strings for 'any'." ) - criteria.append( - pl.all_horizontal([pl.col("code") == code for code in self.code[logic]]) - if logic == "all" - else pl.any_horizontal([pl.col("code") == code for code in self.code[logic]]) - ) + criteria.append(pl.Expr.is_in(pl.col("code"), self.code["any"])) else: raise ValueError( "Invalid specification in the code field! " f"Got: {self.code}. " - "Expected one of 'regex', 'any', or 'all'." + "Expected one of 'regex', 'any'." ) else: criteria.append(pl.col("code") == self.code) diff --git a/src/aces/query.py b/src/aces/query.py index b7605d6..25805e0 100644 --- a/src/aces/query.py +++ b/src/aces/query.py @@ -65,6 +65,10 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame logger.info("No static variable criteria specified, removing all rows with null timestamps...") predicates_df = predicates_df.drop_nulls(subset=["subject_id", "timestamp"]) + if predicates_df.is_empty(): + logger.warning("No valid rows found after filtering patient demographics. Exiting.") + return pl.DataFrame() + logger.info("Identifying possible trigger nodes based on the specified trigger event...") prospective_root_anchors = check_constraints({cfg.trigger.predicate: (1, None)}, predicates_df).select( "subject_id", pl.col("timestamp").alias("subtree_anchor_timestamp") From 98c9471dc395503f5b0a47ca36520a377cfe9e60 Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Sat, 10 Aug 2024 05:21:49 +0100 Subject: [PATCH 10/11] Removed duplicate test --- src/aces/config.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/aces/config.py b/src/aces/config.py index 0121a7c..48b8238 100644 --- a/src/aces/config.py +++ b/src/aces/config.py @@ -99,10 +99,6 @@ def MEDS_eval_expr(self) -> pl.Expr: >>> cfg = PlainPredicateConfig(code={'any': ['foo', 'bar']}) >>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE >>> print(expr) # doctest: +NORMALIZE_WHITESPACE - [(col("code")) == (String(foo))].all_horizontal([[(col("code")) == (String(bar))]]) - >>> cfg = PlainPredicateConfig(code={'any': ['foo', 'bar']}) - >>> expr = cfg.MEDS_eval_expr() # doctest: +NORMALIZE_WHITESPACE - >>> print(expr) # doctest: +NORMALIZE_WHITESPACE col("code").is_in([Series]) """ criteria = [] From 95f518c366cc6fdb71cc72403543dfffec24e97d Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Sat, 10 Aug 2024 05:35:38 +0100 Subject: [PATCH 11/11] Use simplified dictionary membership check --- src/aces/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aces/config.py b/src/aces/config.py index 48b8238..703876b 100644 --- a/src/aces/config.py +++ b/src/aces/config.py @@ -1173,7 +1173,7 @@ def load(cls, config_path: str | Path, predicates_path: str | Path = None) -> Ta predicate_objs[n] = DerivedPredicateConfig(**p) else: config_data = {k: v for k, v in p.items() if k in PlainPredicateConfig.__dataclass_fields__} - other_cols = {k: v for k, v in p.items() if k not in config_data.keys()} + other_cols = {k: v for k, v in p.items() if k not in config_data} predicate_objs[n] = PlainPredicateConfig(**config_data, other_cols=other_cols) if patient_demographics: