Skip to content

Commit

Permalink
Updated to overtly define in MEDS format. Not yet validated.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Aug 11, 2024
1 parent 774d820 commit e62d1bc
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 4 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"pytimeparse == 1.1.*",
"networkx == 3.3.*",
"pyarrow == 16.1.*",
"meds == 0.3",
]

[project.scripts]
Expand Down
97 changes: 93 additions & 4 deletions src/aces/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from importlib.resources import files

import hydra
import polars as pl
import pyarrow as pa
import pyarrow.parquet as pq
from loguru import logger
from meds import label_schema
from omegaconf import DictConfig

config_yaml = files("aces").joinpath("configs/aces.yaml")
Expand All @@ -15,6 +19,88 @@
print("For more information, visit: https://eventstreamaces.readthedocs.io/en/latest/usage.html")
sys.exit(1)

MEDS_LABEL_MANDATORY_TYPES = {
"patient_id": pl.Int64,
"prediction_time": pl.Datetime("us"),
}

MEDS_LABEL_OPTIONAL_TYPES = {
"boolean_value": pl.Boolean,
"integer_value": pl.Int64,
"float_value": pl.Float64,
"categorical_value": pl.String,
}


def get_and_validate_label_schema(df: pl.LazyFrame) -> pa.Table:
"""Validates the schema of a MEDS data DataFrame.
This function validates the schema of a MEDS label DataFrame, ensuring that it has the correct columns
and that the columns are of the correct type. This function will:
1. Re-type any of the mandator MEDS column to the appropriate type.
2. Attempt to add the ``numeric_value`` or ``time`` columns if either are missing, and set it to `None`.
It will not attempt to add any other missing columns even if ``do_retype`` is `True` as the other
columns cannot be set to `None`.
Args:
df: The MEDS label DataFrame to validate.
Returns:
pa.Table: The validated MEDS data DataFrame, with columns re-typed as needed.
Raises:
ValueError: if do_retype is False and the MEDS data DataFrame is not schema compliant.
Examples:
>>> df = pl.DataFrame({})
>>> get_and_validate_label_schema(df.lazy()) # doctest: +NORMALIZE_WHITESPACE
Traceback (most recent call last):
...
ValueError: MEDS Data DataFrame must have a 'patient_id' column of type Int64.
MEDS Data DataFrame must have a 'prediction_time' column of type String.
Datetime(time_unit='us', time_zone=None).
>>> from datetime import datetime
>>> df = pl.DataFrame({
... "patient_id": pl.Series([1, 3, 2], dtype=pl.UInt32),
... "time": [datetime(2021, 1, 1), datetime(2021, 1, 2), datetime(2021, 1, 3)],
... "boolean_value": [1, 0, 100],
... })
>>> get_and_validate_label_schema(df.lazy())
pyarrow.Table
patient_id: int64
time: timestamp[us]
boolean_value: bool
integer_value: int64
float_value: float
categorical_value: string
----
patient_id: [[1,3,2]]
time: [[2021-01-01 00:00:00.000000,2021-01-02 00:00:00.000000,2021-01-03 00:00:00.000000]]
boolean_value: [[true,false,true]]
integer_value: [[null,null,null]]
float_value: [[null,null,null]]
categorical_value: [[null,null,null]]
"""

schema = df.collect_schema()
errors = []
for col, dtype in MEDS_LABEL_MANDATORY_TYPES.items():
if col in schema and schema[col] != dtype:
df = df.with_columns(pl.col(col).cast(dtype, strict=False))
elif col not in schema:
errors.append(f"MEDS Data DataFrame must have a '{col}' column of type {dtype}.")

Check warning on line 91 in src/aces/__main__.py

View check run for this annotation

Codecov / codecov/patch

src/aces/__main__.py#L85-L91

Added lines #L85 - L91 were not covered by tests

if errors:
raise ValueError("\n".join(errors))

Check warning on line 94 in src/aces/__main__.py

View check run for this annotation

Codecov / codecov/patch

src/aces/__main__.py#L93-L94

Added lines #L93 - L94 were not covered by tests

for col, dtype in MEDS_LABEL_OPTIONAL_TYPES.items():
if col in schema and schema[col] != dtype:
df = df.with_columns(pl.col(col).cast(dtype, strict=False))
elif col not in schema:
df = df.with_columns(pl.lit(None, dtype=dtype).alias(col))

Check warning on line 100 in src/aces/__main__.py

View check run for this annotation

Codecov / codecov/patch

src/aces/__main__.py#L96-L100

Added lines #L96 - L100 were not covered by tests

return df.collect().to_arrow().cast(label_schema)

Check warning on line 102 in src/aces/__main__.py

View check run for this annotation

Codecov / codecov/patch

src/aces/__main__.py#L102

Added line #L102 was not covered by tests


@hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem)
def main(cfg: DictConfig):
Expand Down Expand Up @@ -45,14 +131,17 @@ def main(cfg: DictConfig):
# query results
result = query.query(task_cfg, predicates_df)

# save results to parquet
os.makedirs(os.path.dirname(cfg.output_filepath), exist_ok=True)

if cfg.data.standard.lower() == "meds":
result = result.rename({"subject_id": "patient_id"})
if "index_timestamp" in result.columns:
result = result.rename({"index_timestamp": "prediction_time"})

# save results to parquet
os.makedirs(os.path.dirname(cfg.output_filepath), exist_ok=True)
result.write_parquet(cfg.output_filepath, use_pyarrow=True)
result = get_and_validate_label_schema(result.lazy())
pq.write_table(result, cfg.output_filepath)

Check warning on line 142 in src/aces/__main__.py

View check run for this annotation

Codecov / codecov/patch

src/aces/__main__.py#L138-L142

Added lines #L138 - L142 were not covered by tests
else:
result.write_parquet(cfg.output_filepath, use_pyarrow=True)
logger.info(f"Completed in {datetime.now() - st}. Results saved to '{cfg.output_filepath}'.")


Expand Down

0 comments on commit e62d1bc

Please sign in to comment.