Skip to content

Commit

Permalink
Fixed a few other errors for MEDS v0.3 compatability. Now tested with…
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Aug 11, 2024
1 parent 0cde8f6 commit e965539
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 16 deletions.
39 changes: 32 additions & 7 deletions src/aces/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@

MEDS_LABEL_MANDATORY_TYPES = {
"patient_id": pl.Int64,
"prediction_time": pl.Datetime("us"),
}

MEDS_LABEL_OPTIONAL_TYPES = {
"boolean_value": pl.Boolean,
"integer_value": pl.Int64,
"float_value": pl.Float64,
"categorical_value": pl.String,
"prediction_time": pl.Datetime("us"),
}


def get_and_validate_label_schema(df: pl.LazyFrame) -> pa.Table:
def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table:
"""Validates the schema of a MEDS data DataFrame.
This function validates the schema of a MEDS label DataFrame, ensuring that it has the correct columns
Expand All @@ -53,7 +53,7 @@ def get_and_validate_label_schema(df: pl.LazyFrame) -> pa.Table:
Examples:
>>> df = pl.DataFrame({})
>>> get_and_validate_label_schema(df.lazy()) # doctest: +NORMALIZE_WHITESPACE
>>> get_and_validate_label_schema(df) # doctest: +NORMALIZE_WHITESPACE
Traceback (most recent call last):
...
ValueError: MEDS Data DataFrame must have a 'patient_id' column of type Int64.
Expand All @@ -65,7 +65,7 @@ def get_and_validate_label_schema(df: pl.LazyFrame) -> pa.Table:
... "time": [datetime(2021, 1, 1), datetime(2021, 1, 2), datetime(2021, 1, 3)],
... "boolean_value": [1, 0, 100],
... })
>>> get_and_validate_label_schema(df.lazy())
>>> get_and_validate_label_schema(df)
pyarrow.Table
patient_id: int64
time: timestamp[us]
Expand All @@ -82,7 +82,14 @@ def get_and_validate_label_schema(df: pl.LazyFrame) -> pa.Table:
categorical_value: [[null,null,null]]
"""

schema = df.collect_schema()
schema = df.schema
if "prediction_time" not in schema:
logger.warning(

Check warning on line 87 in src/aces/__main__.py

View check run for this annotation

Codecov / codecov/patch

src/aces/__main__.py#L85-L87

Added lines #L85 - L87 were not covered by tests
"Output DataFrame is missing a 'prediction_time' column. If this not intentional, add a "
"'index_timestamp' (yes, it should be different) key to the task configuration identifying "
"which window's start or end time to use as the prediction time."
)

errors = []
for col, dtype in MEDS_LABEL_MANDATORY_TYPES.items():
if col in schema and schema[col] != dtype:
Expand All @@ -99,7 +106,23 @@ def get_and_validate_label_schema(df: pl.LazyFrame) -> pa.Table:
elif col not in schema:
df = df.with_columns(pl.lit(None, dtype=dtype).alias(col))

Check warning on line 107 in src/aces/__main__.py

View check run for this annotation

Codecov / codecov/patch

src/aces/__main__.py#L103-L107

Added lines #L103 - L107 were not covered by tests

return df.collect().to_arrow().cast(label_schema)
extra_cols = [

Check warning on line 109 in src/aces/__main__.py

View check run for this annotation

Codecov / codecov/patch

src/aces/__main__.py#L109

Added line #L109 was not covered by tests
c for c in schema if c not in MEDS_LABEL_MANDATORY_TYPES and c not in MEDS_LABEL_OPTIONAL_TYPES
]
if extra_cols:
err_cols_str = "\n".join(f" - {c}" for c in extra_cols)
logger.warning(

Check warning on line 114 in src/aces/__main__.py

View check run for this annotation

Codecov / codecov/patch

src/aces/__main__.py#L112-L114

Added lines #L112 - L114 were not covered by tests
"Output contains columns that are not valid MEDS label columns. For now, we are dropping them.\n"
"If you need these columns, please comment on https://github.com/justin13601/ACES/issues/97\n"
f"Columns:\n{err_cols_str}"
)
df = df.drop(extra_cols)

Check warning on line 119 in src/aces/__main__.py

View check run for this annotation

Codecov / codecov/patch

src/aces/__main__.py#L119

Added line #L119 was not covered by tests

df = df.select(

Check warning on line 121 in src/aces/__main__.py

View check run for this annotation

Codecov / codecov/patch

src/aces/__main__.py#L121

Added line #L121 was not covered by tests
"patient_id", "prediction_time", "boolean_value", "integer_value", "float_value", "categorical_value"
)

