Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Match MEDS label schema as per #72 #80

Merged
merged 5 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"pytimeparse == 1.1.*",
"networkx == 3.3.*",
"pyarrow == 16.1.*",
"meds == 0.3",
]

[project.scripts]
Expand Down
99 changes: 95 additions & 4 deletions src/aces/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from importlib.resources import files

import hydra
import polars as pl
import pyarrow as pa
import pyarrow.parquet as pq
from loguru import logger
from meds import label_schema
from omegaconf import DictConfig

config_yaml = files("aces").joinpath("configs/aces.yaml")
Expand All @@ -15,6 +19,88 @@
print("For more information, visit: https://eventstreamaces.readthedocs.io/en/latest/usage.html")
sys.exit(1)

MEDS_LABEL_MANDATORY_TYPES = {
"patient_id": pl.Int64,
"prediction_time": pl.Datetime("us"),
}

MEDS_LABEL_OPTIONAL_TYPES = {
"boolean_value": pl.Boolean,
"integer_value": pl.Int64,
"float_value": pl.Float64,
"categorical_value": pl.String,
}


def get_and_validate_label_schema(df: pl.LazyFrame) -> pa.Table:
"""Validates the schema of a MEDS data DataFrame.

This function validates the schema of a MEDS label DataFrame, ensuring that it has the correct columns
and that the columns are of the correct type. This function will:
1. Re-type any of the mandator MEDS column to the appropriate type.
2. Attempt to add the ``numeric_value`` or ``time`` columns if either are missing, and set it to `None`.
It will not attempt to add any other missing columns even if ``do_retype`` is `True` as the other
columns cannot be set to `None`.

Args:
df: The MEDS label DataFrame to validate.

Returns:
pa.Table: The validated MEDS data DataFrame, with columns re-typed as needed.

Raises:
ValueError: if do_retype is False and the MEDS data DataFrame is not schema compliant.

Examples:
>>> df = pl.DataFrame({})
>>> get_and_validate_label_schema(df.lazy()) # doctest: +NORMALIZE_WHITESPACE
Traceback (most recent call last):
...
ValueError: MEDS Data DataFrame must have a 'patient_id' column of type Int64.
MEDS Data DataFrame must have a 'prediction_time' column of type String.
Datetime(time_unit='us', time_zone=None).
>>> from datetime import datetime
>>> df = pl.DataFrame({
... "patient_id": pl.Series([1, 3, 2], dtype=pl.UInt32),
... "time": [datetime(2021, 1, 1), datetime(2021, 1, 2), datetime(2021, 1, 3)],
... "boolean_value": [1, 0, 100],
... })
>>> get_and_validate_label_schema(df.lazy())
pyarrow.Table
patient_id: int64
time: timestamp[us]
boolean_value: bool
integer_value: int64
float_value: float
categorical_value: string
----
patient_id: [[1,3,2]]
time: [[2021-01-01 00:00:00.000000,2021-01-02 00:00:00.000000,2021-01-03 00:00:00.000000]]
boolean_value: [[true,false,true]]
integer_value: [[null,null,null]]
float_value: [[null,null,null]]
categorical_value: [[null,null,null]]
"""

schema = df.collect_schema()
errors = []
for col, dtype in MEDS_LABEL_MANDATORY_TYPES.items():
if col in schema and schema[col] != dtype:
df = df.with_columns(pl.col(col).cast(dtype, strict=False))
elif col not in schema:
errors.append(f"MEDS Data DataFrame must have a '{col}' column of type {dtype}.")

Check warning on line 91 in src/aces/__main__.py

View check run for this annotation

Codecov / codecov/patch

src/aces/__main__.py#L85-L91

Added lines #L85 - L91 were not covered by tests

if errors:
raise ValueError("\n".join(errors))

Check warning on line 94 in src/aces/__main__.py

View check run for this annotation

Codecov / codecov/patch

src/aces/__main__.py#L93-L94

Added lines #L93 - L94 were not covered by tests

for col, dtype in MEDS_LABEL_OPTIONAL_TYPES.items():
if col in schema and schema[col] != dtype:
df = df.with_columns(pl.col(col).cast(dtype, strict=False))
elif col not in schema:
df = df.with_columns(pl.lit(None, dtype=dtype).alias(col))

