diff --git a/MIMIC-IV_Example/configs/event_configs.yaml b/MIMIC-IV_Example/configs/event_configs.yaml index 2da8471..38a900b 100644 --- a/MIMIC-IV_Example/configs/event_configs.yaml +++ b/MIMIC-IV_Example/configs/event_configs.yaml @@ -45,6 +45,10 @@ hosp/diagnoses_icd: hadm_id: hadm_id timestamp: col(hadm_discharge_time) timestamp_format: "%Y-%m-%d %H:%M:%S" + _metadata: + d_icd_diagnoses: + description: "long_title" + parent_code: "ICD{icd_version}CM/{icd_code}" # Single strings are templates of columns. hosp/drgcodes: drg: @@ -79,6 +83,11 @@ hosp/hcpcsevents: hadm_id: hadm_id timestamp: col(chartdate) timestamp_format: "%Y-%m-%d" + _metadata: + # These are not all CPT codes, unfortunately + d_hcpcs: + description: "long_description" + possibly_cpt_code: "code" hosp/labevents: lab: @@ -92,6 +101,11 @@ hosp/labevents: numerical_value: valuenum text_value: value priority: priority + _metadata: + d_labitems_to_loinc: + description: ["omop_concept_name", "label"] # List of strings are columns to be collated + itemid: "itemid (omop_source_code)" + parent_code: "{omop_vocabulary_id}/{omop_concept_code}" hosp/omr: omr: @@ -152,6 +166,12 @@ hosp/procedures_icd: hadm_id: hadm_id timestamp: col(chartdate) timestamp_format: "%Y-%m-%d" + _metadata: + d_icd_procedures: + description: "long_title" + parent_code: # List of objects are string labels mapping to filters to be evaluated. + - "ICD{icd_version}Proc/{icd_code}": { icd_version: 9 } + - "ICD{icd_version}PCS/{icd_code}": { icd_version: 10 } hosp/transfers: transfer: @@ -193,6 +213,16 @@ icu/chartevents: text_value: value hadm_id: hadm_id icustay_id: stay_id + _metadata: + meas_chartevents_main: + description: ["omop_concept_name", "label"] # List of strings are columns to be collated + itemid: "itemid (omop_source_code)" + parent_code: "{omop_vocabulary_id}/{omop_concept_code}" + # TODO: I don't know if this is necessary... + d_labitems_to_loinc: + description: ["omop_concept_name", "label"] # List of strings are columns to be collated + itemid: "itemid (omop_source_code)" + parent_code: "{omop_vocabulary_id}/{omop_concept_code}" icu/procedureevents: start: @@ -204,6 +234,15 @@ icu/procedureevents: timestamp_format: "%Y-%m-%d %H:%M:%S" hadm_id: hadm_id icustay_id: stay_id + _metadata: + proc_datetimeevents: + description: ["omop_concept_name", "label"] # List of strings are columns to be collated + itemid: "itemid (omop_source_code)" + parent_code: "{omop_vocabulary_id}/{omop_concept_code}" + proc_itemid: + description: ["omop_concept_name", "label"] # List of strings are columns to be collated + itemid: "itemid (omop_source_code)" + parent_code: "{omop_vocabulary_id}/{omop_concept_code}" end: code: - PROCEDURE @@ -213,3 +252,77 @@ icu/procedureevents: timestamp_format: "%Y-%m-%d %H:%M:%S" hadm_id: hadm_id icustay_id: stay_id + _metadata: + proc_datetimeevents: + description: ["omop_concept_name", "label"] # List of strings are columns to be collated + itemid: "itemid (omop_source_code)" + parent_code: "{omop_vocabulary_id}/{omop_concept_code}" + proc_itemid: + description: ["omop_concept_name", "label"] # List of strings are columns to be collated + itemid: "itemid (omop_source_code)" + parent_code: "{omop_vocabulary_id}/{omop_concept_code}" + +icu/inputevents: + input_start: + code: + - INFUSION_START + - col(ordercategorydescription) + - col(itemid) + - col(rateuom) + timestamp: col(starttime) + timestamp_format: "%Y-%m-%d %H:%M:%S" + hadm_id: hadm_id + icustay_id: stay_id + order_id: orderid + link_order_id: linkorderid + numerical_value: rate + _metadata: + inputevents_to_rxnorm: + description: ["omop_concept_name", "label"] # List of strings are columns to be collated + itemid: "itemid (omop_source_code)" + parent_code: "{omop_vocabulary_id}/{omop_concept_code}" + rateuom: null # A null column means this column is needed in pulling from the metadata. + input_end: + code: + - INFUSION_END + - col(ordercategorydescription) + - col(itemid) + - col(amountuom) + - col(statusdescription) + timestamp: col(endtime) + timestamp_format: "%Y-%m-%d %H:%M:%S" + hadm_id: hadm_id + icustay_id: stay_id + order_id: orderid + link_order_id: linkorderid + numerical_value: amount + _metadata: + inputevents_to_rxnorm: + description: ["omop_concept_name", "label"] # List of strings are columns to be collated + itemid: "itemid (omop_source_code)" + parent_code: "{omop_vocabulary_id}/{omop_concept_code}" + patient_weight: + code: + - PATIENT_WEIGHT_AT_INFUSION + - KG + timestamp: col(starttime) + timestamp_format: "%Y-%m-%d %H:%M:%S" + numerical_value: patientweight + +icu/outputevents: + output: + code: + - PATIENT_FLUID_OUTPUT + - col(itemid) + - col(valueuom) + timestamp: col(charttime) + timestamp_format: "%Y-%m-%d %H:%M:%S" + hadm_id: hadm_id + icustay_id: stay_id + numerical_value: value + _metadata: + outputevents_to_rxnorm: + description: ["omop_concept_name", "label"] # List of strings are columns to be collated + itemid: "itemid (omop_source_code)" + parent_code: "{omop_vocabulary_id}/{omop_concept_code}" + valueuom: unitname diff --git a/pyproject.toml b/pyproject.toml index a3b6368..040b45b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ MEDS_extract-shard_events = "MEDS_polars_functions.extract.shard_events:main" MEDS_extract-convert_to_sharded_events = "MEDS_polars_functions.extract.convert_to_sharded_events:main" MEDS_extract-merge_to_MEDS_cohort = "MEDS_polars_functions.extract.merge_to_MEDS_cohort:main" MEDS_transform-aggregate_code_metadata = "MEDS_polars_functions.aggregate_code_metadata:main" +MEDS_extract-extract_code_metadata = "MEDS_polars_functions.extract.extract_code_metadata:main" [project.urls] Homepage = "https://github.com/mmcdermott/MEDS_polars_functions" diff --git a/src/MEDS_polars_functions/configs/extract.yaml b/src/MEDS_polars_functions/configs/extract.yaml index afabfbd..734d3da 100644 --- a/src/MEDS_polars_functions/configs/extract.yaml +++ b/src/MEDS_polars_functions/configs/extract.yaml @@ -35,6 +35,7 @@ stages: - convert_to_sharded_events - merge_to_MEDS_cohort - aggregate_code_metadata + - extract_code_metadata stage_configs: aggregate_code_metadata: diff --git a/src/MEDS_polars_functions/extract/convert_to_sharded_events.py b/src/MEDS_polars_functions/extract/convert_to_sharded_events.py index 010104c..9d29a6a 100755 --- a/src/MEDS_polars_functions/extract/convert_to_sharded_events.py +++ b/src/MEDS_polars_functions/extract/convert_to_sharded_events.py @@ -15,6 +15,7 @@ from omegaconf.listconfig import ListConfig from MEDS_polars_functions.extract import CONFIG_YAML +from MEDS_polars_functions.extract.shard_events import META_KEYS from MEDS_polars_functions.mapreduce.mapper import rwlock_wrap from MEDS_polars_functions.utils import ( is_col_field, @@ -29,6 +30,78 @@ def in_format(fmt: str, ts_name: str) -> pl.Expr: return pl.col(ts_name).str.strptime(pl.Datetime, fmt, strict=False) +def get_code_expr(code_field: str | list | ListConfig) -> tuple[pl.Expr, pl.Expr | None, set[str]]: + """Converts the code field in an event config file to a polars expression, null filter, and column set. + + Args: + code_field: The string or list representation of the code field in the event configuration file. + + Returns: + pl.Expr: The polars expression representing the code field. + pl.Expr | None: The null filter expression for the code field. + set[str]: The set of columns needed to construct the code field. + + Raises: + ValueError: If the code field is not a valid type. + + Examples: + >>> print(*get_code_expr("A")) # doctest: +NORMALIZE_WHITESPACE + String(A).strict_cast(String).strict_cast(Categorical(None, Physical)) + None + set() + >>> print(*get_code_expr("col(B)")) # doctest: +NORMALIZE_WHITESPACE + col("B").strict_cast(String).fill_null([String(UNK)]).strict_cast(Categorical(None, Physical)) + col("B").is_not_null() + {'B'} + >>> print(*get_code_expr(["col(A)", "B"])) # doctest: +NORMALIZE_WHITESPACE + [([(col("A").strict_cast(String).fill_null([String(UNK)])) + (String(//))]) + + (String(B).strict_cast(String))].strict_cast(Categorical(None, Physical)) + col("A").is_not_null() + {'A'} + >>> get_code_expr(34) + Traceback (most recent call last): + ... + ValueError: Invalid code field: 34 + >>> get_code_expr(["a", 34, "b"]) + Traceback (most recent call last): + ... + ValueError: Invalid code literal: 34 + + Note that it only takes the first column field for the null filter, not all of them. + >>> expr, null_filter, cols = get_code_expr(["col(A)", "col(c)"]) + >>> print(expr) # doctest: +NORMALIZE_WHITESPACE + [([(col("A").strict_cast(String).fill_null([String(UNK)])) + (String(//))]) + + (col("c").strict_cast(String).fill_null([String(UNK)]))].strict_cast(Categorical(None, Physical)) + >>> print(null_filter) + col("A").is_not_null() + >>> print(sorted(cols)) + ['A', 'c'] + """ + if isinstance(code_field, str): + code_field = [code_field] + elif not isinstance(code_field, (list, ListConfig)): + raise ValueError(f"Invalid code field: {code_field}") + + code_exprs = [] + code_null_filter_expr = None + needed_cols = set() + for i, code in enumerate(code_field): + match code: + case str() if is_col_field(code): + code_col = parse_col_field(code) + needed_cols.add(code_col) + code_exprs.append(pl.col(code_col).cast(pl.Utf8).fill_null("UNK")) + if i == 0: + code_null_filter_expr = pl.col(code_col).is_not_null() + case str(): + code_exprs.append(pl.lit(code, dtype=pl.Utf8)) + case _: + raise ValueError(f"Invalid code literal: {code}") + code_expr = reduce(lambda a, b: a + pl.lit("//") + b, code_exprs).cast(pl.Categorical) + + return code_expr, code_null_filter_expr, needed_cols + + def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.LazyFrame: """Extracts a single event dataframe from the raw data. @@ -183,7 +256,8 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> valid_discharge_event_cfg = { ... "code": ["DISCHARGE", "col(discharge_location)"], ... "timestamp": "col(discharge_time)", - ... "discharge_status": "discharge_status", + ... "categorical_value": "discharge_status", # Note the raw dtype of this col is str + ... "text_value": "discharge_location", # Note the raw dtype of this col is categorical ... } >>> valid_death_event_cfg = { ... "code": "DEATH", @@ -223,20 +297,54 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy │ 2 ┆ ADMISSION//E ┆ 2021-01-05 00:00:00 ┆ 5.0 │ │ 3 ┆ ADMISSION//F ┆ 2021-01-06 00:00:00 ┆ 6.0 │ └────────────┴──────────────┴─────────────────────┴─────────────────┘ - >>> extract_event(complex_raw_data, valid_discharge_event_cfg) + >>> extract_event( + ... complex_raw_data.with_columns(pl.col("severity_score").cast(pl.Utf8)), + ... valid_admission_event_cfg + ... ) + shape: (6, 4) + ┌────────────┬──────────────┬─────────────────────┬─────────────────┐ + │ patient_id ┆ code ┆ timestamp ┆ numerical_value │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ u8 ┆ cat ┆ datetime[μs] ┆ f64 │ + ╞════════════╪══════════════╪═════════════════════╪═════════════════╡ + │ 1 ┆ ADMISSION//A ┆ 2021-01-01 00:00:00 ┆ 1.0 │ + │ 1 ┆ ADMISSION//B ┆ 2021-01-02 00:00:00 ┆ 2.0 │ + │ 2 ┆ ADMISSION//C ┆ 2021-01-03 00:00:00 ┆ 3.0 │ + │ 2 ┆ ADMISSION//D ┆ 2021-01-04 00:00:00 ┆ 4.0 │ + │ 2 ┆ ADMISSION//E ┆ 2021-01-05 00:00:00 ┆ 5.0 │ + │ 3 ┆ ADMISSION//F ┆ 2021-01-06 00:00:00 ┆ 6.0 │ + └────────────┴──────────────┴─────────────────────┴─────────────────┘ + >>> extract_event( + ... complex_raw_data.with_columns(pl.col("severity_score").cast(pl.Utf8).cast(pl.Categorical)), + ... valid_admission_event_cfg + ... ) shape: (6, 4) - ┌────────────┬─────────────────┬─────────────────────┬──────────────────┐ - │ patient_id ┆ code ┆ timestamp ┆ discharge_status │ - │ --- ┆ --- ┆ --- ┆ --- │ - │ u8 ┆ cat ┆ datetime[μs] ┆ cat │ - ╞════════════╪═════════════════╪═════════════════════╪══════════════════╡ - │ 1 ┆ DISCHARGE//Home ┆ 2021-01-01 11:23:45 ┆ AOx4 │ - │ 1 ┆ DISCHARGE//SNF ┆ 2021-01-02 12:34:56 ┆ AO │ - │ 2 ┆ DISCHARGE//Home ┆ 2021-01-03 13:45:56 ┆ AAO │ - │ 2 ┆ DISCHARGE//SNF ┆ 2021-01-04 14:56:45 ┆ AOx3 │ - │ 2 ┆ DISCHARGE//Home ┆ 2021-01-05 15:23:45 ┆ AOx4 │ - │ 3 ┆ DISCHARGE//SNF ┆ 2021-01-06 16:34:56 ┆ AOx4 │ - └────────────┴─────────────────┴─────────────────────┴──────────────────┘ + ┌────────────┬──────────────┬─────────────────────┬─────────────────┐ + │ patient_id ┆ code ┆ timestamp ┆ numerical_value │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ u8 ┆ cat ┆ datetime[μs] ┆ f64 │ + ╞════════════╪══════════════╪═════════════════════╪═════════════════╡ + │ 1 ┆ ADMISSION//A ┆ 2021-01-01 00:00:00 ┆ 1.0 │ + │ 1 ┆ ADMISSION//B ┆ 2021-01-02 00:00:00 ┆ 2.0 │ + │ 2 ┆ ADMISSION//C ┆ 2021-01-03 00:00:00 ┆ 3.0 │ + │ 2 ┆ ADMISSION//D ┆ 2021-01-04 00:00:00 ┆ 4.0 │ + │ 2 ┆ ADMISSION//E ┆ 2021-01-05 00:00:00 ┆ 5.0 │ + │ 3 ┆ ADMISSION//F ┆ 2021-01-06 00:00:00 ┆ 6.0 │ + └────────────┴──────────────┴─────────────────────┴─────────────────┘ + >>> extract_event(complex_raw_data, valid_discharge_event_cfg) + shape: (6, 5) + ┌────────────┬─────────────────┬─────────────────────┬───────────────────┬────────────┐ + │ patient_id ┆ code ┆ timestamp ┆ categorical_value ┆ text_value │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ u8 ┆ cat ┆ datetime[μs] ┆ cat ┆ str │ + ╞════════════╪═════════════════╪═════════════════════╪═══════════════════╪════════════╡ + │ 1 ┆ DISCHARGE//Home ┆ 2021-01-01 11:23:45 ┆ AOx4 ┆ Home │ + │ 1 ┆ DISCHARGE//SNF ┆ 2021-01-02 12:34:56 ┆ AO ┆ SNF │ + │ 2 ┆ DISCHARGE//Home ┆ 2021-01-03 13:45:56 ┆ AAO ┆ Home │ + │ 2 ┆ DISCHARGE//SNF ┆ 2021-01-04 14:56:45 ┆ AOx3 ┆ SNF │ + │ 2 ┆ DISCHARGE//Home ┆ 2021-01-05 15:23:45 ┆ AOx4 ┆ Home │ + │ 3 ┆ DISCHARGE//SNF ┆ 2021-01-06 16:34:56 ┆ AOx4 ┆ SNF │ + └────────────┴─────────────────┴─────────────────────┴───────────────────┴────────────┘ >>> extract_event(complex_raw_data, valid_death_event_cfg) shape: (3, 3) ┌────────────┬───────┬─────────────────────┐ @@ -268,10 +376,6 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy Traceback (most recent call last): ..". KeyError: "Event configuration dictionary must contain 'timestamp' key. Got: [code, value]." - >>> extract_event(complex_raw_data, {"code": 34, "timestamp": "col(admission_time)"}) - Traceback (most recent call last): - ... - ValueError: Invalid code literal: 34 >>> extract_event(complex_raw_data, {"code": "test", "timestamp": "12-01-23"}) Traceback (most recent call last): ... @@ -291,9 +395,9 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> extract_event(complex_raw_data, {"code": "test", "timestamp": None, "foobar": "discharge_time"}) Traceback (most recent call last): ... - ValueError: Source column 'discharge_time' for event column foobar is not numeric or categorical! Cannot be used as an event col. + ValueError: Source column 'discharge_time' for event column foobar is not numeric, string, or categorical! Cannot be used as an event col. """ # noqa: E501 - df = df + event_cfg = copy.deepcopy(event_cfg) event_exprs = {"patient_id": pl.col("patient_id")} if "code" not in event_cfg: @@ -309,29 +413,14 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy if "patient_id" in event_cfg: raise KeyError("Event column name 'patient_id' cannot be overridden.") - codes = event_cfg.pop("code") - if not isinstance(codes, (list, ListConfig)): - logger.debug( - f"Event code '{codes}' is a {type(codes)}, not a list. Automatically converting to a list." - ) - codes = [codes] + code_expr, code_null_filter_expr, needed_cols = get_code_expr(event_cfg.pop("code")) - code_exprs = [] - code_null_filter_expr = None - for i, code in enumerate(codes): - match code: - case str() if is_col_field(code) and parse_col_field(code) in df.schema: - code_col = parse_col_field(code) - logger.info(f"Extracting code column {code_col}") - code_exprs.append(pl.col(code_col).cast(pl.Utf8).fill_null("UNK")) - if i == 0: - code_null_filter_expr = pl.col(code_col).is_not_null() - case str(): - logger.info(f"Adding code literate {code}") - code_exprs.append(pl.lit(code, dtype=pl.Utf8)) - case _: - raise ValueError(f"Invalid code literal: {code}") - event_exprs["code"] = reduce(lambda a, b: a + pl.lit("//") + b, code_exprs).cast(pl.Categorical) + for col in needed_cols: + if col not in df.schema: + raise KeyError(f"Source column '{col}' for event column code not found in DataFrame schema.") + logger.info(f"Extracting column {col}") + + event_exprs["code"] = code_expr ts = event_cfg.pop("timestamp") ts_format = event_cfg.pop("timestamp_format", None) @@ -358,6 +447,9 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy raise ValueError(f"Invalid timestamp literal: {ts}") for k, v in event_cfg.items(): + if k in META_KEYS: + continue + if not isinstance(v, str): raise ValueError( f"For event column {k}, source column {v} must be a string column name. Got {type(v)}." @@ -373,15 +465,33 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy raise KeyError(f"Source column '{v}' for event column {k} not found in DataFrame schema.") col = pl.col(v) - if df.schema[v] == pl.Utf8: - col = col.cast(pl.Categorical) - elif isinstance(df.schema[v], pl.Categorical): - pass - elif not df.schema[v].is_numeric(): - raise ValueError( - f"Source column '{v}' for event column {k} is not numeric or categorical! " - "Cannot be used as an event col." - ) + is_numeric = df.schema[v].is_numeric() + is_str = df.schema[v] == pl.Utf8 + is_cat = isinstance(df.schema[v], pl.Categorical) + match k: + case "numerical_value" if is_numeric: + pass + case "numerical_value" if is_str: + logger.warning(f"Converting numerical_value to float from string for {code_expr}") + col = col.cast(pl.Float64, strict=False) + case "numerical_value" if is_cat: + logger.warning(f"Converting numerical_value to float from categorical for {code_expr}") + col = col.cast(pl.Utf8).cast(pl.Float64, strict=False) + case "text_value" if not df.schema[v] == pl.Utf8: + logger.warning(f"Converting text_value to string for {code_expr}") + col = col.cast(pl.Utf8, strict=False) + case "categorical_value" if not isinstance(df.schema[v], pl.Categorical): + logger.warning(f"Converting categorical_value to categorical for {code_expr}") + col = col.cast(pl.Utf8).cast(pl.Categorical) + case _ if is_str: + # TODO(mmd): Is this right? Is this always a good idea? It probably usually is, but maybe not + # always. Maybe a check on unique values first? + col = col.cast(pl.Categorical) + case _ if not (is_numeric or is_str or is_cat): + raise ValueError( + f"Source column '{v}' for event column {k} is not numeric, string, or categorical! " + "Cannot be used as an event col." + ) event_exprs[k] = col @@ -394,11 +504,6 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy df = df.select(**event_exprs).unique(maintain_order=True) - # if numerical_value column is not numeric, convert it to float - if "numerical_value" in df.columns and not df.schema["numerical_value"].is_numeric(): - logger.warning(f"Converting numerical_value to float for codes {codes}") - df = df.with_columns(pl.col("numerical_value").cast(pl.Float64, strict=False)) - return df diff --git a/src/MEDS_polars_functions/extract/extract_code_metadata.py b/src/MEDS_polars_functions/extract/extract_code_metadata.py new file mode 100644 index 0000000..5676fd1 --- /dev/null +++ b/src/MEDS_polars_functions/extract/extract_code_metadata.py @@ -0,0 +1,396 @@ +#!/usr/bin/env python +"""Utilities for extracting code metadata about the codes produced for the MEDS events.""" + +import copy +import random +import time +from datetime import datetime +from functools import partial +from pathlib import Path + +import hydra +import polars as pl +from loguru import logger +from omegaconf import DictConfig, OmegaConf + +from MEDS_polars_functions.extract import CONFIG_YAML +from MEDS_polars_functions.extract.convert_to_sharded_events import get_code_expr +from MEDS_polars_functions.extract.parser import cfg_to_expr +from MEDS_polars_functions.extract.utils import get_supported_fp +from MEDS_polars_functions.mapreduce.mapper import rwlock_wrap +from MEDS_polars_functions.utils import stage_init, write_lazyframe + + +def extract_metadata( + metadata_df: pl.LazyFrame, event_cfg: dict[str, str | None], allowed_codes: list | None = None +) -> pl.LazyFrame: + """Extracts a single metadata dataframe block for an event configuration from the raw metadata. + + Args: + df: The raw metadata DataFrame. Mandatory columns are determined by the `event_cfg` configuration + dictionary. + event_cfg: A dictionary containing the configuration for the event. This must contain the critical + `"code"` key alongside a mandatory `_metadata` block, which must contain some columns that should + be extracted from the metadata to link to the code. + The `"code"` key must contain either (1) a string literal representing the code for the event or + (2) the name of a column in the raw data from which the code should be extracted. In the latter + case, the column name should be enclosed in `col()` function call syntax--e.g., + `col(my_code_column)`. Note there are no quotes used inside the `col()` function syntax. + + Returns: + A DataFrame containing the metadata extracted and linked to appropriately constructed code strings for + the event configuration. The output DataFrame will contain at least two columns: `"code"` and whatever + metadata column is specified for extraction in the metadata block. The output dataframe will not + necessarily be unique by code if the input metadata is not unique by code. + + Raises: + KeyError: If the event configuration dictionary is missing the `"code"` or `"_metadata"` keys or if + the `"_metadata_"` key is empty or if columns referenced by the event configuration dictionary are + not found in the raw metadata. + + Examples: + >>> extract_metadata(pl.DataFrame(), {}) + Traceback (most recent call last): + ... + KeyError: "Event configuration dictionary must contain 'code' key. Got: []." + >>> extract_metadata(pl.DataFrame(), {"code": "test"}) + Traceback (most recent call last): + ... + KeyError: "Event configuration dictionary must contain a non-empty '_metadata' key. Got: [code]." + >>> raw_metadata = pl.DataFrame({ + ... "code": ["A", "B", "C", "D", "E"], + ... "code_modifier": ["1", "2", "3", "4", "5"], + ... "name": ["Code A-1", "B-2", "C with 3", "D, but 4", None], + ... "priority": [1, 2, 3, 4, 5], + ... }) + >>> event_cfg = { + ... "code": ["FOO", "col(code)", "col(code_modifier)"], + ... "_metadata": {"desc": "name"}, + ... } + >>> extract_metadata(raw_metadata, event_cfg) + shape: (4, 2) + ┌───────────┬──────────┐ + │ code ┆ desc │ + │ --- ┆ --- │ + │ cat ┆ str │ + ╞═══════════╪══════════╡ + │ FOO//A//1 ┆ Code A-1 │ + │ FOO//B//2 ┆ B-2 │ + │ FOO//C//3 ┆ C with 3 │ + │ FOO//D//4 ┆ D, but 4 │ + └───────────┴──────────┘ + >>> extract_metadata(raw_metadata, event_cfg, allowed_codes=["FOO//A//1", "FOO//C//3"]) + shape: (2, 2) + ┌───────────┬──────────┐ + │ code ┆ desc │ + │ --- ┆ --- │ + │ cat ┆ str │ + ╞═══════════╪══════════╡ + │ FOO//A//1 ┆ Code A-1 │ + │ FOO//C//3 ┆ C with 3 │ + └───────────┴──────────┘ + >>> extract_metadata(raw_metadata.drop("code_modifier"), event_cfg) + Traceback (most recent call last): + ... + KeyError: "Columns {'code_modifier'} not found in metadata columns: ['code', 'name', 'priority']" + >>> extract_metadata(raw_metadata, ['foo']) + Traceback (most recent call last): + ... + TypeError: Event configuration must be a dictionary. Got: ['foo']. + + You can also manipulate the columns in more complex ways when assigning metadata from the input source. + >>> raw_metadata = pl.DataFrame({ + ... "code": ["A", "A", "C", "D"], + ... "code_modifier": ["1", "1", "2", "3"], + ... "code_modifier_2": ["1", "2", "3", "4"], + ... "title": ["A-1-1", "A-1-2", "C-2-3", None], + ... "special_title": ["used", None, None, None], + ... }) + >>> event_cfg = { + ... "code": ["FOO", "col(code)", "col(code_modifier)"], + ... "_metadata": { + ... "desc": ["special_title", "title"], + ... "parent_code": [ + ... {"OUT_VAL/{code_modifier}/2": {"code_modifier_2": "2"}}, + ... {"OUT_VAL_for_3/{code_modifier}": {"code_modifier_2": "3"}}, + ... { + ... "matcher": {"code_modifier_2": "4"}, + ... "output": {"literal": "expanded form"}, + ... }, + ... ], + ... }, + ... } + >>> extract_metadata(raw_metadata, event_cfg) + shape: (4, 3) + ┌───────────┬───────┬─────────────────┐ + │ code ┆ desc ┆ parent_code │ + │ --- ┆ --- ┆ --- │ + │ cat ┆ str ┆ str │ + ╞═══════════╪═══════╪═════════════════╡ + │ FOO//A//1 ┆ used ┆ null │ + │ FOO//A//1 ┆ A-1-2 ┆ OUT_VAL/1/2 │ + │ FOO//C//2 ┆ C-2-3 ┆ OUT_VAL_for_3/2 │ + │ FOO//D//3 ┆ null ┆ expanded form │ + └───────────┴───────┴─────────────────┘ + """ + event_cfg = copy.deepcopy(event_cfg) + + if not isinstance(event_cfg, (dict, DictConfig)): + raise TypeError(f"Event configuration must be a dictionary. Got: {type(event_cfg)} {event_cfg}.") + + if "code" not in event_cfg: + raise KeyError( + "Event configuration dictionary must contain 'code' key. " + f"Got: [{', '.join(event_cfg.keys())}]." + ) + if "_metadata" not in event_cfg or not event_cfg["_metadata"]: + raise KeyError( + "Event configuration dictionary must contain a non-empty '_metadata' key. " + f"Got: [{', '.join(event_cfg.keys())}]." + ) + + df_select_exprs = {} + final_cols = [] + needed_cols = set() + for out_col, in_cfg in event_cfg["_metadata"].items(): + in_expr, needed = cfg_to_expr(in_cfg) + df_select_exprs[out_col] = in_expr + final_cols.append(out_col) + needed_cols.update(needed) + + code_expr, _, needed_code_cols = get_code_expr(event_cfg.pop("code")) + + columns = metadata_df.collect_schema().names() + missing_cols = (needed_cols | needed_code_cols) - set(columns) + if missing_cols: + raise KeyError(f"Columns {missing_cols} not found in metadata columns: {columns}") + + for col in needed_code_cols: + if col not in df_select_exprs: + df_select_exprs[col] = pl.col(col) + + metadata_df = metadata_df.select(**df_select_exprs).with_columns(code=code_expr) + + if allowed_codes: + metadata_df = metadata_df.filter(pl.col("code").is_in(allowed_codes)) + + metadata_df = metadata_df.filter(~pl.all_horizontal(*[pl.col(c).is_null() for c in final_cols])) + + return metadata_df.unique(maintain_order=True).select("code", *final_cols) + + +def extract_all_metadata( + metadata_df: pl.LazyFrame, event_cfgs: list[dict], allowed_codes: list | None = None +) -> pl.LazyFrame: + """Extracts all metadata for a list of event configurations. + + Args: + metadata_df: The raw metadata DataFrame. Mandatory columns are determined by the `event_cfg` + configurations. + event_cfgs: A list of event configuration dictionaries. Each dictionary must contain the code + and metadata elements. + allowed_codes: A list of codes to allow in the output metadata. If None, all codes are allowed. + + Returns: + A unified DF containing all metadata for all event configurations. + + Examples: + >>> raw_metadata = pl.DataFrame({ + ... "code": ["A", "B", "C", "D"], + ... "code_modifier": ["1", "2", "3", "4"], + ... "name": ["Code A-1", "B-2", "C with 3", "D, but 4"], + ... "priority": [1, 2, 3, 4], + ... }) + >>> event_cfg_1 = { + ... "code": ["FOO", "col(code)", "col(code_modifier)"], + ... "_metadata": {"desc": "name"}, + ... } + >>> event_cfg_2 = { + ... "code": ["BAR", "col(code)", "col(code_modifier)"], + ... "_metadata": {"desc2": "name"}, + ... } + >>> event_cfgs = [event_cfg_1, event_cfg_2] + >>> extract_all_metadata(raw_metadata, event_cfgs, allowed_codes=["FOO//A//1", "BAR//B//2"]) + shape: (2, 3) + ┌───────────┬──────────┬───────┐ + │ code ┆ desc ┆ desc2 │ + │ --- ┆ --- ┆ --- │ + │ cat ┆ str ┆ str │ + ╞═══════════╪══════════╪═══════╡ + │ FOO//A//1 ┆ Code A-1 ┆ null │ + │ BAR//B//2 ┆ null ┆ B-2 │ + └───────────┴──────────┴───────┘ + """ + + all_metadata = [] + for event_cfg in event_cfgs: + all_metadata.append(extract_metadata(metadata_df, event_cfg, allowed_codes=allowed_codes)) + + return pl.concat(all_metadata, how="diagonal_relaxed").unique(maintain_order=True) + + +def get_events_and_metadata_by_metadata_fp(event_configs: dict | DictConfig) -> dict[str, dict[str, dict]]: + """Reformats the event conversion config to map metadata file input prefixes to linked event configs. + + Args: + event_configs: The event conversion configuration dictionary. + + Returns: + A dictionary keyed by metadata input file prefix mapping to a dictionary of event configurations that + link to that metadata prefix. + + Examples: + >>> event_configs = { + ... "icu/procedureevents": { + ... "patient_id_col": "subject_id", + ... "start": { + ... "code": ["PROCEDURE", "START", "col(itemid)"], + ... "_metadata": { + ... "proc_datetimeevents": {"desc": ["omop_concept_name", "label"]}, + ... "proc_itemid": {"desc": ["omop_concept_name", "label"]}, + ... }, + ... }, + ... "end": { + ... "code": ["PROCEDURE", "END", "col(itemid)"], + ... "_metadata": { + ... "proc_datetimeevents": {"desc": ["omop_concept_name", "label"]}, + ... "proc_itemid": {"desc": ["omop_concept_name", "label"]}, + ... }, + ... }, + ... }, + ... "icu/inputevents": { + ... "event": { + ... "code": ["INFUSION", "col(itemid)"], + ... "_metadata": { + ... "inputevents_to_rxnorm": {"desc": "{label}", "itemid": "{foo}"} + ... }, + ... }, + ... }, + ... } + >>> get_events_and_metadata_by_metadata_fp(event_configs) # doctest: +NORMALIZE_WHITESPACE + {'proc_datetimeevents': [{'code': ['PROCEDURE', 'START', 'col(itemid)'], + '_metadata': {'desc': ['omop_concept_name', 'label']}}, + {'code': ['PROCEDURE', 'END', 'col(itemid)'], + '_metadata': {'desc': ['omop_concept_name', 'label']}}], + 'proc_itemid': [{'code': ['PROCEDURE', 'START', 'col(itemid)'], + '_metadata': {'desc': ['omop_concept_name', 'label']}}, + {'code': ['PROCEDURE', 'END', 'col(itemid)'], + '_metadata': {'desc': ['omop_concept_name', 'label']}}], + 'inputevents_to_rxnorm': [{'code': ['INFUSION', 'col(itemid)'], + '_metadata': {'desc': '{label}', 'itemid': '{foo}'}}]} + >>> no_metadata_event_configs = { + ... "icu/procedureevents": { + ... "start": {"code": ["PROCEDURE", "START", "col(itemid)"]}, + ... "end": {"code": ["PROCEDURE", "END", "col(itemid)"]}, + ... }, + ... "icu/inputevents": { + ... "event": {"code": ["INFUSION", "col(itemid)"]}, + ... }, + ... } + >>> get_events_and_metadata_by_metadata_fp(no_metadata_event_configs) + {} + """ + + out = {} + + for event_cfgs_for_pfx in event_configs.values(): + for event_key, event_cfg in event_cfgs_for_pfx.items(): + if event_key == "patient_id_col": + continue + + for metadata_pfx, metadata_cfg in event_cfg.get("_metadata", {}).items(): + if metadata_pfx not in out: + out[metadata_pfx] = [] + out[metadata_pfx].append({"code": event_cfg["code"], "_metadata": metadata_cfg}) + + return out + + +@hydra.main(version_base=None, config_path=str(CONFIG_YAML.parent), config_name=CONFIG_YAML.stem) +def main(cfg: DictConfig): + """TODO.""" + + stage_input_dir, partial_metadata_dir, _, _ = stage_init(cfg) + raw_input_dir = Path(cfg.input_dir) + + event_conversion_cfg_fp = Path(cfg.event_conversion_config_fp) + if not event_conversion_cfg_fp.exists(): + raise FileNotFoundError(f"Event conversion config file not found: {event_conversion_cfg_fp}") + + logger.info(f"Reading event conversion config from {event_conversion_cfg_fp}") + event_conversion_cfg = OmegaConf.load(event_conversion_cfg_fp) + logger.info(f"Event conversion config:\n{OmegaConf.to_yaml(event_conversion_cfg)}") + + partial_metadata_dir.mkdir(parents=True, exist_ok=True) + OmegaConf.save(event_conversion_cfg, partial_metadata_dir / "event_conversion_config.yaml") + + events_and_metadata_by_metadata_fp = get_events_and_metadata_by_metadata_fp(event_conversion_cfg) + event_metadata_configs = list(events_and_metadata_by_metadata_fp.items()) + random.shuffle(event_metadata_configs) + + # Load all codes + all_codes = ( + pl.scan_parquet(stage_input_dir / "**/*.parquet") + .select(pl.col("code").unique()) + .collect() + .get_column("code") + .to_list() + ) + + all_out_fps = [] + for input_prefix, event_metadata_cfgs in event_metadata_configs: + event_metadata_cfgs = copy.deepcopy(event_metadata_cfgs) + + metadata_fp, read_fn = get_supported_fp(raw_input_dir, input_prefix) + out_fp = partial_metadata_dir / f"{input_prefix}.parquet" + logger.info(f"Extracting metadata from {metadata_fp} and saving to {out_fp}") + + compute_fn = partial(extract_all_metadata, event_cfgs=event_metadata_cfgs, allowed_codes=all_codes) + + rwlock_wrap(metadata_fp, out_fp, read_fn, write_lazyframe, compute_fn, do_overwrite=cfg.do_overwrite) + all_out_fps.append(out_fp) + + logger.info("Extracted metadata for all events. Merging.") + + if cfg.worker != 0: + logger.info("Code metadata extraction completed. Exiting") + return + + logger.info("Starting reduction process") + + while not all(fp.is_file() for fp in all_out_fps): + logger.info("Waiting to begin reduction for all files to be written...") + time.sleep(cfg.polling_time) + + start = datetime.now() + logger.info("All map shards complete! Starting code metadata reduction computation.") + + def reducer_fn(*dfs): + return pl.concat(dfs, how="diagonal_relaxed").unique(maintain_order=True) + + reduced = reducer_fn(*[pl.scan_parquet(fp, glob=False) for fp in all_out_fps]) + join_cols = ["code", *cfg.get("code_modifier_cols", [])] + metadata_cols = [c for c in reduced.columns if c not in join_cols] + + n_unique_obs = reduced.select(pl.n_unique(*join_cols)).collect().item() + n_rows = reduced.select(pl.count()).collect().item() + logger.info(f"Collected metadata for {n_unique_obs} unique codes among {n_rows} total observations.") + + if n_unique_obs != n_rows: + reduced = reduced.group_by(join_cols).agg(*(pl.col(c) for c in metadata_cols)).collect() + else: + reduced = reduced.collect() + + reducer_fp = Path(cfg.cohort_dir) / "code_metadata.parquet" + + if reducer_fp.exists(): + logger.info(f"Joining to existing code metadata at {str(reducer_fp.resolve())}") + existing = pl.read_parquet(reducer_fp, use_pyarrow=True) + reduced = existing.join(reduced, on=join_cols, how="full", coalesce=True) + + reduced.write_parquet(reducer_fp, use_pyarrow=True) + logger.info(f"Finished reduction in {datetime.now() - start}") + + +if __name__ == "__main__": + main() diff --git a/src/MEDS_polars_functions/extract/parser.py b/src/MEDS_polars_functions/extract/parser.py new file mode 100644 index 0000000..9396f2f --- /dev/null +++ b/src/MEDS_polars_functions/extract/parser.py @@ -0,0 +1,491 @@ +"""Module for code that provides a structured DSL to specify columns of dataframes or operations on said. + +This module contains two key concepts: column expressions and matchers. + +Matchers are used to specify conditionals over dataframes. They are expressed simply as dictionaries mapping +column names to values. Exact equality is used to match the column values. + +Column expressions currently support the following types: + - COL (`'col'`): A column expression that extracts a specified column. + - STR (`'str'`): A column expression that is a string, with interpolation allowed to other column names via + python's f-string syntax. + - LITERAL (`'literal'`): A column expression that is a literal value regardless of type. No interpolation is + allowed here. + +Column expressions can be expressed either dictionary or via a shorthand string. If a structured dictionary, +the dictionary has length 1 and the key is one of the column expression types and the value is the expression +target (e.g., the column to load for `COL`, the string to interpolate with `{...}` escaped interpolation +targets for `STR`, or the literal value for `LITERAL`). If a string, the string is interpreted as a `COL` if +it has no interpolation targets, and as a `STR` otherwise. + +These types can be combined or filtered via two modes: + - Coalescing: Multiple column expressions can be combined into a single expression where the first non-null + is used by specifying them in an ordered list. + - Conditional: A column expression can be conditionally applied to a dataframe based on a matcher, by + specifying the column expression and matcher in a dictionary, in one of two possible forms: + - A single key-value pair, where the key is a string realization of either a `COL` or `STR` type expressed + as a string and the value is the matcher dictionary. + - Two key-value pairs, where the first key is `"output"` and the value is the column expression and the + second key is `"matcher"` and the value is the matcher dictionary. +""" +from __future__ import annotations + +import re +from enum import StrEnum +from typing import Any + +import polars as pl +from omegaconf import DictConfig, ListConfig, OmegaConf + + +def is_matcher(matcher_cfg: dict[str, Any]) -> bool: + """Checks if a dictionary is a valid matcher configuration. + + Args: + matcher_cfg: A dictionary of key-value pairs to match against. + + Returns: + bool: True if the input is a valid matcher configuration, False otherwise. + + Examples: + >>> is_matcher({"foo": "bar"}) + True + >>> is_matcher(DictConfig({"foo": "bar"})) + True + >>> is_matcher({"foo": "bar", 32: "baz"}) + False + >>> is_matcher(["foo", "bar"]) + False + >>> is_matcher({}) + True + """ + return isinstance(matcher_cfg, (dict, DictConfig)) and all(isinstance(k, str) for k in matcher_cfg.keys()) + + +def matcher_to_expr(matcher_cfg: DictConfig | dict) -> tuple[pl.Expr, set[str]]: + """Returns an expression and the necessary columns to match a collection of key-value pairs. + + Currently, this only supports checking for equality between column names and values. + TODO: Expand (as needed only) to support other types of matchers. + + Args: + matcher_cfg: A dictionary of key-value pairs to match against. + + Raises: + ValueError: If the matcher configuration is not a dictionary. + + Returns: + pl.Expr: A Polars expression that matches the key-value pairs in the input dictionary. + set[str]: The set of input columns needed to form the returned expression. + + Examples: + >>> expr, cols = matcher_to_expr({"foo": "bar", "buzz": "baz"}) + >>> print(expr) + [(col("foo")) == (String(bar))].all_horizontal([[(col("buzz")) == (String(baz))]]) + >>> sorted(cols) + ['buzz', 'foo'] + >>> expr, cols = matcher_to_expr(DictConfig({"foo": "bar", "buzz": "baz"})) + >>> print(expr) + [(col("foo")) == (String(bar))].all_horizontal([[(col("buzz")) == (String(baz))]]) + >>> sorted(cols) + ['buzz', 'foo'] + >>> matcher_to_expr(["foo", "bar"]) + Traceback (most recent call last): + ... + ValueError: Matcher configuration must be a dictionary with string keys. Got: ['foo', 'bar'] + """ + if not is_matcher(matcher_cfg): + raise ValueError(f"Matcher configuration must be a dictionary with string keys. Got: {matcher_cfg}") + + return pl.all_horizontal((pl.col(k) == v) for k, v in matcher_cfg.items()), set(matcher_cfg.keys()) + + +STR_INTERPOLATION_REGEX = r"\{([^}]+)\}" + + +class ColExprType(StrEnum): + """Enumeration of the different types of column expressions that can be parsed. + + Members: + COL: A column expression that extracts a specified column. + STR: A column expression that is a string, with interpolation allowed to other column names + via python's f-string syntax. + LITERAL: A column expression that is a literal value regardless of type. No interpolation is allowed + here. + """ + + COL = "col" + STR = "str" + LITERAL = "literal" + + @classmethod + def is_valid(cls, expr_dict: dict[ColExprType, Any]) -> tuple[bool, str | None]: + """Checks if a dictionary of expression key to value is a valid column expression. + + Args: + expr_dict: A dictionary of column expression type to value. + + Returns: + bool: True if the input is a valid column expression, False otherwise. + str | None: The reason the input is invalid, if it is invalid. + + Examples: + >>> ColExprType.is_valid({"col": "foo"}) + (True, None) + >>> ColExprType.is_valid({"col": 32}) + (False, 'Column expressions must have a string value. Got 32') + >>> ColExprType.is_valid({ColExprType.STR: "bar//{foo}"}) + (True, None) + >>> ColExprType.is_valid({ColExprType.STR: ["bar//{foo}"]}) + (False, "String interpolation expressions must have a string value. Got ['bar//{foo}']") + >>> ColExprType.is_valid({"literal": ["baz", 32]}) + (True, None) + >>> ColExprType.is_valid({"col": "foo", "str": "bar"}) # doctest: +NORMALIZE_WHITESPACE + (False, "Column expressions can only contain a single key-value pair. + Got {'col': 'foo', 'str': 'bar'}") + >>> ColExprType.is_valid({"foo": "bar"}) + (False, "Column expressions must have a key in ColExprType: ['col', 'str', 'literal']. Got foo") + >>> ColExprType.is_valid([("col", "foo")]) + (False, "Column expressions must be a dictionary. Got [('col', 'foo')]") + """ + + if not isinstance(expr_dict, dict): + return False, f"Column expressions must be a dictionary. Got {expr_dict}" + if len(expr_dict) != 1: + return False, f"Column expressions can only contain a single key-value pair. Got {expr_dict}" + + expr_type, expr_val = next(iter(expr_dict.items())) + match expr_type: + case cls.COL if isinstance(expr_val, str): + return True, None + case cls.COL: + return False, f"Column expressions must have a string value. Got {expr_val}" + case cls.STR if isinstance(expr_val, str): + return True, None + case cls.STR: + return False, f"String interpolation expressions must have a string value. Got {expr_val}" + case cls.LITERAL: + return True, None + case _: + return ( + False, + f"Column expressions must have a key in ColExprType: {[x.value for x in cls]}. Got " + f"{expr_type}", + ) + + @classmethod + def to_pl_expr(cls, expr_type: ColExprType, expr_val: Any) -> tuple[pl.Expr, set[str]]: + """Converts a column expression type and value to a Polars expression. + + Args: + expr_type: The type of column expression. + expr_val: The value of the column expression. + + Returns: + pl.Expr: A Polars expression that extracts the column from the metadata DataFrame. + set[str]: The set of input columns needed to form the returned expression. + + Raises: + ValueError: If the column expression type is invalid. + + Examples: + >>> print(*ColExprType.to_pl_expr(ColExprType.COL, "foo")) + col("foo") {'foo'} + >>> expr, cols = ColExprType.to_pl_expr(ColExprType.STR, "bar//{foo}//{baz}") + >>> print(expr) + String(bar//).str.concat_horizontal([col("foo"), String(//), col("baz")]) + >>> sorted(cols) + ['baz', 'foo'] + >>> expr, cols = ColExprType.to_pl_expr(ColExprType.LITERAL, ListConfig(["foo", "bar"])) + >>> print(expr) + Series[literal] + >>> pl.select(expr).item().to_list() + ['foo', 'bar'] + >>> cols + set() + >>> ColExprType.to_pl_expr(ColExprType.COL, 32) + Traceback (most recent call last): + ... + ValueError: ... + """ + is_valid, err_msg = cls.is_valid({expr_type: expr_val}) + if not is_valid: + raise ValueError(err_msg) + + match expr_type: + case cls.COL: + return pl.col(expr_val), {expr_val} + case cls.STR: + cols = list(re.findall(STR_INTERPOLATION_REGEX, expr_val)) + expr_val = re.sub(STR_INTERPOLATION_REGEX, "{}", expr_val) + return pl.format(expr_val, *cols), set(cols) + case cls.LITERAL: + if isinstance(expr_val, ListConfig): + expr_val = OmegaConf.to_object(expr_val) + return pl.lit(expr_val), set() + + +def parse_col_expr(cfg: str | list | dict[str, str] | ListConfig | DictConfig) -> dict: + """Parses a column expression configuration object into a dictionary expressing the desired expression. + + Args: + col_expr: A configuration object that specifies how to extract a column from the metadata. See the + module docstring for formatting details. + + Returns: + A dictionary specifying, in a structured form, the desired column expression. + + Examples: + >>> parse_col_expr("foo") + {'col': 'foo'} + >>> parse_col_expr("bar//{foo}") + {'str': 'bar//{foo}'} + >>> parse_col_expr({'col': 'bar//{foo}'}) + {'col': 'bar//{foo}'} + >>> parse_col_expr({"literal": ["foo", "bar"]}) + {'literal': ['foo', 'bar']} + >>> parse_col_expr({"output": "foo", "matcher": {"bar": "baz"}}) + {'output': {'col': 'foo'}, 'matcher': {'bar': 'baz'}} + >>> parse_col_expr({"output": "foo", "matcher": {32: "baz"}}) # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: A pre-specified output/matcher configuration must have a valid matcher dictionary, + which is a dictionary with string-type keys. Got cfg['matcher']={32: 'baz'} + >>> parse_col_expr({"foo": {"bar": "baz"}}) + {'output': {'col': 'foo'}, 'matcher': {'bar': 'baz'}} + >>> parse_col_expr({"foo": {32: "baz"}}) # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: A simple-form conditional expression is expressed with a single key-value pair dict, + where the key is not a column expression type and the value is a valid matcher dict, + which is a dictionary with string-type keys. This config has a single key-value pair + with key foo but an invalid matcher: {32: 'baz'} + >>> parse_col_expr(["bar//{foo}", {"str": "bar//UNK"}]) + [{'str': 'bar//{foo}'}, {'str': 'bar//UNK'}] + >>> parse_col_expr({"foo": "bar", "buzz": "baz", "fuzz": "fizz"}) # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: Dictionary column expression must either be explicit output/matcher configs, with two + keys, 'output' and 'matcher' with a valid matcher dictionary, or a simple column + expression with a single key-value pair where the key is a column expression type, or a + simple-form conditional expression with a single key-value pair where the key is the + conditional value and the value is a valid matcher dict. Got a dictionary with 3 elements: + {'foo': 'bar', 'buzz': 'baz', 'fuzz': 'fizz'} + >>> parse_col_expr(('foo', 'bar')) # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: A simple column expression must be a string, list, or dictionary. + Got : ('foo', 'bar') + >>> parse_col_expr({"col": "foo", "str": "bar"}) # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: Dictionary column expression must either be explicit output/matcher configs, with two + keys, 'output' and 'matcher' with a valid matcher dictionary, or a simple column + expression with a single key-value pair where the key is a column expression type, or a + simple-form conditional expression with a single key-value pair where the key is the + conditional value and the value is a valid matcher dict. Got a dictionary with 2 elements: + {'col': 'foo', 'str': 'bar'} + >>> parse_col_expr(["foo", 32]) # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: If a list (which coalesces columns), all elements must be strings or dictionaries. + Got: ['foo', 32] + """ + match cfg: + case str() if re.search(STR_INTERPOLATION_REGEX, cfg): + return {"str": cfg} + case str(): + return {"col": cfg} + case list() | ListConfig() if all(isinstance(x, (str, dict)) for x in cfg): + return [parse_col_expr(x) for x in cfg] + case list() | ListConfig(): + raise ValueError( + "If a list (which coalesces columns), all elements must be strings or dictionaries. " + f"Got: {cfg}" + ) + case dict() | DictConfig() if set(cfg.keys()) == {"output", "matcher"} and is_matcher(cfg["matcher"]): + return {"output": parse_col_expr(cfg["output"]), "matcher": cfg["matcher"]} + case dict() | DictConfig() if set(cfg.keys()) == {"output", "matcher"}: + raise ValueError( + "A pre-specified output/matcher configuration must have a valid matcher dictionary, which is " + f"a dictionary with string-type keys. Got cfg['matcher']={cfg['matcher']}" + ) + case dict() | DictConfig() if len(cfg) == 1 and ColExprType.is_valid(cfg)[0]: + return cfg + case dict() | DictConfig() if len(cfg) == 1: + out_cfg, matcher_cfg = next(iter(cfg.items())) + if is_matcher(matcher_cfg): + return {"output": parse_col_expr(out_cfg), "matcher": matcher_cfg} + else: + raise ValueError( + "A simple-form conditional expression is expressed with a single key-value pair dict, " + "where the key is not a column expression type and the value is a valid matcher dict, " + "which is a dictionary with string-type keys. This config has a single key-value pair " + f"with key {out_cfg} but an invalid matcher: {matcher_cfg}" + ) + case dict() | DictConfig(): + raise ValueError( + "Dictionary column expression must either be explicit output/matcher configs, with two keys, " + "'output' and 'matcher' with a valid matcher dictionary, or a simple column expression with " + "a single key-value pair where the key is a column expression type, or a simple-form " + "conditional expression with a single key-value pair where the key is the conditional value " + f"and the value is a valid matcher dict. Got a dictionary with {len(cfg)} elements: {cfg}" + ) + case _: + raise ValueError( + f"A simple column expression must be a string, list, or dictionary. Got {type(cfg)}: {cfg}" + ) + + +def structured_expr_to_pl(cfg: dict | list[dict] | ListConfig | DictConfig) -> tuple[pl.Expr, set[str]]: + """Converts a structured column expression configuration object to a Polars expression. + + Args: + structured_expr: A structured column expression configuration object. See the module docstring for DSL + details. + + Returns: + pl.Expr: A Polars expression that extracts the column from the metadata DataFrame. + set[str]: The set of input columns needed to form the returned expression. + + Raises: + ValueError: If the configuration object is invalid. + + Examples: + >>> expr, cols = structured_expr_to_pl([{"col": "foo"}, {"str": "bar//{baz}"}, {"literal": "fizz"}]) + >>> print(expr) + col("foo").coalesce([String(bar//).str.concat_horizontal([col("baz")]), String(fizz)]) + >>> sorted(cols) + ['baz', 'foo'] + >>> expr, cols = structured_expr_to_pl({"output": {"literal": "foo"}, "matcher": {"bar": "baz"}}) + >>> print(expr) + .when([(col("bar")) == (String(baz))].all_horizontal()).then(String(foo)).otherwise(null) + >>> sorted(cols) + ['bar'] + >>> expr, cols = structured_expr_to_pl({"col": "bar"}) + >>> print(expr) + col("bar") + >>> sorted(cols) + ['bar'] + >>> structured_expr_to_pl(["foo", 32]) + Traceback (most recent call last): + ... + ValueError: Error processing list config on field 1 for ['foo', 32] + >>> structured_expr_to_pl({"output": 32, "matcher": {"bar": "baz"}}) # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: Error processing output/matcher config output expression for + {'output': 32, 'matcher': {'bar': 'baz'}} + >>> structured_expr_to_pl({"output": "foo", "matcher": {32: "baz"}}) # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: A pre-specified output/matcher configuration must have a valid matcher dictionary, which + is a dictionary with string-type keys. Got cfg['matcher']={32: 'baz'} + >>> structured_expr_to_pl({"col": 32}) + Traceback (most recent call last): + ... + ValueError: Column expressions must have a string value. Got 32 + >>> structured_expr_to_pl("foo") + Traceback (most recent call last): + ... + ValueError: A structured column expression must be a list or dictionary. Got : foo + """ + + match cfg: + case list() | ListConfig() as cfg_fields: + component_exprs = [] + needed_cols = set() + for i, field in enumerate(cfg_fields): + try: + expr, cols = cfg_to_expr(field) + except ValueError as e: + raise ValueError(f"Error processing list config on field {i} for {cfg}") from e + component_exprs.append(expr) + needed_cols.update(cols) + return pl.coalesce(*component_exprs), needed_cols + case dict() | DictConfig() if set(cfg.keys()) == {"output", "matcher"} and is_matcher(cfg["matcher"]): + matcher_expr, matcher_cols = matcher_to_expr(cfg["matcher"]) + try: + out_expr, out_cols = cfg_to_expr(cfg["output"]) + except ValueError as e: + raise ValueError(f"Error processing output/matcher config output expression for {cfg}") from e + return pl.when(matcher_expr).then(out_expr), out_cols | matcher_cols + case dict() | DictConfig() if set(cfg.keys()) == {"output", "matcher"}: + # TODO(mmd): DRY out this and other error messages. + raise ValueError( + "A pre-specified output/matcher configuration must have a valid matcher dictionary, which is " + f"a dictionary with string-type keys. Got cfg['matcher']={cfg['matcher']}" + ) + case dict() | DictConfig() if ColExprType.is_valid(cfg)[0]: + expr_type, expr_val = next(iter(cfg.items())) + return ColExprType.to_pl_expr(expr_type, expr_val) + case dict() | DictConfig(): + _, err_msg = ColExprType.is_valid(cfg) + raise ValueError(err_msg) + case _: + raise ValueError( + f"A structured column expression must be a list or dictionary. Got {type(cfg)}: {cfg}" + ) + + +def cfg_to_expr(cfg: str | ListConfig | DictConfig) -> tuple[pl.Expr, set[str]]: + """Converts a metadata column configuration object to a Polars expression. + + Args: + cfg: A configuration object that specifies how to extract a column from the metadata. See the module + docstring for formatting details. + + Returns: + pl.Expr: A Polars expression that extracts the column from the metadata DataFrame. + set[str]: The set of input columns needed to form the returned expression. + + Examples: + >>> data = pl.DataFrame({ + ... "foo": ["a", "b", "c"], + ... "bar": ["d", "e", "f"], + ... "baz": [1, 2, 3] + ... }) + >>> expr, cols = cfg_to_expr("foo") + >>> data.select(expr.alias("out"))["out"].to_list() + ['a', 'b', 'c'] + >>> sorted(cols) + ['foo'] + >>> expr, cols = cfg_to_expr("bar//{foo}//{baz}") + >>> data.select(expr.alias("out"))["out"].to_list() + ['bar//a//1', 'bar//b//2', 'bar//c//3'] + >>> sorted(cols) + ['baz', 'foo'] + >>> expr, cols = cfg_to_expr({"literal": 34.2}) + >>> data.select(expr.alias("out"))["out"].to_list() + [34.2] + >>> cols + set() + >>> expr, cols = cfg_to_expr({"{baz}//{bar}": {"foo": "a"}}) + >>> data.select(expr.alias("out"))["out"].to_list() + ['1//d', None, None] + >>> sorted(cols) + ['bar', 'baz', 'foo'] + >>> cfg = [ + ... {"matcher": {"baz": 2}, "output": {"str": "bar//{baz}"}}, + ... {"literal": "34.2"}, + ... ] + >>> expr, cols = cfg_to_expr(cfg) + >>> data.select(expr.alias("out"))["out"].to_list() + ['34.2', 'bar//2', '34.2'] + >>> sorted(cols) + ['baz'] + + Note that sometimes coalescing can lead to unexpected results. For example, if the first expression is of + a different type than the second, the second expression may have its type coerced to match the first, + potentially in an unexpected manner. This is also related to some polars, bugs, such as + https://github.com/pola-rs/polars/issues/17773 + >>> cfg = [ + ... {"matcher": {"baz": 2}, "output": {"str": "bar//{baz}"}}, + ... {"literal": 34.8218}, + ... ] + >>> expr, cols = cfg_to_expr(cfg) + >>> data.select(expr.alias("out"))["out"].to_list() + ['34', 'bar//2', '34'] + """ + structured_expr = parse_col_expr(cfg) + return structured_expr_to_pl(structured_expr) diff --git a/src/MEDS_polars_functions/extract/shard_events.py b/src/MEDS_polars_functions/extract/shard_events.py index 9247305..d1dc4e9 100755 --- a/src/MEDS_polars_functions/extract/shard_events.py +++ b/src/MEDS_polars_functions/extract/shard_events.py @@ -24,7 +24,7 @@ ) ROW_IDX_NAME = "__row_idx__" -META_KEYS = {"timestamp_format"} +META_KEYS = {"timestamp_format", "_metadata"} def kwargs_strs(kwargs: dict) -> str: diff --git a/src/MEDS_polars_functions/extract/utils.py b/src/MEDS_polars_functions/extract/utils.py new file mode 100644 index 0000000..7227512 --- /dev/null +++ b/src/MEDS_polars_functions/extract/utils.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python +import gzip +import warnings +from collections.abc import Callable +from enum import StrEnum +from pathlib import Path +from typing import TypeVar + +import polars as pl +from loguru import logger + + +class SupportedFileFormats(StrEnum): + """The supported file formats for dataframes we can read in, in priority order. + + The values of the enum are the allowed file suffix for the format. + """ + + PARQUET = ".parquet" + CSV_GZ = ".csv.gz" + CSV = ".csv" + + +def scan_csv_gz(fp: Path, **kwargs) -> pl.LazyFrame: + with gzip.open(fp, mode="rb") as f: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + return pl.read_csv(f, **kwargs).lazy() + + +READERS = { + SupportedFileFormats.PARQUET: pl.scan_parquet, + SupportedFileFormats.CSV_GZ: scan_csv_gz, + SupportedFileFormats.CSV: pl.scan_csv, +} + + +DF_T = TypeVar("DF_T") + + +def get_supported_fp(root_dir: Path, file_prefix: str | Path) -> tuple[Path, Callable[[Path], DF_T]]: + """This function finds the best file path to read for a given root_dir and prefix. + + Args: + root_dir: The root directory to search for files. + file_prefix: The file prefix to search for. + + Raises: + FileNotFoundError: If no files are found with the given prefix and an allowed suffix. + + Returns: + The filepath with the matching prefix and the most preferred allowed suffix and an appropriate reader + function for that file type. + + Examples: + >>> from tempfile import TemporaryDirectory + >>> df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, schema={"a": pl.UInt8, "b": pl.Int64}) + >>> with TemporaryDirectory() as tmpdir: + ... tmpdir = Path(tmpdir) + ... fp = tmpdir / "test.csv" + ... df.write_csv(fp) + ... fp, reader = get_supported_fp(tmpdir, "test") + ... print(str(fp.relative_to(tmpdir)), reader(fp).collect()) + test.csv shape: (3, 2) + ┌─────┬─────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪═════╡ + │ 1 ┆ 4 │ + │ 2 ┆ 5 │ + │ 3 ┆ 6 │ + └─────┴─────┘ + >>> with TemporaryDirectory() as tmpdir: + ... tmpdir = Path(tmpdir) + ... fp = tmpdir / "test.parquet" + ... csv_fp = tmpdir / "test.csv" + ... df.write_parquet(fp) + ... df.write_csv(csv_fp) + ... fp, reader = get_supported_fp(tmpdir, "test") + ... print(str(fp.relative_to(tmpdir)), reader(fp).collect()) + test.parquet shape: (3, 2) + ┌─────┬─────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ u8 ┆ i64 │ + ╞═════╪═════╡ + │ 1 ┆ 4 │ + │ 2 ┆ 5 │ + │ 3 ┆ 6 │ + └─────┴─────┘ + >>> import gzip + >>> with TemporaryDirectory() as tmpdir: + ... tmpdir = Path(tmpdir) + ... fp = tmpdir / "test.csv.gz" + ... with gzip.open(fp, mode="wb") as f: + ... with warnings.catch_warnings(): + ... warnings.simplefilter("ignore", category=UserWarning) + ... df.write_csv(f) + ... fp, reader = get_supported_fp(tmpdir, "test") + ... print(str(fp.relative_to(tmpdir)), reader(fp).collect()) + test.csv.gz shape: (3, 2) + ┌─────┬─────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪═════╡ + │ 1 ┆ 4 │ + │ 2 ┆ 5 │ + │ 3 ┆ 6 │ + └─────┴─────┘ + >>> with TemporaryDirectory() as tmpdir: + ... tmpdir = Path(tmpdir) + ... fp = tmpdir / "test.json" + ... df.write_json(fp) + ... get_supported_fp(tmpdir, "test") # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + FileNotFoundError: No files found with prefix: test and allowed suffixes: + ['.parquet', '.csv.gz', '.csv'] + """ + + for suffix in list(SupportedFileFormats): + fp = root_dir / f"{file_prefix}{suffix.value}" + if fp.exists(): + logger.debug(f"Found file: {str(fp.resolve())}") + return fp, READERS[suffix] + raise FileNotFoundError( + f"No files found with prefix: {file_prefix} and allowed suffixes: " + f"{[x.value for x in SupportedFileFormats]}" + ) diff --git a/src/MEDS_polars_functions/utils.py b/src/MEDS_polars_functions/utils.py index ecd905b..669c0fc 100644 --- a/src/MEDS_polars_functions/utils.py +++ b/src/MEDS_polars_functions/utils.py @@ -175,6 +175,17 @@ def populate_stage( ... ValueError: 'stage7' is not a valid stage name. Options are: stage1, stage2, stage3, stage4, stage5, stage6 + >>> root_config = DictConfig({ + ... "input_dir": "/a/b", + ... "cohort_dir": "/c/d", + ... "stages": ["stage1", "stage2", "stage3", "stage4", "stage5", "stage6"], + ... "stage_configs": {"stage2": {"is_metadata": 34}}, + ... }) + >>> args = [root_config[k] for k in ["input_dir", "cohort_dir", "stages", "stage_configs"]] + >>> populate_stage("stage2", *args) + Traceback (most recent call last): + ... + TypeError: If specified manually, is_metadata must be a boolean. Got 34 """ if stage_name not in stages: diff --git a/tests/test_extract.py b/tests/test_extract.py index 068a835..4b654d3 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -19,12 +19,14 @@ CONVERT_TO_SHARDED_EVENTS_SCRIPT = extraction_root / "convert_to_sharded_events.py" MERGE_TO_MEDS_COHORT_SCRIPT = extraction_root / "merge_to_MEDS_cohort.py" AGGREGATE_CODE_METADATA_SCRIPT = code_root / "aggregate_code_metadata.py" + EXTRACT_CODE_METADATA_SCRIPT = extraction_root / "extract_code_metadata.py" else: SHARD_EVENTS_SCRIPT = "MEDS_extract-shard_events" SPLIT_AND_SHARD_SCRIPT = "MEDS_extract-split_and_shard_patients" CONVERT_TO_SHARDED_EVENTS_SCRIPT = "MEDS_extract-convert_to_sharded_events" MERGE_TO_MEDS_COHORT_SCRIPT = "MEDS_extract-merge_to_MEDS_cohort" AGGREGATE_CODE_METADATA_SCRIPT = "MEDS_transform-aggregate_code_metadata" + EXTRACT_CODE_METADATA_SCRIPT = "MEDS_extract-extract_code_metadata" import json import subprocess @@ -69,6 +71,20 @@ 1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 15:39:49",84.4,100.3 """ +INPUT_METADATA_FILE = """ +lab_code,title,loinc +HR,Heart Rate,8867-4 +temp,Body Temperature,8310-5 +""" + +DEMO_METADATA_FILE = """ +eye_color,description +BROWN,"Brown Eyes. The most common eye color." +BLUE,"Blue Eyes. Less common than brown." +HAZEL,"Hazel eyes. These are uncommon" +GREEN,"Green eyes. These are rare." +""" + EVENT_CFGS_YAML = """ subjects: patient_id_col: MRN @@ -77,6 +93,9 @@ - EYE_COLOR - col(eye_color) timestamp: null + _metadata: + demo_metadata: + description: description height: code: HEIGHT timestamp: null @@ -101,11 +120,19 @@ timestamp: col(vitals_date) timestamp_format: "%m/%d/%Y, %H:%M:%S" numerical_value: HR + _metadata: + input_metadata: + description: {"title": {"lab_code": "HR"}} + parent_code: {"LOINC/{loinc}": {"lab_code": "HR"}} temp: code: TEMP timestamp: col(vitals_date) timestamp_format: "%m/%d/%Y, %H:%M:%S" numerical_value: temp + _metadata: + input_metadata: + description: {"title": {"lab_code": "temp"}} + parent_code: {"LOINC/{loinc}": {"lab_code": "temp"}} """ # Test data (expected outputs) -- ALL OF THIS MAY CHANGE IF THE SEED OR DATA CHANGES @@ -240,6 +267,22 @@ def get_expected_output(df: str) -> pl.DataFrame: TEMP,12,4,12,1181.4999999999998,116373.38999999998 """ +MEDS_OUTPUT_CODE_METADATA_FILE_WITH_DESC = """ +code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,description,parent_code +,44,4,28,3198.8389005974336,382968.28937288234,, +ADMISSION//CARDIAC,2,2,0,,,, +ADMISSION//ORTHOPEDIC,1,1,0,,,, +ADMISSION//PULMONARY,1,1,0,,,, +DISCHARGE,4,4,0,,,, +DOB,4,4,0,,,, +EYE_COLOR//BLUE,1,1,0,,,"Blue Eyes. Less common than brown.", +EYE_COLOR//BROWN,1,1,0,,,"Brown Eyes. The most common eye color.", +EYE_COLOR//HAZEL,2,2,0,,,"Hazel eyes. These are uncommon", +HEIGHT,4,4,4,656.8389005974336,108056.12937288235,, +HR,12,4,12,1360.5000000000002,158538.77,"Heart Rate",LOINC/8867-4 +TEMP,12,4,12,1181.4999999999998,116373.38999999998,"Body Temperature",LOINC/8310-5 +""" + SUB_SHARDED_OUTPUTS = { "train/0": { "subjects": MEDS_OUTPUT_TRAIN_0_SUBJECTS, @@ -312,9 +355,14 @@ def test_extraction(): admit_vitals_csv = raw_cohort_dir / "admit_vitals.csv" event_cfgs_yaml = raw_cohort_dir / "event_cfgs.yaml" + demo_metadata_csv = raw_cohort_dir / "demo_metadata.csv" + input_metadata_csv = raw_cohort_dir / "input_metadata.csv" + # Write the CSV files subjects_csv.write_text(SUBJECTS_CSV.strip()) admit_vitals_csv.write_text(ADMIT_VITALS_CSV.strip()) + demo_metadata_csv.write_text(DEMO_METADATA_FILE.strip()) + input_metadata_csv.write_text(INPUT_METADATA_FILE.strip()) # Mix things up -- have one CSV be also in parquet format. admit_vitals_parquet = raw_cohort_dir / "admit_vitals.parquet" @@ -553,3 +601,36 @@ def test_extraction(): check_column_order=False, check_row_order=False, ) + + stderr, stdout = run_command( + EXTRACT_CODE_METADATA_SCRIPT, + extraction_config_kwargs, + "extract_code_metadata", + ) + all_stderrs.append(stderr) + all_stdouts.append(stdout) + + full_stderr = "\n".join(all_stderrs) + full_stdout = "\n".join(all_stdouts) + + output_file = MEDS_cohort_dir / "code_metadata.parquet" + assert output_file.is_file(), f"Expected {output_file} to exist: stderr:\n{stderr}\nstdout:\n{stdout}" + + got_df = pl.read_parquet(output_file, glob=False) + + want_df = pl.read_csv(source=StringIO(MEDS_OUTPUT_CODE_METADATA_FILE_WITH_DESC)).with_columns( + pl.col("code").cast(pl.Categorical), + pl.col("code/n_occurrences").cast(pl.UInt8), + pl.col("code/n_patients").cast(pl.UInt8), + pl.col("values/n_occurrences").cast(pl.UInt8), + pl.col("values/sum").cast(pl.Float32).fill_null(0), + pl.col("values/sum_sqd").cast(pl.Float32).fill_null(0), + ) + + assert_df_equal( + want=want_df, + got=got_df, + msg="Code metadata with descriptions differs!", + check_column_order=False, + check_row_order=False, + )