return df.to_arrow().cast(label_schema)

Check warning on line 125 in src/aces/__main__.py

View check run for this annotation

Codecov / codecov/patch

src/aces/__main__.py#L125

Added line #L125 was not covered by tests


@hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem)
Expand Down Expand Up @@ -138,7 +161,9 @@ def main(cfg: DictConfig):
result = result.rename({"subject_id": "patient_id"})
if "index_timestamp" in result.columns:
result = result.rename({"index_timestamp": "prediction_time"})
result = get_and_validate_label_schema(result.lazy())
if "label" in result.columns:
result = result.rename({"label": "boolean_value"})
result = get_and_validate_label_schema(result)
pq.write_table(result, cfg.output_filepath)

Check warning on line 167 in src/aces/__main__.py

View check run for this annotation

Codecov / codecov/patch

src/aces/__main__.py#L161-L167

Added lines #L161 - L167 were not covered by tests
else:
result.write_parquet(cfg.output_filepath, use_pyarrow=True)
Expand Down
14 changes: 8 additions & 6 deletions src/aces/expand_shards.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def expand_shards(*shards: str) -> str:
>>> parquet_data = pl.DataFrame({
... "patient_id": [1, 1, 1, 2, 3],
... "timestamp": ["1/1/1989 00:00", "1/1/1989 01:00", "1/1/1989 01:00", "1/1/1989 02:00", None],
... "time": ["1/1/1989 00:00", "1/1/1989 01:00", "1/1/1989 01:00", "1/1/1989 02:00", None],
... "code": ['admission', 'discharge', 'discharge', 'admission', "gender"],
... }).with_columns(pl.col("timestamp").str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M"))
... }).with_columns(pl.col("time").str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M"))
>>> with tempfile.TemporaryDirectory() as tmpdirname:
... for i in range(4):
Expand All @@ -43,18 +43,20 @@ def expand_shards(*shards: str) -> str:
... else:
... data_path = Path(tmpdirname) / f"{i}.parquet"
... parquet_data.write_parquet(data_path)
... json_fp = Path(tmpdirname) / ".shards.json"
... json_fp = Path(tmpdirname) / "4.json"
... _ = json_fp.write_text('["foo"]')
... result = expand_shards(tmpdirname)
... sorted(str(Path(x).relative_to(Path(tmpdirname))) for x in result.split(","))
['1.parquet', '3.parquet', 'evens/0/file_0.parquet', 'evens/0/file_2.parquet']
... sorted(result.split(","))
['1', '3', 'evens/0/file_0', 'evens/0/file_2']
"""

result = []
for arg in shards:
if os.path.isdir(arg):
# If the argument is a directory, take all parquet files in any subdirs of the directory
result.extend(str(x.resolve()) for x in Path(arg).glob("**/*.parquet"))
result.extend(
str(x.relative_to(Path(arg)).with_suffix("")) for x in Path(arg).glob("**/*.parquet")
)
else:
# Otherwise, treat it as a shard prefix and number of shards
match = re.match(r"(.+)([/_])(\d+)$", arg)
Expand Down
6 changes: 3 additions & 3 deletions src/aces/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,9 @@ def generate_plain_predicates_from_meds(data_path: Path, predicates: dict) -> pl
>>> from .config import PlainPredicateConfig
>>> parquet_data = pl.DataFrame({
... "patient_id": [1, 1, 1, 2, 3],
... "timestamp": ["1/1/1989 00:00", "1/1/1989 01:00", "1/1/1989 01:00", "1/1/1989 02:00", None],
... "time": ["1/1/1989 00:00", "1/1/1989 01:00", "1/1/1989 01:00", "1/1/1989 02:00", None],
... "code": ['admission', 'discharge', 'discharge', 'admission', "gender//male"],
... }).with_columns(pl.col("timestamp").str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M"))
... }).with_columns(pl.col("time").str.strptime(pl.Datetime, format="%m/%d/%Y %H:%M"))
>>> with tempfile.NamedTemporaryFile(mode="w", suffix=".parquet") as f:
... data_path = Path(f.name)
... parquet_data.write_parquet(data_path)
Expand All @@ -316,7 +316,7 @@ def generate_plain_predicates_from_meds(data_path: Path, predicates: dict) -> pl
"""

logger.info("Loading MEDS data...")
data = pl.read_parquet(data_path).rename({"patient_id": "subject_id"})
data = pl.read_parquet(data_path).rename({"patient_id": "subject_id", "time": "timestamp"})

if data.columns == ["subject_id", "events"]:
data = unnest_meds(data)
Expand Down

0 comments on commit e965539

Please sign in to comment.