From b52a3ae762065c67dc30a8c57b34798d801fbf53 Mon Sep 17 00:00:00 2001 From: Justin Xu Date: Fri, 26 Jul 2024 10:47:18 +0100 Subject: [PATCH 1/4] Rename index_timestamp as part of #72 --- src/aces/__main__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/aces/__main__.py b/src/aces/__main__.py index 80f94ab..9c51efc 100644 --- a/src/aces/__main__.py +++ b/src/aces/__main__.py @@ -47,6 +47,8 @@ def main(cfg: DictConfig): if cfg.data.standard.lower() == "meds": result = result.rename({"subject_id": "patient_id"}) + if "index_timestamp" in result.columns: + result = result.rename({"index_timestamp": "prediction_time"}) # save results to parquet os.makedirs(os.path.dirname(cfg.output_filepath), exist_ok=True) From e62d1bc32cebbae3d6a5c782d6403918d6ef0965 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 11 Aug 2024 15:12:25 -0400 Subject: [PATCH 2/4] Updated to overtly define in MEDS format. Not yet validated. --- pyproject.toml | 1 + src/aces/__main__.py | 97 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 94 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0612da6..ba14c91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "pytimeparse == 1.1.*", "networkx == 3.3.*", "pyarrow == 16.1.*", + "meds == 0.3", ] [project.scripts] diff --git a/src/aces/__main__.py b/src/aces/__main__.py index 9c51efc..15c6012 100644 --- a/src/aces/__main__.py +++ b/src/aces/__main__.py @@ -4,7 +4,11 @@ from importlib.resources import files import hydra +import polars as pl +import pyarrow as pa +import pyarrow.parquet as pq from loguru import logger +from meds import label_schema from omegaconf import DictConfig config_yaml = files("aces").joinpath("configs/aces.yaml") @@ -15,6 +19,88 @@ print("For more information, visit: https://eventstreamaces.readthedocs.io/en/latest/usage.html") sys.exit(1) +MEDS_LABEL_MANDATORY_TYPES = { + "patient_id": pl.Int64, + "prediction_time": pl.Datetime("us"), +} + +MEDS_LABEL_OPTIONAL_TYPES = { + "boolean_value": pl.Boolean, + "integer_value": pl.Int64, + "float_value": pl.Float64, + "categorical_value": pl.String, +} + + +def get_and_validate_label_schema(df: pl.LazyFrame) -> pa.Table: + """Validates the schema of a MEDS data DataFrame. + + This function validates the schema of a MEDS label DataFrame, ensuring that it has the correct columns + and that the columns are of the correct type. This function will: + 1. Re-type any of the mandator MEDS column to the appropriate type. + 2. Attempt to add the ``numeric_value`` or ``time`` columns if either are missing, and set it to `None`. + It will not attempt to add any other missing columns even if ``do_retype`` is `True` as the other + columns cannot be set to `None`. + + Args: + df: The MEDS label DataFrame to validate. + + Returns: + pa.Table: The validated MEDS data DataFrame, with columns re-typed as needed. + + Raises: + ValueError: if do_retype is False and the MEDS data DataFrame is not schema compliant. + + Examples: + >>> df = pl.DataFrame({}) + >>> get_and_validate_label_schema(df.lazy()) # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: MEDS Data DataFrame must have a 'patient_id' column of type Int64. + MEDS Data DataFrame must have a 'prediction_time' column of type String. + Datetime(time_unit='us', time_zone=None). + >>> from datetime import datetime + >>> df = pl.DataFrame({ + ... "patient_id": pl.Series([1, 3, 2], dtype=pl.UInt32), + ... "time": [datetime(2021, 1, 1), datetime(2021, 1, 2), datetime(2021, 1, 3)], + ... "boolean_value": [1, 0, 100], + ... }) + >>> get_and_validate_label_schema(df.lazy()) + pyarrow.Table + patient_id: int64 + time: timestamp[us] + boolean_value: bool + integer_value: int64 + float_value: float + categorical_value: string + ---- + patient_id: [[1,3,2]] + time: [[2021-01-01 00:00:00.000000,2021-01-02 00:00:00.000000,2021-01-03 00:00:00.000000]] + boolean_value: [[true,false,true]] + integer_value: [[null,null,null]] + float_value: [[null,null,null]] + categorical_value: [[null,null,null]] + """ + + schema = df.collect_schema() + errors = [] + for col, dtype in MEDS_LABEL_MANDATORY_TYPES.items(): + if col in schema and schema[col] != dtype: + df = df.with_columns(pl.col(col).cast(dtype, strict=False)) + elif col not in schema: + errors.append(f"MEDS Data DataFrame must have a '{col}' column of type {dtype}.") + + if errors: + raise ValueError("\n".join(errors)) + + for col, dtype in MEDS_LABEL_OPTIONAL_TYPES.items(): + if col in schema and schema[col] != dtype: + df = df.with_columns(pl.col(col).cast(dtype, strict=False)) + elif col not in schema: + df = df.with_columns(pl.lit(None, dtype=dtype).alias(col)) + + return df.collect().to_arrow().cast(label_schema) + @hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem) def main(cfg: DictConfig): @@ -45,14 +131,17 @@ def main(cfg: DictConfig): # query results result = query.query(task_cfg, predicates_df) + # save results to parquet + os.makedirs(os.path.dirname(cfg.output_filepath), exist_ok=True) + if cfg.data.standard.lower() == "meds": result = result.rename({"subject_id": "patient_id"}) if "index_timestamp" in result.columns: result = result.rename({"index_timestamp": "prediction_time"}) - - # save results to parquet - os.makedirs(os.path.dirname(cfg.output_filepath), exist_ok=True) - result.write_parquet(cfg.output_filepath, use_pyarrow=True) + result = get_and_validate_label_schema(result.lazy()) + pq.write_table(result, cfg.output_filepath) + else: + result.write_parquet(cfg.output_filepath, use_pyarrow=True) logger.info(f"Completed in {datetime.now() - st}. Results saved to '{cfg.output_filepath}'.") From 0cde8f6c9b26882b67e045a4d8bdae15b81fae67 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 11 Aug 2024 15:21:17 -0400 Subject: [PATCH 3/4] Fixed expand_shards to work with arbitrary MEDS sharding strategies. --- src/aces/expand_shards.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/aces/expand_shards.py b/src/aces/expand_shards.py index 9638092..6ff5a1a 100755 --- a/src/aces/expand_shards.py +++ b/src/aces/expand_shards.py @@ -1,9 +1,9 @@ #!/usr/bin/env python -import glob import os import re import sys +from pathlib import Path def expand_shards(*shards: str) -> str: @@ -23,7 +23,6 @@ def expand_shards(*shards: str) -> str: Examples: >>> import polars as pl >>> import tempfile - >>> from pathlib import Path >>> expand_shards("train/4", "val/IID/1", "val/prospective/1") 'train/0,train/1,train/2,train/3,val/IID/0,val/prospective/0' @@ -38,20 +37,24 @@ def expand_shards(*shards: str) -> str: >>> with tempfile.TemporaryDirectory() as tmpdirname: ... for i in range(4): - ... with tempfile.NamedTemporaryFile(mode="w", suffix=".parquet") as f: - ... data_path = Path(tmpdirname + f"/file_{i}") - ... parquet_data.write_parquet(data_path) + ... if i in (0, 2): + ... data_path = Path(tmpdirname) / f"evens/0/file_{i}.parquet" + ... data_path.parent.mkdir(parents=True, exist_ok=True) + ... else: + ... data_path = Path(tmpdirname) / f"{i}.parquet" + ... parquet_data.write_parquet(data_path) + ... json_fp = Path(tmpdirname) / ".shards.json" + ... _ = json_fp.write_text('["foo"]') ... result = expand_shards(tmpdirname) - ... ','.join(sorted(os.path.basename(f) for f in result.split(','))) - 'file_0,file_1,file_2,file_3' + ... sorted(str(Path(x).relative_to(Path(tmpdirname))) for x in result.split(",")) + ['1.parquet', '3.parquet', 'evens/0/file_0.parquet', 'evens/0/file_2.parquet'] """ result = [] for arg in shards: if os.path.isdir(arg): - # If the argument is a directory, list all files in the directory - files = glob.glob(os.path.join(arg, "*")) - result.extend(files) + # If the argument is a directory, take all parquet files in any subdirs of the directory + result.extend(str(x.resolve()) for x in Path(arg).glob("**/*.parquet")) else: # Otherwise, treat it as a shard prefix and number of shards match = re.match(r"(.+)([/_])(\d+)$", arg) From e9655390f25bf79167370a802176bcf671cefa44 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 11 Aug 2024 15:56:38 -0400 Subject: [PATCH 4/4] Fixed a few other errors for MEDS v0.3 compatability. Now tested with config https://gist.github.com/mmcdermott/80a9086d8fdf36f2fd04b8e4912348ac on a tiny MIMIC cohort. --- src/aces/__main__.py | 39 ++++++++++++++++++++++++++++++++------- src/aces/expand_shards.py | 14 ++++++++------ src/aces/predicates.py | 6 +++--- 3 files changed, 43 insertions(+), 16 deletions(-) diff --git a/src/aces/__main__.py b/src/aces/__main__.py index 15c6012..e28747b 100644 --- a/src/aces/__main__.py +++ b/src/aces/__main__.py @@ -21,7 +21,6 @@ MEDS_LABEL_MANDATORY_TYPES = { "patient_id": pl.Int64, - "prediction_time": pl.Datetime("us"), } MEDS_LABEL_OPTIONAL_TYPES = { @@ -29,10 +28,11 @@ "integer_value": pl.Int64, "float_value": pl.Float64, "categorical_value": pl.String, + "prediction_time": pl.Datetime("us"), } -def get_and_validate_label_schema(df: pl.LazyFrame) -> pa.Table: +def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table: """Validates the schema of a MEDS data DataFrame. This function validates the schema of a MEDS label DataFrame, ensuring that it has the correct columns @@ -53,7 +53,7 @@ def get_and_validate_label_schema(df: pl.LazyFrame) -> pa.Table: Examples: >>> df = pl.DataFrame({}) - >>> get_and_validate_label_schema(df.lazy()) # doctest: +NORMALIZE_WHITESPACE + >>> get_and_validate_label_schema(df) # doctest: +NORMALIZE_WHITESPACE Traceback (most recent call last): ... ValueError: MEDS Data DataFrame must have a 'patient_id' column of type Int64. @@ -65,7 +65,7 @@ def get_and_validate_label_schema(df: pl.LazyFrame) -> pa.Table: ... "time": [datetime(2021, 1, 1), datetime(2021, 1, 2), datetime(2021, 1, 3)], ... "boolean_value": [1, 0, 100], ... }) - >>> get_and_validate_label_schema(df.lazy()) + >>> get_and_validate_label_schema(df) pyarrow.Table patient_id: int64 time: timestamp[us] @@ -82,7 +82,14 @@ def get_and_validate_label_schema(df: pl.LazyFrame) -> pa.Table: categorical_value: [[null,null,null]] """ - schema = df.collect_schema() + schema = df.schema + if "prediction_time" not in schema: + logger.warning( + "Output DataFrame is missing a 'prediction_time' column. If this not intentional, add a " + "'index_timestamp' (yes, it should be different) key to the task configuration identifying " + "which window's start or end time to use as the prediction time." + ) + errors = [] for col, dtype in MEDS_LABEL_MANDATORY_TYPES.items(): if col in schema and schema[col] != dtype: @@ -99,7 +106,23 @@ def get_and_validate_label_schema(df: pl.LazyFrame) -> pa.Table: elif col not in schema: df = df.with_columns(pl.lit(None, dtype=dtype).alias(col)) - return df.collect().to_arrow().cast(label_schema) + extra_cols = [ + c for c in schema if c not in MEDS_LABEL_MANDATORY_TYPES and c not in MEDS_LABEL_OPTIONAL_TYPES + ] + if extra_cols: + err_cols_str = "\n".join(f" - {c}" for c in extra_cols) + logger.warning( + "Output contains columns that are not valid MEDS label columns. For now, we are dropping them.\n" + "If you need these columns, please comment on https://github.com/justin13601/ACES/issues/97\n" + f"Columns:\n{err_cols_str}" + ) + df = df.drop(extra_cols) + + df = df.select( + "patient_id", "prediction_time", "boolean_value", "integer_value", "float_value", "categorical_value" + ) + + return df.to_arrow().cast(label_schema) @hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem) @@ -138,7 +161,9 @@ def main(cfg: DictConfig): result = result.rename({"subject_id": "patient_id"}) if "index_timestamp" in result.columns: result = result.rename({"index_timestamp": "prediction_time"}) - result = get_and_validate_label_schema(result.lazy()) + if "label" in result.columns: + result = result.rename({"label": "boolean_value"}) + result = get_and_validate_label_schema(result) pq.write_table(result, cfg.output_filepath) else: result.write_parquet(cfg.output_filepath, use_pyarrow=True) diff --git a/src/aces/expand_shards.py b/src/aces/expand_shards.py index 6ff5a1a..ca64a9a 100755 --- a/src/aces/expand_shards.py +++ b/src/aces/expand_shards.py @@ -31,9 +31,9 @@ def expand_shards(*shards: str) -> str: >>> parquet_data = pl.DataFrame({ ... "patient_id": [1, 1, 1, 2, 3], - ... "timestamp": ["1/1/1989 00:00", "1/1/1989 01:00", "1/1/1989 01:00", "1/1/1989 02:00", None], + ... "time": ["1/1/1989 00:00", "1/1/1989 01:00", "1/1/1989 01:00", "1/1/1989 02:00", None], ... "code": ['admission', 'discharge', 'discharge', 'admission', "gender"], - ... }).with_columns(pl.col("timestamp").str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M")) + ... }).with_columns(pl.col("time").str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M")) >>> with tempfile.TemporaryDirectory() as tmpdirname: ... for i in range(4): @@ -43,18 +43,20 @@ def expand_shards(*shards: str) -> str: ... else: ... data_path = Path(tmpdirname) / f"{i}.parquet" ... parquet_data.write_parquet(data_path) - ... json_fp = Path(tmpdirname) / ".shards.json" + ... json_fp = Path(tmpdirname) / "4.json" ... _ = json_fp.write_text('["foo"]') ... result = expand_shards(tmpdirname) - ... sorted(str(Path(x).relative_to(Path(tmpdirname))) for x in result.split(",")) - ['1.parquet', '3.parquet', 'evens/0/file_0.parquet', 'evens/0/file_2.parquet'] + ... sorted(result.split(",")) + ['1', '3', 'evens/0/file_0', 'evens/0/file_2'] """ result = [] for arg in shards: if os.path.isdir(arg): # If the argument is a directory, take all parquet files in any subdirs of the directory - result.extend(str(x.resolve()) for x in Path(arg).glob("**/*.parquet")) + result.extend( + str(x.relative_to(Path(arg)).with_suffix("")) for x in Path(arg).glob("**/*.parquet") + ) else: # Otherwise, treat it as a shard prefix and number of shards match = re.match(r"(.+)([/_])(\d+)$", arg) diff --git a/src/aces/predicates.py b/src/aces/predicates.py index 7e26e2a..752fe6d 100644 --- a/src/aces/predicates.py +++ b/src/aces/predicates.py @@ -291,9 +291,9 @@ def generate_plain_predicates_from_meds(data_path: Path, predicates: dict) -> pl >>> from .config import PlainPredicateConfig >>> parquet_data = pl.DataFrame({ ... "patient_id": [1, 1, 1, 2, 3], - ... "timestamp": ["1/1/1989 00:00", "1/1/1989 01:00", "1/1/1989 01:00", "1/1/1989 02:00", None], + ... "time": ["1/1/1989 00:00", "1/1/1989 01:00", "1/1/1989 01:00", "1/1/1989 02:00", None], ... "code": ['admission', 'discharge', 'discharge', 'admission', "gender//male"], - ... }).with_columns(pl.col("timestamp").str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M")) + ... }).with_columns(pl.col("time").str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M")) >>> with tempfile.NamedTemporaryFile(mode="w", suffix=".parquet") as f: ... data_path = Path(f.name) ... parquet_data.write_parquet(data_path) @@ -316,7 +316,7 @@ def generate_plain_predicates_from_meds(data_path: Path, predicates: dict) -> pl """ logger.info("Loading MEDS data...") - data = pl.read_parquet(data_path).rename({"patient_id": "subject_id"}) + data = pl.read_parquet(data_path).rename({"patient_id": "subject_id", "time": "timestamp"}) if data.columns == ["subject_id", "events"]: data = unnest_meds(data)