Check warning on line 100 in src/aces/__main__.py

View check run for this annotation

Codecov / codecov/patch

src/aces/__main__.py#L96-L100

Added lines #L96 - L100 were not covered by tests

return df.collect().to_arrow().cast(label_schema)

Check warning on line 102 in src/aces/__main__.py

View check run for this annotation

Codecov / codecov/patch

src/aces/__main__.py#L102

Added line #L102 was not covered by tests
justin13601 marked this conversation as resolved.
Show resolved Hide resolved


@hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem)
def main(cfg: DictConfig):
Expand Down Expand Up @@ -45,12 +131,17 @@
# query results
result = query.query(task_cfg, predicates_df)

if cfg.data.standard.lower() == "meds":
result = result.rename({"subject_id": "patient_id"})

# save results to parquet
os.makedirs(os.path.dirname(cfg.output_filepath), exist_ok=True)
result.write_parquet(cfg.output_filepath, use_pyarrow=True)

if cfg.data.standard.lower() == "meds":
result = result.rename({"subject_id": "patient_id"})
if "index_timestamp" in result.columns:
result = result.rename({"index_timestamp": "prediction_time"})
result = get_and_validate_label_schema(result.lazy())
pq.write_table(result, cfg.output_filepath)

Check warning on line 142 in src/aces/__main__.py

View check run for this annotation

Codecov / codecov/patch

src/aces/__main__.py#L138-L142

Added lines #L138 - L142 were not covered by tests
justin13601 marked this conversation as resolved.
Show resolved Hide resolved
else:
result.write_parquet(cfg.output_filepath, use_pyarrow=True)
logger.info(f"Completed in {datetime.now() - st}. Results saved to '{cfg.output_filepath}'.")


Expand Down
23 changes: 13 additions & 10 deletions src/aces/expand_shards.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/usr/bin/env python

import glob
import os
import re
import sys
from pathlib import Path


def expand_shards(*shards: str) -> str:
Expand All @@ -23,7 +23,6 @@ def expand_shards(*shards: str) -> str:
Examples:
>>> import polars as pl
>>> import tempfile
>>> from pathlib import Path

>>> expand_shards("train/4", "val/IID/1", "val/prospective/1")
'train/0,train/1,train/2,train/3,val/IID/0,val/prospective/0'
Expand All @@ -38,20 +37,24 @@ def expand_shards(*shards: str) -> str:

>>> with tempfile.TemporaryDirectory() as tmpdirname:
... for i in range(4):
... with tempfile.NamedTemporaryFile(mode="w", suffix=".parquet") as f:
... data_path = Path(tmpdirname + f"/file_{i}")
... parquet_data.write_parquet(data_path)
... if i in (0, 2):
... data_path = Path(tmpdirname) / f"evens/0/file_{i}.parquet"
... data_path.parent.mkdir(parents=True, exist_ok=True)
... else:
... data_path = Path(tmpdirname) / f"{i}.parquet"
... parquet_data.write_parquet(data_path)
... json_fp = Path(tmpdirname) / ".shards.json"
... _ = json_fp.write_text('["foo"]')
... result = expand_shards(tmpdirname)
... ','.join(sorted(os.path.basename(f) for f in result.split(',')))
'file_0,file_1,file_2,file_3'
... sorted(str(Path(x).relative_to(Path(tmpdirname))) for x in result.split(","))
justin13601 marked this conversation as resolved.
Show resolved Hide resolved
['1.parquet', '3.parquet', 'evens/0/file_0.parquet', 'evens/0/file_2.parquet']
"""

result = []
for arg in shards:
if os.path.isdir(arg):
# If the argument is a directory, list all files in the directory
files = glob.glob(os.path.join(arg, "*"))
result.extend(files)
# If the argument is a directory, take all parquet files in any subdirs of the directory
result.extend(str(x.resolve()) for x in Path(arg).glob("**/*.parquet"))
else:
# Otherwise, treat it as a shard prefix and number of shards
match = re.match(r"(.+)([/_])(\d+)$", arg)
Expand Down
Loading