From ef121198813c1a04485ef0a0a61072fce63f91f8 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Fri, 31 May 2024 08:18:54 -0400 Subject: [PATCH 01/53] Added some starter code for the preprocessing side. --- configs/preprocess.yaml | 34 ++++++---- .../add_time_derived_measurements.py | 22 +++---- .../preprocessing/collect_code_metadata.py | 29 +++------ scripts/preprocessing/filter_codes.py | 62 +++++++++++++++++++ scripts/preprocessing/filter_outliers.py | 62 +++++++++++++++++++ scripts/preprocessing/filter_patients.py | 14 ++--- src/MEDS_polars_functions/code_metadata.py | 4 +- .../filter_measurements.py | 44 +++++++++++++ 8 files changed, 212 insertions(+), 59 deletions(-) create mode 100644 scripts/preprocessing/filter_codes.py create mode 100644 scripts/preprocessing/filter_outliers.py create mode 100644 src/MEDS_polars_functions/filter_measurements.py diff --git a/configs/preprocess.yaml b/configs/preprocess.yaml index 9b60579..d6d8991 100644 --- a/configs/preprocess.yaml +++ b/configs/preprocess.yaml @@ -6,13 +6,25 @@ defaults: # tokenization. code_modifier_columns: ??? -# Pipeline Structure stages: - - name: filter_patients + - filter_patients + - add_time_derived_measurements + - preliminary_counts + - filter_codes + - fit_outlier_detection + - filter_outliers + - fit_normalization + - normalize + - tokenize + - tensorize + +# Pipeline Structure +stage_configs: + filter_patients: min_events_per_patient: null min_measurements_per_patient: null - - name: add_time_derived_measurements + add_time_derived_measurements: age: dob_code: ??? age_code: "AGE" @@ -20,31 +32,27 @@ stages: time_of_day: bin_endpoints: [6, 12, 18, 24] - - name: preliminary_counts - obs_aggregations: + preliminary_counts: + aggregations: - "code/n_occurrences" - "code/n_patients" - - name: filter_codes + filter_codes: min_code_occurrences: null - - name: fit_outlier_detection + fit_outlier_detection: aggregations: - "values/n_occurrences" - "values/sum" - "values/sum_sqd" - - name: filter_outliers + filter_outliers: stddev_cutoff: 4.5 - - name: fit_normalization + fit_normalization: aggregations: - "code/n_occurrences" - "code/n_patients" - "values/n_occurrences" - "values/sum" - "values/sum_sqd" - - - name: normalization - - name: tokenization - - name: tensorization diff --git a/scripts/preprocessing/add_time_derived_measurements.py b/scripts/preprocessing/add_time_derived_measurements.py index 1e01067..5731445 100644 --- a/scripts/preprocessing/add_time_derived_measurements.py +++ b/scripts/preprocessing/add_time_derived_measurements.py @@ -30,26 +30,20 @@ def main(cfg: DictConfig): f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" ) - output_dir = Path(cfg.stage_dfg.output_dir) + output_dir = Path(cfg.stage_cfg.output_dir) - shards = json.loads((Path(cfg.stage_cfg.metadata_input_dir) / "splits.json").read_text()) + shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) + input_dir = Path(cfg.stage_cfg.data_input_dir) - final_cohort_dir = cfg.stage_cfg.data_input_dir / "final_cohort" - filtered_patients_dir = output_dir / "patients_above_length_threshold" - with_time_derived_dir = output_dir / "with_time_derived_measurements" - - if filtered_patients_dir.is_dir(): - logger.info(f"Reading data from filtered cohort directory {str(filtered_patients_dir.resolve())}") - input_dir = filtered_patients_dir - else: - logger.info(f"Reading data from raw cohort directory {str(final_cohort_dir.resolve())}") - input_dir = final_cohort_dir + logger.info(f"Reading data from {str(input_dir.resolve())}") patient_splits = list(shards.keys()) random.shuffle(patient_splits) compute_fns = [] - for feature_name, feature_cfg in cfg.time_derived_measurements.items(): + # We use the raw stages object as the induced `stage_cfg` has extra properties like the input and output + # directories. + for feature_name, feature_cfg in cfg.stages[cfg.stage].items(): match feature_name: case "age": compute_fns.append(add_new_events_fntr(age_fntr(feature_cfg))) @@ -62,7 +56,7 @@ def main(cfg: DictConfig): for sp in patient_splits: in_fp = input_dir / f"{sp}.parquet" - out_fp = with_time_derived_dir / f"{sp}.parquet" + out_fp = output_dir / f"{sp}.parquet" logger.info( f"Adding time derived measurements to {str(in_fp.resolve())} into {str(out_fp.resolve())}" diff --git a/scripts/preprocessing/collect_code_metadata.py b/scripts/preprocessing/collect_code_metadata.py index fa25bcb..07abd4f 100644 --- a/scripts/preprocessing/collect_code_metadata.py +++ b/scripts/preprocessing/collect_code_metadata.py @@ -28,25 +28,12 @@ def main(cfg: DictConfig): f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" ) - MEDS_cohort_dir = Path(cfg.MEDS_cohort_dir) - output_dir = Path(cfg.output_data_dir) - - shards = json.loads((MEDS_cohort_dir / "splits.json").read_text()) - - final_cohort_dir = MEDS_cohort_dir / "final_cohort" - filtered_patients_dir = output_dir / "patients_above_length_threshold" - with_time_derived_dir = output_dir / "with_time_derived_measurements" - code_metadata_dir = output_dir / f"code_metadata/{cfg.stage}" - - if with_time_derived_dir.is_dir(): - logger.info("Reading data from directory with time-derived: {str(with_time_derived_dir.resolve())}") - input_dir = with_time_derived_dir - if filtered_patients_dir.is_dir(): - logger.info(f"Reading data from filtered cohort directory {str(filtered_patients_dir.resolve())}") - input_dir = filtered_patients_dir - else: - logger.info(f"Reading data from raw cohort directory {str(final_cohort_dir.resolve())}") - input_dir = final_cohort_dir + input_dir = Path(cfg.stage_cfg.data_input_dir) + output_dir = Path(cfg.stage_cfg.output_dir) + + logger.info(f"Reading data from input directory {str(input_dir.resolve())}") + + shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) patient_splits = list(shards.keys()) random.shuffle(patient_splits) @@ -59,7 +46,7 @@ def main(cfg: DictConfig): all_out_fps = [] for sp in patient_splits: in_fp = input_dir / f"{sp}.parquet" - out_fp = code_metadata_dir / f"{sp}.parquet" + out_fp = output_dir / f"{sp}.parquet" all_out_fps.append(out_fp) logger.info( @@ -91,7 +78,7 @@ def main(cfg: DictConfig): reducer_fn = reducer_fntr(cfg, cfg.stage) reduced = reducer_fn(pl.scan_parquet(fp, glob=False) for fp in all_out_fps) - write_lazyframe(reduced, code_metadata_dir / "code_metadata.parquet") + write_lazyframe(reduced, output_dir / "code_metadata.parquet") logger.info(f"Finished reduction in {datetime.now() - start}") diff --git a/scripts/preprocessing/filter_codes.py b/scripts/preprocessing/filter_codes.py new file mode 100644 index 0000000..1f0e318 --- /dev/null +++ b/scripts/preprocessing/filter_codes.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python + +import json +import random +from pathlib import Path + +import hydra +import polars as pl +from loguru import logger +from omegaconf import DictConfig, OmegaConf + +from MEDS_polars_functions.filter_measurements import filter_codes_fntr +from MEDS_polars_functions.mapper import wrap as rwlock_wrap +from MEDS_polars_functions.utils import hydra_loguru_init, write_lazyframe + + +@hydra.main(version_base=None, config_path="configs", config_name="preprocess") +def main(cfg: DictConfig): + """TODO.""" + + hydra_loguru_init() + + logger.info( + f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" + f"Stage: {cfg.stage}\n\n" + f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" + ) + + input_dir = Path(cfg.stage_cfg.data_input_dir) + metadata_input_dir = Path(cfg.stage_cfg.metadata_input_dir) + output_dir = Path(cfg.stage_cfg.output_dir) + + shards = json.loads((cfg.input_dir / "splits.json").read_text()) + + patient_splits = list(shards.keys()) + random.shuffle(patient_splits) + + code_metadata = pl.read_parquet(metadata_input_dir / "code_metadata.parquet", use_pyarrow=True) + compute_fn = filter_codes_fntr(cfg.stage_cfg, code_metadata) + + for sp in patient_splits: + in_fp = input_dir / "final_cohort" / f"{sp}.parquet" + out_fp = output_dir / f"{sp}.parquet" + + logger.info(f"Filtering {str(in_fp.resolve())} into {str(out_fp.resolve())}") + + rwlock_wrap( + in_fp, + out_fp, + pl.scan_parquet, + write_lazyframe, + compute_fn, + do_return=False, + cache_intermediate=False, + do_overwrite=cfg.do_overwrite, + ) + + logger.info(f"Done with {cfg.stage}") + + +if __name__ == "__main__": + main() diff --git a/scripts/preprocessing/filter_outliers.py b/scripts/preprocessing/filter_outliers.py new file mode 100644 index 0000000..b1914a4 --- /dev/null +++ b/scripts/preprocessing/filter_outliers.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python + +import json +import random +from pathlib import Path + +import hydra +import polars as pl +from loguru import logger +from omegaconf import DictConfig, OmegaConf + +from MEDS_polars_functions.filter_measurements import filter_outliers_fntr +from MEDS_polars_functions.mapper import wrap as rwlock_wrap +from MEDS_polars_functions.utils import hydra_loguru_init, write_lazyframe + + +@hydra.main(version_base=None, config_path="configs", config_name="preprocess") +def main(cfg: DictConfig): + """TODO.""" + + hydra_loguru_init() + + logger.info( + f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" + f"Stage: {cfg.stage}\n\n" + f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" + ) + + input_dir = Path(cfg.stage_cfg.data_input_dir) + metadata_input_dir = Path(cfg.stage_cfg.metadata_input_dir) + output_dir = Path(cfg.stage_cfg.output_dir) + + shards = json.loads((cfg.input_dir / "splits.json").read_text()) + + patient_splits = list(shards.keys()) + random.shuffle(patient_splits) + + code_metadata = pl.read_parquet(metadata_input_dir / "code_metadata.parquet", use_pyarrow=True) + compute_fn = filter_outliers_fntr(cfg.stage_cfg, code_metadata) + + for sp in patient_splits: + in_fp = input_dir / "final_cohort" / f"{sp}.parquet" + out_fp = output_dir / f"{sp}.parquet" + + logger.info(f"Filtering {str(in_fp.resolve())} into {str(out_fp.resolve())}") + + rwlock_wrap( + in_fp, + out_fp, + pl.scan_parquet, + write_lazyframe, + compute_fn, + do_return=False, + cache_intermediate=False, + do_overwrite=cfg.do_overwrite, + ) + + logger.info(f"Done with {cfg.stage}") + + +if __name__ == "__main__": + main() diff --git a/scripts/preprocessing/filter_patients.py b/scripts/preprocessing/filter_patients.py index a2b6308..937d87f 100644 --- a/scripts/preprocessing/filter_patients.py +++ b/scripts/preprocessing/filter_patients.py @@ -30,18 +30,14 @@ def main(cfg: DictConfig): f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" ) - MEDS_cohort_dir = Path(cfg.MEDS_cohort_dir) - output_dir = Path(cfg.output_data_dir) + input_dir = Path(cfg.stage_cfg.data_input_dir) + output_dir = Path(cfg.stage_cfg.output_dir) - shards = json.loads((MEDS_cohort_dir / "splits.json").read_text()) - - final_cohort_dir = MEDS_cohort_dir / "final_cohort" + shards = json.loads((cfg.input_dir / "splits.json").read_text()) patient_splits = list(shards.keys()) random.shuffle(patient_splits) - filtered_patients_dir = output_dir / "patients_above_length_threshold" - compute_fns = [] if cfg.min_measurements_per_patient: logger.info( @@ -63,8 +59,8 @@ def main(cfg: DictConfig): ) for sp in patient_splits: - in_fp = final_cohort_dir / f"{sp}.parquet" - out_fp = filtered_patients_dir / f"{sp}.parquet" + in_fp = input_dir / "final_cohort" / f"{sp}.parquet" + out_fp = output_dir / f"{sp}.parquet" logger.info(f"Filtering {str(in_fp.resolve())} into {str(out_fp.resolve())}") diff --git a/src/MEDS_polars_functions/code_metadata.py b/src/MEDS_polars_functions/code_metadata.py index 301dd0b..9f2b2e2 100644 --- a/src/MEDS_polars_functions/code_metadata.py +++ b/src/MEDS_polars_functions/code_metadata.py @@ -242,10 +242,10 @@ def mapper_fntr(cfg: DictConfig, stage_name: str) -> Callable[[pl.DataFrame], pl KeyError: 'Stage name stage5 not found in code_processing_stages in configuration file.' """ - if stage_name not in cfg.code_processing_stages: + if stage_name not in cfg.stages: raise KeyError(f"Stage name {stage_name} not found in code_processing_stages in configuration file.") - aggregations = cfg.code_processing_stages[stage_name] + aggregations = cfg.stages[stage_name].aggregations for agg in aggregations: if agg not in METADATA_FN: raise KeyError( diff --git a/src/MEDS_polars_functions/filter_measurements.py b/src/MEDS_polars_functions/filter_measurements.py new file mode 100644 index 0000000..6def789 --- /dev/null +++ b/src/MEDS_polars_functions/filter_measurements.py @@ -0,0 +1,44 @@ +"""A polars-to-polars transformation function for filtering patients by sequence length.""" + +from collections.abc import Callable + +import polars as pl +from omegaconf import DictConfig + + +def filter_codes_fntr( + stage_cfg: DictConfig, code_metadata: pl.LazyFrame +) -> Callable[[pl.LazyFrame], pl.LazyFrame]: + """Filters patient events to only encompass those with a set of permissible codes. + + Args: + df: The input DataFrame. + stage_cfg: The configuration for the code filtering stage. + + Returns: + The processed DataFrame. + + Examples: + >>> raise NotImplementedError + """ + + raise NotImplementedError + + +def filter_outliers_fntr( + stage_cfg: DictConfig, code_metadata: pl.LazyFrame +) -> Callable[[pl.LazyFrame], pl.LazyFrame]: + """Filters patient events to only encompass those with a set of permissible codes. + + Args: + df: The input DataFrame. + stage_cfg: The configuration for the code filtering stage. + + Returns: + The processed DataFrame. + + Examples: + >>> raise NotImplementedError + """ + + raise NotImplementedError From 3ab9e40b156875bbe77544f0d2b1e2b684b0ce54 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 8 Jun 2024 17:36:32 -0400 Subject: [PATCH 02/53] Fixed code_metadata setup. --- .../preprocessing/collect_code_metadata.py | 4 +- src/MEDS_polars_functions/code_metadata.py | 262 +++++++++++------- 2 files changed, 157 insertions(+), 109 deletions(-) diff --git a/scripts/preprocessing/collect_code_metadata.py b/scripts/preprocessing/collect_code_metadata.py index 07abd4f..40dd27d 100644 --- a/scripts/preprocessing/collect_code_metadata.py +++ b/scripts/preprocessing/collect_code_metadata.py @@ -38,7 +38,7 @@ def main(cfg: DictConfig): patient_splits = list(shards.keys()) random.shuffle(patient_splits) - mapper_fn = mapper_fntr(cfg, cfg.stage) + mapper_fn = mapper_fntr(cfg.stage_cfg, cfg.get("code_modifier_columns", None)) start = datetime.now() logger.info("Starting code metadata mapping computation") @@ -75,7 +75,7 @@ def main(cfg: DictConfig): start = datetime.now() logger.info("All map shards complete! Starting code metadata reduction computation.") - reducer_fn = reducer_fntr(cfg, cfg.stage) + reducer_fn = reducer_fntr(cfg.stage_cfg, cfg.get("code_modifier_columns", None)) reduced = reducer_fn(pl.scan_parquet(fp, glob=False) for fp in all_out_fps) write_lazyframe(reduced, output_dir / "code_metadata.parquet") diff --git a/src/MEDS_polars_functions/code_metadata.py b/src/MEDS_polars_functions/code_metadata.py index 9f2b2e2..a8f2f07 100644 --- a/src/MEDS_polars_functions/code_metadata.py +++ b/src/MEDS_polars_functions/code_metadata.py @@ -6,7 +6,7 @@ import polars as pl import polars.selectors as cs -from omegaconf import DictConfig +from omegaconf import DictConfig, ListConfig, OmegaConf pl.enable_string_cache() @@ -119,43 +119,108 @@ class MapReducePair(NamedTuple): } -def mapper_fntr(cfg: DictConfig, stage_name: str) -> Callable[[pl.DataFrame], pl.DataFrame]: +def validate_args_and_get_code_cols( + stage_cfg: DictConfig, code_modifier_columns: list[str] | None +) -> list[str]: + """Validates the stage configuration and code_modifier_columns argument and returns the code group keys. + + Args: + stage_cfg: The configuration object for this stage. It must contain an `aggregations` field that has a + list of aggregations that should be applied in this stage. Each aggregation must be a string in + the `METADATA_FN` enumeration. + code_modifier_columns: A list of column names that should be used in addition to the core `code` + column to group the data before applying the aggregations. If None, only the `code` column will be + used. + + Returns: + A list of column names that should be used to group the data before applying the aggregations. + + Raises: + ValueError: If the stage config either does not contain an aggregations field, contains an empty or + mis-typed aggregations field, or contains an invalid aggregation function. + ValueError: If the code_modifier_columns argument is not a list of strings or None. + + Examples: + >>> no_aggs_cfg = DictConfig({"other_key": "other_value"}) + >>> validate_args_and_get_code_cols(no_aggs_cfg, None) # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: Stage config must contain an 'aggregations' field. Got: + other_key: other_value + >>> invalid_agg_cfg = DictConfig({"aggregations": ["INVALID"]}) + >>> validate_args_and_get_code_cols(invalid_agg_cfg, None) # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: Metadata aggregation function INVALID not found in METADATA_FN enumeration. Values are: + code/n_patients, code/n_occurrences, values/n_patients, values/n_occurrences, values/n_ints, + values/sum, values/sum_sqd, values/min, values/max + >>> valid_cfg = DictConfig({"aggregations": ["code/n_patients", "values/n_ints"]}) + >>> validate_args_and_get_code_cols(valid_cfg, 33) + Traceback (most recent call last): + ... + ValueError: code_modifier_columns must be a list of strings or None. Got 33 + >>> validate_args_and_get_code_cols(valid_cfg, [33]) + Traceback (most recent call last): + ... + ValueError: code_modifier_columns must be a list of strings or None. Got [33] + >>> validate_args_and_get_code_cols(valid_cfg, ["modifier1"]) + ['code', 'modifier1'] + >>> validate_args_and_get_code_cols(valid_cfg, None) + ['code'] + """ + + if "aggregations" not in stage_cfg: + raise ValueError( + f"Stage config must contain an 'aggregations' field. Got:\n{OmegaConf.to_yaml(stage_cfg)}" + ) + + aggregations = stage_cfg.aggregations + for agg in aggregations: + if agg not in METADATA_FN: + raise ValueError( + f"Metadata aggregation function {agg} not found in METADATA_FN enumeration. Values are: " + f"{', '.join([fn.value for fn in METADATA_FN])}" + ) + + match code_modifier_columns: + case None: + return ["code"] + case list() | ListConfig() if all(isinstance(col, str) for col in code_modifier_columns): + return ["code"] + code_modifier_columns + case _: + raise ValueError( + f"code_modifier_columns must be a list of strings or None. Got {code_modifier_columns}" + ) + + +def mapper_fntr( + stage_cfg: DictConfig, code_modifier_columns: list[str] | None +) -> Callable[[pl.DataFrame], pl.DataFrame]: """Returns a function that extracts code metadata from a MEDS cohort shard. Args: - cfg: A pre-processing configuration in `OmegaConf` `DictConfig` format. The configuration should have - a field `code_processing_stages` that specifies the metadata aggregations to perform for each - stage. This field should be a dictionary with stage names as keys and lists of metadata - aggregation functions (from the `METADATA_FN` enumeration) as values. - stage_name: The name of the stage in the configuration file that specifies the set of metadata - aggregations to perform in this function. + stage_cfg: The configuration object for this stage. It must contain an `aggregations` field that has a + list of aggregations that should be applied in this stage. Each aggregation must be a string in + the `METADATA_FN` enumeration, and the mapper function is specified in the + `CODE_METADATA_AGGREGATIONS` dictionary. + code_modifier_columns: A list of column names that should be used in addition to the core `code` + column to group the data before applying the aggregations. If None, only the `code` column will be + used. + + Raises: See `validate_args_and_get_code_cols`. Returns: A function that extracts the specified metadata from a MEDS cohort shard after grouping by the specified code & modifier columns - Raises: - KeyError: If the specified stage name is not found in the configuration file. - KeyError: If any specified aggregation function is not an element of the `METADATA_FN` enumeration. - Examples: - >>> cfg = DictConfig({ - ... "code_modifier_columns": ["modifier1"], - ... "code_processing_stages": { - ... "stage1": ["code/n_patients", "values/n_ints"], - ... "stage2": ["code/n_occurrences", "values/sum"], - ... "stage3.A": ["values/n_patients", "values/n_occurrences"], - ... "stage3.B": ["values/sum_sqd", "values/min", "values/max"], - ... "stage4": ["INVALID"], - ... } - ... }) >>> import numpy as np >>> df = pl.DataFrame({ - ... "code": pl.Series(["A", "B", "A", "B", "C", "A", "C", "B", "D"], dtype=pl.Categorical), - ... "modifier1": [1, 2, 1, 2, 1, 2, 1, 2, None], - ... "modifier_ignored": [3, 3, 4, 4, 5, 5, 6, 6, 7], - ... "patient_id": [1, 2, 1, 3, 1, 2, 2, 2, 1], - ... "numerical_value": [1.1, 2., 1.1, 4., 5., 6., 7.5, np.NaN, None], + ... "code": pl.Series(["A", "B", "A", "B", "C", "A", "C", "B", "D"], dtype=pl.Categorical), + ... "modifier1": [1, 2, 1, 2, 1, 2, 1, 2, None], + ... "modifier_ignored": [3, 3, 4, 4, 5, 5, 6, 6, 7], + ... "patient_id": [1, 2, 1, 3, 1, 2, 2, 2, 1], + ... "numerical_value": [1.1, 2., 1.1, 4., 5., 6., 7.5, np.NaN, None], ... }) >>> df shape: (9, 5) @@ -174,7 +239,23 @@ def mapper_fntr(cfg: DictConfig, stage_name: str) -> Callable[[pl.DataFrame], pl │ B ┆ 2 ┆ 6 ┆ 2 ┆ NaN │ │ D ┆ null ┆ 7 ┆ 1 ┆ null │ └──────┴───────────┴──────────────────┴────────────┴─────────────────┘ - >>> mapper = mapper_fntr(cfg, "stage1") + >>> stage_cfg = DictConfig({"aggregations": ["code/n_patients", "values/n_ints"]}) + >>> mapper = mapper_fntr(stage_cfg, None) + >>> mapper(df.lazy()).collect() + shape: (4, 3) + ┌──────┬─────────────────┬───────────────┐ + │ code ┆ code/n_patients ┆ values/n_ints │ + │ --- ┆ --- ┆ --- │ + │ cat ┆ u32 ┆ u32 │ + ╞══════╪═════════════════╪═══════════════╡ + │ A ┆ 2 ┆ 1 │ + │ B ┆ 2 ┆ 2 │ + │ C ┆ 2 ┆ 1 │ + │ D ┆ 1 ┆ 0 │ + └──────┴─────────────────┴───────────────┘ + >>> code_modifier_columns = ["modifier1"] + >>> stage_cfg = DictConfig({"aggregations": ["code/n_patients", "values/n_ints"]}) + >>> mapper = mapper_fntr(stage_cfg, ListConfig(code_modifier_columns)) >>> mapper(df.lazy()).collect() shape: (5, 4) ┌──────┬───────────┬─────────────────┬───────────────┐ @@ -188,7 +269,8 @@ def mapper_fntr(cfg: DictConfig, stage_name: str) -> Callable[[pl.DataFrame], pl │ C ┆ 1 ┆ 2 ┆ 1 │ │ D ┆ null ┆ 1 ┆ 0 │ └──────┴───────────┴─────────────────┴───────────────┘ - >>> mapper = mapper_fntr(cfg, "stage2") + >>> stage_cfg = DictConfig({"aggregations": ["code/n_occurrences", "values/sum"]}) + >>> mapper = mapper_fntr(stage_cfg, code_modifier_columns) >>> mapper(df.lazy()).collect() shape: (5, 4) ┌──────┬───────────┬────────────────────┬────────────┐ @@ -202,7 +284,8 @@ def mapper_fntr(cfg: DictConfig, stage_name: str) -> Callable[[pl.DataFrame], pl │ C ┆ 1 ┆ 2 ┆ 12.5 │ │ D ┆ null ┆ 1 ┆ 0.0 │ └──────┴───────────┴────────────────────┴────────────┘ - >>> mapper = mapper_fntr(cfg, "stage3.A") + >>> stage_cfg = DictConfig({"aggregations": ["values/n_patients", "values/n_occurrences"]}) + >>> mapper = mapper_fntr(stage_cfg, code_modifier_columns) >>> mapper(df.lazy()).collect() shape: (5, 4) ┌──────┬───────────┬───────────────────┬──────────────────────┐ @@ -216,7 +299,8 @@ def mapper_fntr(cfg: DictConfig, stage_name: str) -> Callable[[pl.DataFrame], pl │ C ┆ 1 ┆ 2 ┆ 2 │ │ D ┆ null ┆ 0 ┆ 0 │ └──────┴───────────┴───────────────────┴──────────────────────┘ - >>> mapper = mapper_fntr(cfg, "stage3.B") + >>> stage_cfg = DictConfig({"aggregations": ["values/sum_sqd", "values/min", "values/max"]}) + >>> mapper = mapper_fntr(stage_cfg, code_modifier_columns) >>> mapper(df.lazy()).collect() shape: (5, 5) ┌──────┬───────────┬────────────────┬────────────┬────────────┐ @@ -230,30 +314,11 @@ def mapper_fntr(cfg: DictConfig, stage_name: str) -> Callable[[pl.DataFrame], pl │ C ┆ 1 ┆ 81.25 ┆ 5.0 ┆ 7.5 │ │ D ┆ null ┆ 0.0 ┆ null ┆ null │ └──────┴───────────┴────────────────┴────────────┴────────────┘ - >>> mapper = mapper_fntr(cfg, "stage4") # doctest: +NORMALIZE_WHITESPACE - Traceback (most recent call last): - ... - KeyError: 'Metadata aggregation function INVALID not found in METADATA_FN enumeration. Values are: - code/n_patients, code/n_occurrences, values/n_patients, values/n_occurrences, values/n_ints, - values/sum, values/sum_sqd, values/min, values/max' - >>> mapper = mapper_fntr(cfg, "stage5") - Traceback (most recent call last): - ... - KeyError: 'Stage name stage5 not found in code_processing_stages in configuration file.' """ - if stage_name not in cfg.stages: - raise KeyError(f"Stage name {stage_name} not found in code_processing_stages in configuration file.") + code_key_columns = validate_args_and_get_code_cols(stage_cfg, code_modifier_columns) + aggregations = stage_cfg.aggregations - aggregations = cfg.stages[stage_name].aggregations - for agg in aggregations: - if agg not in METADATA_FN: - raise KeyError( - f"Metadata aggregation function {agg} not found in METADATA_FN enumeration. Values are: " - f"{', '.join([fn.value for fn in METADATA_FN])}" - ) - - code_key_columns = ["code"] + cfg.get("code_modifier_columns", []) agg_operations = {agg: CODE_METADATA_AGGREGATIONS[agg].mapper for agg in aggregations} def mapper(df: pl.LazyFrame) -> pl.LazyFrame: @@ -262,43 +327,30 @@ def mapper(df: pl.LazyFrame) -> pl.LazyFrame: return mapper -def reducer_fntr(cfg: DictConfig, stage_name: str) -> Callable[[Sequence[pl.DataFrame]], pl.DataFrame]: +def reducer_fntr( + stage_cfg: DictConfig, code_modifier_columns: list[str] | None = None +) -> Callable[[Sequence[pl.DataFrame]], pl.DataFrame]: """Returns a function that merges different code metadata files together into an aggregated total. - The functions specified are determined by the TODO field in the configuration file at stage `stage_name`. - The reductions for these aggregation functions are specified in the `CODE_METADATA_AGGREGATIONS` - dictionary. - Args: - cfg: A pre-processing configuration in `OmegaConf` `DictConfig` format. The configuration should have - a field `code_processing_stages` that specifies the metadata aggregations to perform for each - stage. This field should be a dictionary with stage names as keys and lists of metadata - aggregation functions (from the `METADATA_FN` enumeration) as values. - stage_name: The name of the stage in the configuration file that specifies the set of metadata - aggregations to perform in this function. + stage_cfg: The configuration object for this stage. It must contain an `aggregations` field that has a + list of aggregations that should be applied in this stage. Each aggregation must be a string in + the `METADATA_FN` enumeration, and the reduction function is specified in the + `CODE_METADATA_AGGREGATIONS` dictionary. + code_modifier_columns: A list of column names that should be used in addition to the core `code` + column to group the data before applying the aggregations. If None, only the `code` column will be + used. Returns: A function that aggregates the specified metadata columns from different extracted metadata shards into a total view. - Raises: - KeyError: If the specified stage name is not found in the configuration file. - KeyError: If any specified aggregation function is not an element of the `METADATA_FN` enumeration. + Raises: See `validate_args_and_get_code_cols`. Examples: - >>> cfg = DictConfig({ - ... "code_modifier_columns": ["modifier1"], - ... "code_processing_stages": { - ... "stage1": ["code/n_patients", "values/n_ints"], - ... "stage2": ["code/n_occurrences", "values/sum"], - ... "stage3.A": ["values/n_patients", "values/n_occurrences"], - ... "stage3.B": ["values/sum_sqd", "values/min", "values/max"], - ... "stage4": ["INVALID"], - ... } - ... }) >>> df_1 = pl.DataFrame({ - ... "code": pl.Series(["A", "A", "B", "C"], dtype=pl.Categorical), - ... "modifier1": [1, 2, 1, 2], + ... "code": pl.Series(["A", "A", "B", "C"], dtype=pl.Categorical), + ... "modifier1": [1, 2, 1, 2], ... "code/n_patients": [1, 1, 2, 2], ... "code/n_occurrences": [2, 1, 3, 2], ... "values/n_patients": [1, 1, 2, 2], @@ -311,7 +363,7 @@ def reducer_fntr(cfg: DictConfig, stage_name: str) -> Callable[[Sequence[pl.Data ... }) >>> df_2 = pl.DataFrame({ ... "code": pl.Series(["A", "A", "B", "C"], dtype=pl.Categorical), - ... "modifier1": [1, 2, 1, None], + ... "modifier1": [1, 2, 1, None], ... "code/n_patients": [3, 3, 4, 4], ... "code/n_occurrences": [10, 11, 8, 11], ... "values/n_patients": [0, 1, 2, 2], @@ -323,11 +375,11 @@ def reducer_fntr(cfg: DictConfig, stage_name: str) -> Callable[[Sequence[pl.Data ... "values/max": [None, 6.2, 1.0, 1.5], ... }) >>> df_3 = pl.DataFrame({ - ... "code": pl.Series(["D"], dtype=pl.Categorical), + ... "code": pl.Series(["D"], dtype=pl.Categorical), ... "modifier1": [1], - ... "code/n_patients": [2], + ... "code/n_patients": [2], ... "code/n_occurrences": [2], - ... "values/n_patients": [1], + ... "values/n_patients": [1], ... "values/n_occurrences": [3], ... "values/n_ints": [3], ... "values/sum": [2], @@ -335,7 +387,9 @@ def reducer_fntr(cfg: DictConfig, stage_name: str) -> Callable[[Sequence[pl.Data ... "values/min": [0], ... "values/max": [2], ... }) - >>> reducer = reducer_fntr(cfg, "stage1") + >>> code_modifier_columns = ["modifier1"] + >>> stage_cfg = DictConfig({"aggregations": ["code/n_patients", "values/n_ints"]}) + >>> reducer = reducer_fntr(stage_cfg, code_modifier_columns) >>> reducer(df_1, df_2, df_3) shape: (6, 4) ┌──────┬───────────┬─────────────────┬───────────────┐ @@ -350,7 +404,18 @@ def reducer_fntr(cfg: DictConfig, stage_name: str) -> Callable[[Sequence[pl.Data │ C ┆ 2 ┆ 2 ┆ 1 │ │ D ┆ 1 ┆ 2 ┆ 3 │ └──────┴───────────┴─────────────────┴───────────────┘ - >>> reducer = reducer_fntr(cfg, "stage2") + >>> cfg = DictConfig({ + ... "code_modifier_columns": ["modifier1"], + ... "code_processing_stages": { + ... "stage1": ["code/n_patients", "values/n_ints"], + ... "stage2": ["code/n_occurrences", "values/sum"], + ... "stage3.A": ["values/n_patients", "values/n_occurrences"], + ... "stage3.B": ["values/sum_sqd", "values/min", "values/max"], + ... "stage4": ["INVALID"], + ... } + ... }) + >>> stage_cfg = DictConfig({"aggregations": ["code/n_occurrences", "values/sum"]}) + >>> reducer = reducer_fntr(stage_cfg, code_modifier_columns) >>> reducer(df_1, df_2, df_3) shape: (6, 4) ┌──────┬───────────┬────────────────────┬────────────┐ @@ -365,7 +430,8 @@ def reducer_fntr(cfg: DictConfig, stage_name: str) -> Callable[[Sequence[pl.Data │ C ┆ 2 ┆ 2 ┆ 12.5 │ │ D ┆ 1 ┆ 2 ┆ 2.0 │ └──────┴───────────┴────────────────────┴────────────┘ - >>> reducer = reducer_fntr(cfg, "stage3.A") + >>> stage_cfg = DictConfig({"aggregations": ["values/n_patients", "values/n_occurrences"]}) + >>> reducer = reducer_fntr(stage_cfg, code_modifier_columns) >>> reducer(df_1, df_2, df_3) shape: (6, 4) ┌──────┬───────────┬───────────────────┬──────────────────────┐ @@ -380,7 +446,8 @@ def reducer_fntr(cfg: DictConfig, stage_name: str) -> Callable[[Sequence[pl.Data │ C ┆ 2 ┆ 2 ┆ 2 │ │ D ┆ 1 ┆ 1 ┆ 3 │ └──────┴───────────┴───────────────────┴──────────────────────┘ - >>> reducer = reducer_fntr(cfg, "stage3.B") + >>> stage_cfg = DictConfig({"aggregations": ["values/sum_sqd", "values/min", "values/max"]}) + >>> reducer = reducer_fntr(stage_cfg, code_modifier_columns) >>> reducer(df_1, df_2, df_3) shape: (6, 5) ┌──────┬───────────┬────────────────┬────────────┬────────────┐ @@ -399,30 +466,11 @@ def reducer_fntr(cfg: DictConfig, stage_name: str) -> Callable[[Sequence[pl.Data Traceback (most recent call last): ... KeyError: 'Column values/min not found in DataFrame 0 for reduction.' - >>> reducer = reducer_fntr(cfg, "stage4") # doctest: +NORMALIZE_WHITESPACE - Traceback (most recent call last): - ... - KeyError: 'Metadata aggregation function INVALID not found in METADATA_FN enumeration. Values are: - code/n_patients, code/n_occurrences, values/n_patients, values/n_occurrences, values/n_ints, - values/sum, values/sum_sqd, values/min, values/max' - >>> reducer = reducer_fntr(cfg, "stage5") - Traceback (most recent call last): - ... - KeyError: 'Stage name stage5 not found in code_processing_stages in configuration file.' """ - if stage_name not in cfg.code_processing_stages: - raise KeyError(f"Stage name {stage_name} not found in code_processing_stages in configuration file.") - - aggregations = cfg.code_processing_stages[stage_name] - for agg in aggregations: - if agg not in METADATA_FN: - raise KeyError( - f"Metadata aggregation function {agg} not found in METADATA_FN enumeration. Values are: " - f"{', '.join([fn.value for fn in METADATA_FN])}" - ) + code_key_columns = validate_args_and_get_code_cols(stage_cfg, code_modifier_columns) + aggregations = stage_cfg.aggregations - code_key_columns = ["code"] + cfg.get("code_modifier_columns", []) agg_operations = { agg: CODE_METADATA_AGGREGATIONS[agg].reducer(cs.matches(f"{agg}/shard_\\d+")) for agg in aggregations } From 5511f9156a5936f7b6178df98e366f07975e44bc Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 8 Jun 2024 18:01:21 -0400 Subject: [PATCH 03/53] Enabled summarizing the entire population in the code metadata map/reduce functions. --- configs/preprocess.yaml | 4 +- src/MEDS_polars_functions/code_metadata.py | 94 +++++++++++++++---- .../filter_measurements.py | 23 ++++- 3 files changed, 101 insertions(+), 20 deletions(-) diff --git a/configs/preprocess.yaml b/configs/preprocess.yaml index d6d8991..ee57e65 100644 --- a/configs/preprocess.yaml +++ b/configs/preprocess.yaml @@ -36,9 +36,11 @@ stage_configs: aggregations: - "code/n_occurrences" - "code/n_patients" + do_summarize_over_all_codes: true # This indicates we should include overall, code-independent counts filter_codes: - min_code_occurrences: null + min_patients_per_code: null + min_occurrences_per_code: null fit_outlier_detection: aggregations: diff --git a/src/MEDS_polars_functions/code_metadata.py b/src/MEDS_polars_functions/code_metadata.py index a8f2f07..6e7f99f 100644 --- a/src/MEDS_polars_functions/code_metadata.py +++ b/src/MEDS_polars_functions/code_metadata.py @@ -211,7 +211,11 @@ def mapper_fntr( Returns: A function that extracts the specified metadata from a MEDS cohort shard after grouping by the - specified code & modifier columns + specified code & modifier columns. **Note**: The output of this function will, if + ``stage_cfg.do_summarize_over_all_codes`` is True, contain the metadata summarizing all observations + across all codes and patients in the shard, with both ``code`` and all ``code_modifier_columns`` set + to `None` in the output dataframe, in the same format as the code/modifier specific rows with non-null + values. Examples: >>> import numpy as np @@ -239,6 +243,24 @@ def mapper_fntr( │ B ┆ 2 ┆ 6 ┆ 2 ┆ NaN │ │ D ┆ null ┆ 7 ┆ 1 ┆ null │ └──────┴───────────┴──────────────────┴────────────┴─────────────────┘ + >>> stage_cfg = DictConfig({ + ... "aggregations": ["code/n_patients", "values/n_ints"], + ... "do_summarize_over_all_codes": True + ... }) + >>> mapper = mapper_fntr(stage_cfg, None) + >>> mapper(df.lazy()).collect() + shape: (5, 3) + ┌──────┬─────────────────┬───────────────┐ + │ code ┆ code/n_patients ┆ values/n_ints │ + │ --- ┆ --- ┆ --- │ + │ cat ┆ u32 ┆ u32 │ + ╞══════╪═════════════════╪═══════════════╡ + │ null ┆ 3 ┆ 4 │ + │ A ┆ 2 ┆ 1 │ + │ B ┆ 2 ┆ 2 │ + │ C ┆ 2 ┆ 1 │ + │ D ┆ 1 ┆ 0 │ + └──────┴─────────────────┴───────────────┘ >>> stage_cfg = DictConfig({"aggregations": ["code/n_patients", "values/n_ints"]}) >>> mapper = mapper_fntr(stage_cfg, None) >>> mapper(df.lazy()).collect() @@ -284,6 +306,25 @@ def mapper_fntr( │ C ┆ 1 ┆ 2 ┆ 12.5 │ │ D ┆ null ┆ 1 ┆ 0.0 │ └──────┴───────────┴────────────────────┴────────────┘ + >>> stage_cfg = DictConfig({ + ... "aggregations": ["code/n_occurrences", "values/sum"], + ... "do_summarize_over_all_codes": True, + ... }) + >>> mapper = mapper_fntr(stage_cfg, code_modifier_columns) + >>> mapper(df.lazy()).collect() + shape: (6, 4) + ┌──────┬───────────┬────────────────────┬────────────┐ + │ code ┆ modifier1 ┆ code/n_occurrences ┆ values/sum │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ cat ┆ i64 ┆ u32 ┆ f64 │ + ╞══════╪═══════════╪════════════════════╪════════════╡ + │ null ┆ null ┆ 9 ┆ 26.7 │ + │ A ┆ 1 ┆ 2 ┆ 2.2 │ + │ A ┆ 2 ┆ 1 ┆ 6.0 │ + │ B ┆ 2 ┆ 3 ┆ 6.0 │ + │ C ┆ 1 ┆ 2 ┆ 12.5 │ + │ D ┆ null ┆ 1 ┆ 0.0 │ + └──────┴───────────┴────────────────────┴────────────┘ >>> stage_cfg = DictConfig({"aggregations": ["values/n_patients", "values/n_occurrences"]}) >>> mapper = mapper_fntr(stage_cfg, code_modifier_columns) >>> mapper(df.lazy()).collect() @@ -321,9 +362,24 @@ def mapper_fntr( agg_operations = {agg: CODE_METADATA_AGGREGATIONS[agg].mapper for agg in aggregations} - def mapper(df: pl.LazyFrame) -> pl.LazyFrame: + def by_code_mapper(df: pl.LazyFrame) -> pl.LazyFrame: return df.group_by(code_key_columns).agg(**agg_operations).sort(code_key_columns) + def all_patients_mapper(df: pl.LazyFrame) -> pl.LazyFrame: + return df.select(**agg_operations) + + if stage_cfg.get("do_summarize_over_all_codes", False): + + def mapper(df: pl.LazyFrame) -> pl.LazyFrame: + by_code = by_code_mapper(df) + all_patients = all_patients_mapper(df) + return pl.concat([all_patients, by_code], how="diagonal_relaxed").select( + *code_key_columns, *aggregations + ) + + else: + mapper = by_code_mapper + return mapper @@ -349,17 +405,17 @@ def reducer_fntr( Examples: >>> df_1 = pl.DataFrame({ - ... "code": pl.Series(["A", "A", "B", "C"], dtype=pl.Categorical), - ... "modifier1": [1, 2, 1, 2], - ... "code/n_patients": [1, 1, 2, 2], - ... "code/n_occurrences": [2, 1, 3, 2], - ... "values/n_patients": [1, 1, 2, 2], - ... "values/n_occurrences": [2, 1, 3, 2], - ... "values/n_ints": [0, 1, 3, 1], - ... "values/sum": [2.2, 6.0, 14.0, 12.5], - ... "values/sum_sqd": [2.42, 36.0, 84.0, 81.25], - ... "values/min": [0, -1, 2, 2.], - ... "values/max": [1.1, 6.0, 8.0, 7.5], + ... "code": pl.Series([None, "A", "A", "B", "C"], dtype=pl.Categorical), + ... "modifier1": [None, 1, 2, 1, 2], + ... "code/n_patients": [10, 1, 1, 2, 2], + ... "code/n_occurrences": [13, 2, 1, 3, 2], + ... "values/n_patients": [8, 1, 1, 2, 2], + ... "values/n_occurrences": [12, 2, 1, 3, 2], + ... "values/n_ints": [4, 0, 1, 3, 1], + ... "values/sum": [13.2, 2.2, 6.0, 14.0, 12.5], + ... "values/sum_sqd": [21.3, 2.42, 36.0, 84.0, 81.25], + ... "values/min": [-1, 0, -1, 2, 2.], + ... "values/max": [8.0, 1.1, 6.0, 8.0, 7.5], ... }) >>> df_2 = pl.DataFrame({ ... "code": pl.Series(["A", "A", "B", "C"], dtype=pl.Categorical), @@ -391,12 +447,13 @@ def reducer_fntr( >>> stage_cfg = DictConfig({"aggregations": ["code/n_patients", "values/n_ints"]}) >>> reducer = reducer_fntr(stage_cfg, code_modifier_columns) >>> reducer(df_1, df_2, df_3) - shape: (6, 4) + shape: (7, 4) ┌──────┬───────────┬─────────────────┬───────────────┐ │ code ┆ modifier1 ┆ code/n_patients ┆ values/n_ints │ │ --- ┆ --- ┆ --- ┆ --- │ │ cat ┆ i64 ┆ i64 ┆ i64 │ ╞══════╪═══════════╪═════════════════╪═══════════════╡ + │ null ┆ null ┆ 10 ┆ 4 │ │ A ┆ 1 ┆ 4 ┆ 0 │ │ A ┆ 2 ┆ 4 ┆ 2 │ │ B ┆ 1 ┆ 6 ┆ 6 │ @@ -417,12 +474,13 @@ def reducer_fntr( >>> stage_cfg = DictConfig({"aggregations": ["code/n_occurrences", "values/sum"]}) >>> reducer = reducer_fntr(stage_cfg, code_modifier_columns) >>> reducer(df_1, df_2, df_3) - shape: (6, 4) + shape: (7, 4) ┌──────┬───────────┬────────────────────┬────────────┐ │ code ┆ modifier1 ┆ code/n_occurrences ┆ values/sum │ │ --- ┆ --- ┆ --- ┆ --- │ │ cat ┆ i64 ┆ i64 ┆ f64 │ ╞══════╪═══════════╪════════════════════╪════════════╡ + │ null ┆ null ┆ 13 ┆ 13.2 │ │ A ┆ 1 ┆ 12 ┆ 2.2 │ │ A ┆ 2 ┆ 12 ┆ 13.0 │ │ B ┆ 1 ┆ 11 ┆ 28.0 │ @@ -433,12 +491,13 @@ def reducer_fntr( >>> stage_cfg = DictConfig({"aggregations": ["values/n_patients", "values/n_occurrences"]}) >>> reducer = reducer_fntr(stage_cfg, code_modifier_columns) >>> reducer(df_1, df_2, df_3) - shape: (6, 4) + shape: (7, 4) ┌──────┬───────────┬───────────────────┬──────────────────────┐ │ code ┆ modifier1 ┆ values/n_patients ┆ values/n_occurrences │ │ --- ┆ --- ┆ --- ┆ --- │ │ cat ┆ i64 ┆ i64 ┆ i64 │ ╞══════╪═══════════╪═══════════════════╪══════════════════════╡ + │ null ┆ null ┆ 8 ┆ 12 │ │ A ┆ 1 ┆ 1 ┆ 2 │ │ A ┆ 2 ┆ 2 ┆ 5 │ │ B ┆ 1 ┆ 4 ┆ 6 │ @@ -449,12 +508,13 @@ def reducer_fntr( >>> stage_cfg = DictConfig({"aggregations": ["values/sum_sqd", "values/min", "values/max"]}) >>> reducer = reducer_fntr(stage_cfg, code_modifier_columns) >>> reducer(df_1, df_2, df_3) - shape: (6, 5) + shape: (7, 5) ┌──────┬───────────┬────────────────┬────────────┬────────────┐ │ code ┆ modifier1 ┆ values/sum_sqd ┆ values/min ┆ values/max │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ cat ┆ i64 ┆ f64 ┆ f64 ┆ f64 │ ╞══════╪═══════════╪════════════════╪════════════╪════════════╡ + │ null ┆ null ┆ 21.3 ┆ -1.0 ┆ 8.0 │ │ A ┆ 1 ┆ 2.42 ┆ 0.0 ┆ 1.1 │ │ A ┆ 2 ┆ 139.2 ┆ -1.0 ┆ 6.2 │ │ B ┆ 1 ┆ 168.0 ┆ 0.2 ┆ 8.0 │ diff --git a/src/MEDS_polars_functions/filter_measurements.py b/src/MEDS_polars_functions/filter_measurements.py index 6def789..8b8204c 100644 --- a/src/MEDS_polars_functions/filter_measurements.py +++ b/src/MEDS_polars_functions/filter_measurements.py @@ -7,7 +7,7 @@ def filter_codes_fntr( - stage_cfg: DictConfig, code_metadata: pl.LazyFrame + stage_cfg: DictConfig, code_metadata: pl.LazyFrame, code_modifier_columns: list[str] | None = None ) -> Callable[[pl.LazyFrame], pl.LazyFrame]: """Filters patient events to only encompass those with a set of permissible codes. @@ -19,6 +19,12 @@ def filter_codes_fntr( The processed DataFrame. Examples: + >>> code_metadata_df = pl.DataFrame({ + ... "code": pl.Series(["A", "A", "B", "C"], dtype=pl.Categorical), + ... "modifier1": [1, 2, 1, 2], + ... "code/n_patients": [1, 1, 2, 2], + ... "code/n_occurrences": [2, 1, 3, 2], + ... }) >>> raise NotImplementedError """ @@ -26,7 +32,7 @@ def filter_codes_fntr( def filter_outliers_fntr( - stage_cfg: DictConfig, code_metadata: pl.LazyFrame + stage_cfg: DictConfig, code_metadata: pl.LazyFrame, code_modifier_columns: list[str] | None = None ) -> Callable[[pl.LazyFrame], pl.LazyFrame]: """Filters patient events to only encompass those with a set of permissible codes. @@ -38,6 +44,19 @@ def filter_outliers_fntr( The processed DataFrame. Examples: + >>> code_metadata_df = pl.DataFrame({ + ... "code": pl.Series(["A", "A", "B", "C"], dtype=pl.Categorical), + ... "modifier1": [1, 2, 1, 2], + ... "code/n_patients": [1, 1, 2, 2], + ... "code/n_occurrences": [2, 1, 3, 2], + ... "values/n_patients": [1, 1, 2, 2], + ... "values/n_occurrences": [2, 1, 3, 2], + ... "values/n_ints": [0, 1, 3, 1], + ... "values/sum": [2.2, 6.0, 14.0, 12.5], + ... "values/sum_sqd": [2.42, 36.0, 84.0, 81.25], + ... "values/min": [0, -1, 2, 2.], + ... "values/max": [1.1, 6.0, 8.0, 7.5], + ... }) >>> raise NotImplementedError """ From 7b6280da52a0c42a7b9fb263a44366fb65b9704c Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 13 Jun 2024 14:01:05 -0400 Subject: [PATCH 04/53] added one test and implemented the basic filter codes logic --- .../filter_measurements.py | 119 ++++++++++++++++-- 1 file changed, 112 insertions(+), 7 deletions(-) diff --git a/src/MEDS_polars_functions/filter_measurements.py b/src/MEDS_polars_functions/filter_measurements.py index 8b8204c..12e3e7c 100644 --- a/src/MEDS_polars_functions/filter_measurements.py +++ b/src/MEDS_polars_functions/filter_measurements.py @@ -3,13 +3,15 @@ from collections.abc import Callable import polars as pl + +pl.enable_string_cache() from omegaconf import DictConfig def filter_codes_fntr( stage_cfg: DictConfig, code_metadata: pl.LazyFrame, code_modifier_columns: list[str] | None = None ) -> Callable[[pl.LazyFrame], pl.LazyFrame]: - """Filters patient events to only encompass those with a set of permissible codes. + """Returns a function that filters patient events to only encompass those with a set of permissible codes. Args: df: The input DataFrame. @@ -20,15 +22,118 @@ def filter_codes_fntr( Examples: >>> code_metadata_df = pl.DataFrame({ - ... "code": pl.Series(["A", "A", "B", "C"], dtype=pl.Categorical), - ... "modifier1": [1, 2, 1, 2], - ... "code/n_patients": [1, 1, 2, 2], - ... "code/n_occurrences": [2, 1, 3, 2], + ... "code": pl.Series(["A", "A", "B", "C"], dtype=pl.Categorical), + ... "modifier1": [1, 2, 1, 2], + ... "code/n_patients": [2, 1, 3, 2], + ... "code/n_occurrences": [4, 5, 3, 2], ... }) - >>> raise NotImplementedError + >>> data = pl.DataFrame({ + ... "patient_id": [1, 1, 2, 2], + ... "code": pl.Series(["A", "B", "A", "C"], dtype=pl.Categorical), + ... "modifier1": [1, 1, 2, 2], + ... }).lazy() + >>> stage_cfg = DictConfig({"min_patients_per_code": 2, "min_occurrences_per_code": 3}) + >>> fn = filter_codes_fntr(stage_cfg, code_metadata_df, ["modifier1"]) + >>> fn(data).collect() + shape: (2, 3) + ┌────────────┬──────┬───────────┐ + │ patient_id ┆ code ┆ modifier1 │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ cat ┆ i64 │ + ╞════════════╪══════╪═══════════╡ + │ 1 ┆ A ┆ 1 │ + │ 1 ┆ B ┆ 1 │ + └────────────┴──────┴───────────┘ + >>> stage_cfg = DictConfig({"min_patients_per_code": 1, "min_occurrences_per_code": 4}) + >>> fn = filter_codes_fntr(stage_cfg, code_metadata_df, ["modifier1"]) + >>> fn(data).collect() + shape: (2, 3) + ┌────────────┬──────┬───────────┐ + │ patient_id ┆ code ┆ modifier1 │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ cat ┆ i64 │ + ╞════════════╪══════╪═══════════╡ + │ 1 ┆ A ┆ 1 │ + │ 2 ┆ A ┆ 2 │ + └────────────┴──────┴───────────┘ + >>> stage_cfg = DictConfig({"min_patients_per_code": 1}) + >>> fn = filter_codes_fntr(stage_cfg, code_metadata_df, ["modifier1"]) + >>> fn(data).collect() + shape: (4, 3) + ┌────────────┬──────┬───────────┐ + │ patient_id ┆ code ┆ modifier1 │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ cat ┆ i64 │ + ╞════════════╪══════╪═══════════╡ + │ 1 ┆ A ┆ 1 │ + │ 1 ┆ B ┆ 1 │ + │ 2 ┆ A ┆ 2 │ + │ 2 ┆ C ┆ 2 │ + └────────────┴──────┴───────────┘ + >>> stage_cfg = DictConfig({"min_patients_per_code": None, "min_occurrences_per_code": None}) + >>> fn = filter_codes_fntr(stage_cfg, code_metadata_df, ["modifier1"]) + >>> fn(data).collect() + shape: (4, 3) + ┌────────────┬──────┬───────────┐ + │ patient_id ┆ code ┆ modifier1 │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ cat ┆ i64 │ + ╞════════════╪══════╪═══════════╡ + │ 1 ┆ A ┆ 1 │ + │ 1 ┆ B ┆ 1 │ + │ 2 ┆ A ┆ 2 │ + │ 2 ┆ C ┆ 2 │ + └────────────┴──────┴───────────┘ + >>> stage_cfg = DictConfig({"min_occurrences_per_code": 5}) + >>> fn = filter_codes_fntr(stage_cfg, code_metadata_df, ["modifier1"]) + >>> fn(data).collect() + shape: (1, 3) + ┌────────────┬──────┬───────────┐ + │ patient_id ┆ code ┆ modifier1 │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ cat ┆ i64 │ + ╞════════════╪══════╪═══════════╡ + │ 2 ┆ A ┆ 2 │ + └────────────┴──────┴───────────┘ """ - raise NotImplementedError + min_patients_per_code = stage_cfg.get("min_patients_per_code", None) + min_occurrences_per_code = stage_cfg.get("min_occurrences_per_code", None) + + filter_exprs = [] + if min_patients_per_code is not None: + filter_exprs.append(pl.col("code/n_patients") >= min_patients_per_code) + if min_occurrences_per_code is not None: + filter_exprs.append(pl.col("code/n_occurrences") >= min_occurrences_per_code) + + if not filter_exprs: + return lambda df: df + + join_cols = ["code"] + if code_modifier_columns: + join_cols.extend(code_modifier_columns) + + allowed_code_metadata = (code_metadata.filter(pl.all_horizontal(filter_exprs)).select(join_cols)).lazy() + + def filter_codes(df: pl.LazyFrame) -> pl.LazyFrame: + f"""Filters patient events to only encompass those with a set of permissible codes. + + In particular, this function filters the DataFrame to only include (code, modifier) pairs that have + at least {min_patients_per_code} patients and {min_occurrences_per_code} occurrences. + """ + + idx_col = "_row_idx" + while idx_col in df.columns: + idx_col = f"_{idx_col}" + + return ( + df.with_row_count(idx_col) + .join(allowed_code_metadata, on=join_cols, how="inner") + .sort(idx_col) + .drop(idx_col) + ) + + return filter_codes def filter_outliers_fntr( From cb0b2d39febd751e0cf8623001498b031c00883e Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 13 Jun 2024 14:38:24 -0400 Subject: [PATCH 05/53] Added the outlier filtering function. --- .../filter_measurements.py | 89 +++++++++++++++---- 1 file changed, 74 insertions(+), 15 deletions(-) diff --git a/src/MEDS_polars_functions/filter_measurements.py b/src/MEDS_polars_functions/filter_measurements.py index 12e3e7c..29cbee2 100644 --- a/src/MEDS_polars_functions/filter_measurements.py +++ b/src/MEDS_polars_functions/filter_measurements.py @@ -115,7 +115,7 @@ def filter_codes_fntr( allowed_code_metadata = (code_metadata.filter(pl.all_horizontal(filter_exprs)).select(join_cols)).lazy() - def filter_codes(df: pl.LazyFrame) -> pl.LazyFrame: + def filter_codes_fn(df: pl.LazyFrame) -> pl.LazyFrame: f"""Filters patient events to only encompass those with a set of permissible codes. In particular, this function filters the DataFrame to only include (code, modifier) pairs that have @@ -133,7 +133,7 @@ def filter_codes(df: pl.LazyFrame) -> pl.LazyFrame: .drop(idx_col) ) - return filter_codes + return filter_codes_fn def filter_outliers_fntr( @@ -150,19 +150,78 @@ def filter_outliers_fntr( Examples: >>> code_metadata_df = pl.DataFrame({ - ... "code": pl.Series(["A", "A", "B", "C"], dtype=pl.Categorical), - ... "modifier1": [1, 2, 1, 2], - ... "code/n_patients": [1, 1, 2, 2], - ... "code/n_occurrences": [2, 1, 3, 2], - ... "values/n_patients": [1, 1, 2, 2], - ... "values/n_occurrences": [2, 1, 3, 2], - ... "values/n_ints": [0, 1, 3, 1], - ... "values/sum": [2.2, 6.0, 14.0, 12.5], - ... "values/sum_sqd": [2.42, 36.0, 84.0, 81.25], - ... "values/min": [0, -1, 2, 2.], - ... "values/max": [1.1, 6.0, 8.0, 7.5], + ... "code": pl.Series(["A", "A", "B", "C"], dtype=pl.Categorical), + ... "modifier1": [1, 2, 1, 2], + ... "values/n_occurrences": [3, 1, 3, 2], + ... "values/sum": [0.0, 4.0, 12.0, 2.0], + ... "values/sum_sqd": [27.0, 16.0, 75.0, 4.0], + ... # for clarity: ----- mean = [0.0, 4.0, 4.0, 1.0] + ... # for clarity: --- stddev = [3.0, 0.0, 3.0, 1.0] ... }) - >>> raise NotImplementedError + >>> data = pl.DataFrame({ + ... "patient_id": [1, 1, 2, 2], + ... "code": pl.Series(["A", "B", "A", "C"], dtype=pl.Categorical), + ... "modifier1": [1, 1, 2, 2], + ... # for clarity: mean [0.0, 4.0, 4.0, 1.0] + ... # for clarity: stddev [3.0, 3.0, 0.0, 1.0] + ... "numerical_value": [15., 16., 3.9, 1.0], + ... }).lazy() + >>> stage_cfg = DictConfig({"stddev_cutoff": 4.5}) + >>> fn = filter_outliers_fntr(stage_cfg, code_metadata_df, ["modifier1"]) + >>> fn(data).collect() + shape: (4, 5) + ┌────────────┬──────┬───────────┬─────────────────┬───────────────────────────┐ + │ patient_id ┆ code ┆ modifier1 ┆ numerical_value ┆ numerical_value/is_inlier │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ cat ┆ i64 ┆ f64 ┆ bool │ + ╞════════════╪══════╪═══════════╪═════════════════╪═══════════════════════════╡ + │ 1 ┆ A ┆ 1 ┆ null ┆ false │ + │ 1 ┆ B ┆ 1 ┆ 16.0 ┆ true │ + │ 2 ┆ A ┆ 2 ┆ null ┆ false │ + │ 2 ┆ C ┆ 2 ┆ 1.0 ┆ true │ + └────────────┴──────┴───────────┴─────────────────┴───────────────────────────┘ """ - raise NotImplementedError + stddev_cutoff = stage_cfg.get("stddev_cutoff", None) + if stddev_cutoff is None: + return lambda df: df + + join_cols = ["code"] + if code_modifier_columns: + join_cols.extend(code_modifier_columns) + + cols_to_select = ["code"] + if code_modifier_columns: + cols_to_select.extend(code_modifier_columns) + + mean_col = pl.col("values/sum") / pl.col("values/n_occurrences") + stddev_col = (pl.col("values/sum_sqd") / pl.col("values/n_occurrences") - mean_col**2) ** 0.5 + if "values/mean" not in code_metadata.columns: + cols_to_select.append(mean_col.alias("values/mean")) + if "values/stddev" not in code_metadata.columns: + cols_to_select.append(stddev_col.alias("values/stddev")) + + code_metadata = code_metadata.lazy().select(cols_to_select) + + def filter_outliers_fn(df: pl.LazyFrame) -> pl.LazyFrame: + f"""Filters out outlier numerical values from patient events. + + In particular, this function filters the DataFrame to only include numerical values that are within + {stddev_cutoff} standard deviations of the mean for the corresponding (code, modifier) pair. + """ + + val = pl.col("numerical_value") + mean = pl.col("values/mean") + stddev = pl.col("values/stddev") + filter_expr = (val - mean).abs() <= stddev_cutoff * stddev + + return ( + df.join(code_metadata, on=join_cols, how="left") + .with_columns( + filter_expr.alias("numerical_value/is_inlier"), + pl.when(filter_expr).then(pl.col("numerical_value")).alias("numerical_value"), + ) + .drop("values/mean", "values/stddev") + ) + + return filter_outliers_fn From 70e3524e3b90e0be763218c6f3d520da73bb3201 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 13 Jun 2024 15:00:33 -0400 Subject: [PATCH 06/53] Added a code metadata extraction step for core MEDS extraction -- this will also facilitate joining with ontologies and downstream usage (and it tests that, more generally pre-processing oriented utility). --- configs/extraction.yaml | 21 +++++ scripts/__init__.py | 0 scripts/extraction/__init__.py | 0 scripts/extraction/collect_code_metadata.py | 84 +++++++++++++++++++ scripts/preprocessing/__init__.py | 0 .../add_time_derived_measurements.py | 2 +- .../preprocessing/collect_code_metadata.py | 8 +- scripts/preprocessing/filter_codes.py | 2 +- scripts/preprocessing/filter_outliers.py | 2 +- scripts/preprocessing/filter_patients.py | 2 +- tests/test_extraction.py | 19 ++++- 11 files changed, 130 insertions(+), 10 deletions(-) create mode 100644 scripts/__init__.py create mode 100644 scripts/extraction/__init__.py create mode 100644 scripts/extraction/collect_code_metadata.py create mode 100644 scripts/preprocessing/__init__.py diff --git a/configs/extraction.yaml b/configs/extraction.yaml index 1a1c0dd..6359b09 100644 --- a/configs/extraction.yaml +++ b/configs/extraction.yaml @@ -16,12 +16,15 @@ description: |- # The event conversion configuration file is used throughout the pipeline to define the events to extract. event_conversion_config_fp: ??? +# The code modifier columns are in this pipeline only used in the collect_code_metadata stage. +code_modifiers: null stages: - shard_events - split_and_shard_patients - convert_to_sharded_events - merge_to_MEDS_cohort + - collect_code_metadata stage_configs: shard_events: @@ -32,6 +35,7 @@ stage_configs: files are csvs) row_chunksize: 200000000 infer_schema_length: 10000 + split_and_shard_patients: description: |- This stage splits the patients into training, tuning, and held-out sets, and further splits those sets @@ -54,6 +58,7 @@ stage_configs: train: 0.8 tuning: 0.1 held_out: 0.1 + merge_to_MEDS_cohort: description: |- This stage splits the patients into training, tuning, and held-out sets, and further splits those sets @@ -65,3 +70,19 @@ stage_configs: - `split_fracs`: The fraction of patients to include in the IID training, tuning, and held-out sets. output_dir: ${cohort_dir}/final_cohort unique_by: "*" + + collect_code_metadata: + description: |- + This stage collects some descriptive metadata about the codes in the cohort. Arguments: + - `aggregations`: The aggregations to compute over the codes. Defaults to counts of code occurrences, + counts of patients with the code, and counts of value occurrences per code, as well as the sum and + sum of squares of values (for use in computing means and variances). + aggregations: + - "code/n_occurrences" + - "code/n_patients" + - "values/n_occurrences" + - "values/sum" + - "values/sum_sqd" + do_summarize_over_all_codes: true # This indicates we should include overall, code-independent counts + is_metadata: True + output_dir: ${cohort_dir} diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/extraction/__init__.py b/scripts/extraction/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/extraction/collect_code_metadata.py b/scripts/extraction/collect_code_metadata.py new file mode 100644 index 0000000..8986c12 --- /dev/null +++ b/scripts/extraction/collect_code_metadata.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python + +import json +import random +import time +from datetime import datetime +from pathlib import Path + +import hydra +import polars as pl +from loguru import logger +from omegaconf import DictConfig, OmegaConf + +from MEDS_polars_functions.code_metadata import mapper_fntr, reducer_fntr +from MEDS_polars_functions.mapper import wrap as rwlock_wrap +from MEDS_polars_functions.utils import hydra_loguru_init, write_lazyframe + + +@hydra.main(version_base=None, config_path="../../configs", config_name="extraction") +def main(cfg: DictConfig): + """Computes code metadata.""" + + hydra_loguru_init() + + logger.info( + f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" + f"Stage: {cfg.stage}\n\n" + f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" + ) + + input_dir = Path(cfg.stage_cfg.data_input_dir) + output_dir = Path(cfg.stage_cfg.output_dir) + + shards = json.loads((Path(cfg.stage_cfg.metadata_input_dir) / "splits.json").read_text()) + + patient_splits = list(shards.keys()) + random.shuffle(patient_splits) + + mapper_fn = mapper_fntr(cfg.stage_cfg, cfg.get("code_modifier_columns", None)) + + start = datetime.now() + logger.info("Starting code metadata mapping computation") + + all_out_fps = [] + for sp in patient_splits: + in_fp = input_dir / f"{sp}.parquet" + out_fp = output_dir / f"{sp}.parquet" + all_out_fps.append(out_fp) + + logger.info( + f"Computing code metadata for {str(in_fp.resolve())} and storing to {str(out_fp.resolve())}" + ) + + rwlock_wrap( + in_fp, + out_fp, + pl.scan_parquet, + write_lazyframe, + mapper_fn, + do_return=False, + cache_intermediate=False, + do_overwrite=cfg.do_overwrite, + ) + + logger.info(f"Finished mapping in {datetime.now() - start}") + + if cfg.worker != 1: + return + + while not all(fp.is_file() for fp in all_out_fps): + logger.info("Waiting to begin reduction for all files to be written...") + time.sleep(cfg.polling_time) + + start = datetime.now() + logger.info("All map shards complete! Starting code metadata reduction computation.") + reducer_fn = reducer_fntr(cfg.stage_cfg, cfg.get("code_modifier_columns", None)) + + reduced = reducer_fn(*[pl.scan_parquet(fp, glob=False) for fp in all_out_fps]) + write_lazyframe(reduced, output_dir / "code_metadata.parquet") + logger.info(f"Finished reduction in {datetime.now() - start}") + + +if __name__ == "__main__": + main() diff --git a/scripts/preprocessing/__init__.py b/scripts/preprocessing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/preprocessing/add_time_derived_measurements.py b/scripts/preprocessing/add_time_derived_measurements.py index 5731445..0eb5f5c 100644 --- a/scripts/preprocessing/add_time_derived_measurements.py +++ b/scripts/preprocessing/add_time_derived_measurements.py @@ -18,7 +18,7 @@ from MEDS_polars_functions.utils import hydra_loguru_init, write_lazyframe -@hydra.main(version_base=None, config_path="configs", config_name="preprocess") +@hydra.main(version_base=None, config_path="../../configs", config_name="preprocess") def main(cfg: DictConfig): """Adds time-derived measurements to a MEDS cohort as separate observations at each unique timestamp.""" diff --git a/scripts/preprocessing/collect_code_metadata.py b/scripts/preprocessing/collect_code_metadata.py index 40dd27d..6f0defa 100644 --- a/scripts/preprocessing/collect_code_metadata.py +++ b/scripts/preprocessing/collect_code_metadata.py @@ -16,7 +16,7 @@ from MEDS_polars_functions.utils import hydra_loguru_init, write_lazyframe -@hydra.main(version_base=None, config_path="configs", config_name="preprocess") +@hydra.main(version_base=None, config_path="../../configs", config_name="preprocess") def main(cfg: DictConfig): """Computes code metadata.""" @@ -31,9 +31,7 @@ def main(cfg: DictConfig): input_dir = Path(cfg.stage_cfg.data_input_dir) output_dir = Path(cfg.stage_cfg.output_dir) - logger.info(f"Reading data from input directory {str(input_dir.resolve())}") - - shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) + shards = json.loads((Path(cfg.stage_cfg.metadata_input_dir) / "splits.json").read_text()) patient_splits = list(shards.keys()) random.shuffle(patient_splits) @@ -77,7 +75,7 @@ def main(cfg: DictConfig): logger.info("All map shards complete! Starting code metadata reduction computation.") reducer_fn = reducer_fntr(cfg.stage_cfg, cfg.get("code_modifier_columns", None)) - reduced = reducer_fn(pl.scan_parquet(fp, glob=False) for fp in all_out_fps) + reduced = reducer_fn(*[pl.scan_parquet(fp, glob=False) for fp in all_out_fps]) write_lazyframe(reduced, output_dir / "code_metadata.parquet") logger.info(f"Finished reduction in {datetime.now() - start}") diff --git a/scripts/preprocessing/filter_codes.py b/scripts/preprocessing/filter_codes.py index 1f0e318..99dbaa9 100644 --- a/scripts/preprocessing/filter_codes.py +++ b/scripts/preprocessing/filter_codes.py @@ -14,7 +14,7 @@ from MEDS_polars_functions.utils import hydra_loguru_init, write_lazyframe -@hydra.main(version_base=None, config_path="configs", config_name="preprocess") +@hydra.main(version_base=None, config_path="../../configs", config_name="preprocess") def main(cfg: DictConfig): """TODO.""" diff --git a/scripts/preprocessing/filter_outliers.py b/scripts/preprocessing/filter_outliers.py index b1914a4..ce436b7 100644 --- a/scripts/preprocessing/filter_outliers.py +++ b/scripts/preprocessing/filter_outliers.py @@ -14,7 +14,7 @@ from MEDS_polars_functions.utils import hydra_loguru_init, write_lazyframe -@hydra.main(version_base=None, config_path="configs", config_name="preprocess") +@hydra.main(version_base=None, config_path="../../configs", config_name="preprocess") def main(cfg: DictConfig): """TODO.""" diff --git a/scripts/preprocessing/filter_patients.py b/scripts/preprocessing/filter_patients.py index 937d87f..3a7ded5 100644 --- a/scripts/preprocessing/filter_patients.py +++ b/scripts/preprocessing/filter_patients.py @@ -18,7 +18,7 @@ from MEDS_polars_functions.utils import hydra_loguru_init, write_lazyframe -@hydra.main(version_base=None, config_path="configs", config_name="preprocess") +@hydra.main(version_base=None, config_path="../../configs", config_name="preprocess") def main(cfg: DictConfig): """TODO.""" diff --git a/tests/test_extraction.py b/tests/test_extraction.py index 9343d17..d221a30 100644 --- a/tests/test_extraction.py +++ b/tests/test_extraction.py @@ -434,4 +434,21 @@ def test_extraction(): print(f"stdout:\n{full_stdout}") raise e - logger.warning("Only checked the train/0 split for now. TODO: add the rest of the splits.") + logger.warning("Only checked the train/0 split for now. TODO: add the rest of the splits.") + + # Step 4: Merge to the final output + stderr, stdout = run_command( + extraction_root / "collect_code_metadata.py", + extraction_config_kwargs, + "collect_code_metadata", + ) + all_stderrs.append(stderr) + all_stdouts.append(stdout) + + full_stderr = "\n".join(all_stderrs) + full_stdout = "\n".join(all_stdouts) + + output_file = MEDS_cohort_dir / "code_metadata.parquet" + assert output_file.is_file(), f"Expected {output_file} to exist: stderr:\n{stderr}\nstdout:\n{stdout}" + + logger.warning("Didn't check contents of code metadata!") From 92ec6da1c828cbb2e26e2c4baf652a0fd468eb84 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 13 Jun 2024 15:06:09 -0400 Subject: [PATCH 07/53] Made script executable --- scripts/extraction/collect_code_metadata.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 scripts/extraction/collect_code_metadata.py diff --git a/scripts/extraction/collect_code_metadata.py b/scripts/extraction/collect_code_metadata.py old mode 100644 new mode 100755 From f9a10fdb41c26997b6df27e87050ae8e74832848 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 13 Jun 2024 15:16:00 -0400 Subject: [PATCH 08/53] Set up a mapper output directory -- not sure if this is the best policy yet. --- configs/extraction.yaml | 3 ++- scripts/extraction/collect_code_metadata.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/configs/extraction.yaml b/configs/extraction.yaml index 6359b09..a0939e8 100644 --- a/configs/extraction.yaml +++ b/configs/extraction.yaml @@ -85,4 +85,5 @@ stage_configs: - "values/sum_sqd" do_summarize_over_all_codes: true # This indicates we should include overall, code-independent counts is_metadata: True - output_dir: ${cohort_dir} + mapper_output_dir: "${cohort_dir}/code_metadata" + output_dir: "${cohort_dir}" diff --git a/scripts/extraction/collect_code_metadata.py b/scripts/extraction/collect_code_metadata.py index 8986c12..97f1f69 100755 --- a/scripts/extraction/collect_code_metadata.py +++ b/scripts/extraction/collect_code_metadata.py @@ -30,6 +30,7 @@ def main(cfg: DictConfig): input_dir = Path(cfg.stage_cfg.data_input_dir) output_dir = Path(cfg.stage_cfg.output_dir) + mapper_output_dir = Path(cfg.stage_cfg.mapper_output_dir) shards = json.loads((Path(cfg.stage_cfg.metadata_input_dir) / "splits.json").read_text()) @@ -44,7 +45,7 @@ def main(cfg: DictConfig): all_out_fps = [] for sp in patient_splits: in_fp = input_dir / f"{sp}.parquet" - out_fp = output_dir / f"{sp}.parquet" + out_fp = mapper_output_dir / f"{sp}.parquet" all_out_fps.append(out_fp) logger.info( From 320c8b4b952d2a6e827d90e9d1e716e3b341577f Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 13 Jun 2024 16:10:41 -0400 Subject: [PATCH 09/53] made scripts executable --- scripts/preprocessing/add_time_derived_measurements.py | 0 scripts/preprocessing/collect_code_metadata.py | 0 scripts/preprocessing/filter_codes.py | 0 scripts/preprocessing/filter_outliers.py | 0 scripts/preprocessing/filter_patients.py | 0 5 files changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 scripts/preprocessing/add_time_derived_measurements.py mode change 100644 => 100755 scripts/preprocessing/collect_code_metadata.py mode change 100644 => 100755 scripts/preprocessing/filter_codes.py mode change 100644 => 100755 scripts/preprocessing/filter_outliers.py mode change 100644 => 100755 scripts/preprocessing/filter_patients.py diff --git a/scripts/preprocessing/add_time_derived_measurements.py b/scripts/preprocessing/add_time_derived_measurements.py old mode 100644 new mode 100755 diff --git a/scripts/preprocessing/collect_code_metadata.py b/scripts/preprocessing/collect_code_metadata.py old mode 100644 new mode 100755 diff --git a/scripts/preprocessing/filter_codes.py b/scripts/preprocessing/filter_codes.py old mode 100644 new mode 100755 diff --git a/scripts/preprocessing/filter_outliers.py b/scripts/preprocessing/filter_outliers.py old mode 100644 new mode 100755 diff --git a/scripts/preprocessing/filter_patients.py b/scripts/preprocessing/filter_patients.py old mode 100644 new mode 100755 From 9f487938102673afb37aaae4e984c29f6d4eec80 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 13 Jun 2024 16:14:07 -0400 Subject: [PATCH 10/53] fixed a small typo (maybe?) --- scripts/preprocessing/filter_patients.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/preprocessing/filter_patients.py b/scripts/preprocessing/filter_patients.py index 3a7ded5..5158694 100755 --- a/scripts/preprocessing/filter_patients.py +++ b/scripts/preprocessing/filter_patients.py @@ -30,10 +30,11 @@ def main(cfg: DictConfig): f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" ) + metadata_input_dir = Path(cfg.stage_cfg.metadata_input_dir) input_dir = Path(cfg.stage_cfg.data_input_dir) output_dir = Path(cfg.stage_cfg.output_dir) - shards = json.loads((cfg.input_dir / "splits.json").read_text()) + shards = json.loads((metadata_input_dir / "splits.json").read_text()) patient_splits = list(shards.keys()) random.shuffle(patient_splits) From e28ef244af2b0da8f27affe82b35e5821c28d4bb Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 13 Jun 2024 16:16:51 -0400 Subject: [PATCH 11/53] Fixed an issue with using cfg instead of stage_cfg --- scripts/preprocessing/filter_patients.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scripts/preprocessing/filter_patients.py b/scripts/preprocessing/filter_patients.py index 5158694..362afa7 100755 --- a/scripts/preprocessing/filter_patients.py +++ b/scripts/preprocessing/filter_patients.py @@ -40,23 +40,23 @@ def main(cfg: DictConfig): random.shuffle(patient_splits) compute_fns = [] - if cfg.min_measurements_per_patient: + if cfg.stage_cfg.min_measurements_per_patient: logger.info( - f"Filtering patients with fewer than {cfg.min_measurements_per_patient} measurements " + f"Filtering patients with fewer than {cfg.stage_cfg.min_measurements_per_patient} measurements " "(observations of any kind)." ) compute_fns.append( partial( filter_patients_by_num_measurements, - min_measurements_per_patient=cfg.min_measurements_per_patient, + min_measurements_per_patient=cfg.stage_cfg.min_measurements_per_patient, ) ) - if cfg.min_events_per_patient: + if cfg.stage_cfg.min_events_per_patient: logger.info( - f"Filtering patients with fewer than {cfg.min_events_per_patient} events (unique timepoints)." + f"Filtering patients with fewer than {cfg.stage_cfg.min_events_per_patient} events (unique timepoints)." ) compute_fns.append( - partial(filter_patients_by_num_events, min_events_per_patient=cfg.min_events_per_patient) + partial(filter_patients_by_num_events, min_events_per_patient=cfg.stage_cfg.min_events_per_patient) ) for sp in patient_splits: From 3821d302bff9329e6f61014c14718fc37b7f22b1 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 13 Jun 2024 16:21:28 -0400 Subject: [PATCH 12/53] Fixing some other typos --- scripts/preprocessing/add_time_derived_measurements.py | 2 +- scripts/preprocessing/filter_codes.py | 2 +- scripts/preprocessing/filter_outliers.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/preprocessing/add_time_derived_measurements.py b/scripts/preprocessing/add_time_derived_measurements.py index 0eb5f5c..09fff0a 100755 --- a/scripts/preprocessing/add_time_derived_measurements.py +++ b/scripts/preprocessing/add_time_derived_measurements.py @@ -32,7 +32,7 @@ def main(cfg: DictConfig): output_dir = Path(cfg.stage_cfg.output_dir) - shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) + shards = json.loads((Path(cfg.stage_cfg.metadata_input_dir) / "splits.json").read_text()) input_dir = Path(cfg.stage_cfg.data_input_dir) logger.info(f"Reading data from {str(input_dir.resolve())}") diff --git a/scripts/preprocessing/filter_codes.py b/scripts/preprocessing/filter_codes.py index 99dbaa9..fe3bf20 100755 --- a/scripts/preprocessing/filter_codes.py +++ b/scripts/preprocessing/filter_codes.py @@ -30,7 +30,7 @@ def main(cfg: DictConfig): metadata_input_dir = Path(cfg.stage_cfg.metadata_input_dir) output_dir = Path(cfg.stage_cfg.output_dir) - shards = json.loads((cfg.input_dir / "splits.json").read_text()) + shards = json.loads((metadata_input_dir / "splits.json").read_text()) patient_splits = list(shards.keys()) random.shuffle(patient_splits) diff --git a/scripts/preprocessing/filter_outliers.py b/scripts/preprocessing/filter_outliers.py index ce436b7..46214e0 100755 --- a/scripts/preprocessing/filter_outliers.py +++ b/scripts/preprocessing/filter_outliers.py @@ -30,7 +30,7 @@ def main(cfg: DictConfig): metadata_input_dir = Path(cfg.stage_cfg.metadata_input_dir) output_dir = Path(cfg.stage_cfg.output_dir) - shards = json.loads((cfg.input_dir / "splits.json").read_text()) + shards = json.loads((metadata_input_dir / "splits.json").read_text()) patient_splits = list(shards.keys()) random.shuffle(patient_splits) From c1e7b485a6f9932036485abe07462ba320161763 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 13 Jun 2024 16:23:13 -0400 Subject: [PATCH 13/53] Fixing anohter typo in time-derived --- scripts/preprocessing/add_time_derived_measurements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/preprocessing/add_time_derived_measurements.py b/scripts/preprocessing/add_time_derived_measurements.py index 09fff0a..3d585b9 100755 --- a/scripts/preprocessing/add_time_derived_measurements.py +++ b/scripts/preprocessing/add_time_derived_measurements.py @@ -43,7 +43,7 @@ def main(cfg: DictConfig): compute_fns = [] # We use the raw stages object as the induced `stage_cfg` has extra properties like the input and output # directories. - for feature_name, feature_cfg in cfg.stages[cfg.stage].items(): + for feature_name, feature_cfg in cfg.stage_cfg.items(): match feature_name: case "age": compute_fns.append(add_new_events_fntr(age_fntr(feature_cfg))) From 719e1d0c8556fd5f1aceb23726ee01360a0ba6dd Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 13 Jun 2024 16:44:47 -0400 Subject: [PATCH 14/53] Some small corrections and documentation for getting time derived to work. --- MIMIC-IV_Example/README.md | 19 +++++++++++++++++++ configs/preprocess.yaml | 5 +++-- .../add_time_derived_measurements.py | 4 ++++ src/MEDS_polars_functions/utils.py | 3 +++ 4 files changed, 29 insertions(+), 2 deletions(-) diff --git a/MIMIC-IV_Example/README.md b/MIMIC-IV_Example/README.md index 406f1f2..d804d2d 100644 --- a/MIMIC-IV_Example/README.md +++ b/MIMIC-IV_Example/README.md @@ -193,6 +193,25 @@ and performance is not necessary; however, for larger datasets, it can be. event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml ``` +5. (Optional) Generate preliminary code statistics and merge to external metadata. + +## Pre-processing for a model +To run the pre-processing steps for a model, consider the sample script provided here: + +1. Filter patients to only those with at least 32 events (unique timepoints): +```bash +mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines +❯ ./scripts/preprocessing/filter_patients.py --multirun worker="range(0,3)" hydra/launcher=joblib input_dir="$MIMICIV_MEDS_DIR/3workers_slurm" cohort_dir="$MIMICIV_MEDS_PROC_DIR/test" code_modifier_columns=null stage_configs.filter_patients.min_events_per_patient=32 +``` + +2. Add time-derived measurements (age and time-of-day): +```bash +mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines took 3s +❯ ./scripts/preprocessing/add_time_derived_measurements.py --multirun worker="range(0,3)" hydra/launcher=joblib input_dir="$MIMICIV_MEDS_DIR/3workers_slurm" cohort_dir="$MIMICIV_MEDS_PROC_DI +R/test" code_modifier_columns=null stage_configs.add_time_derived_measurements.age.DOB_code="DOB" +``` + + ## Limitations / TO-DOs: Currently, some tables are ignored, including: diff --git a/configs/preprocess.yaml b/configs/preprocess.yaml index 088541e..d52f573 100644 --- a/configs/preprocess.yaml +++ b/configs/preprocess.yaml @@ -27,11 +27,12 @@ stage_configs: add_time_derived_measurements: age: - dob_code: ??? + DOB_code: ??? age_code: "AGE" age_unit: "years" time_of_day: - bin_endpoints: [6, 12, 18, 24] + time_of_day_code: "TIME_OF_DAY" + endpoints: [6, 12, 18, 24] preliminary_counts: aggregations: diff --git a/scripts/preprocessing/add_time_derived_measurements.py b/scripts/preprocessing/add_time_derived_measurements.py index 3d585b9..bd1dded 100755 --- a/scripts/preprocessing/add_time_derived_measurements.py +++ b/scripts/preprocessing/add_time_derived_measurements.py @@ -17,6 +17,8 @@ ) from MEDS_polars_functions.utils import hydra_loguru_init, write_lazyframe +INFERRED_STAGE_KEYS = {"is_metadata", "data_input_dir", "metadata_input_dir", "output_dir"} + @hydra.main(version_base=None, config_path="../../configs", config_name="preprocess") def main(cfg: DictConfig): @@ -49,6 +51,8 @@ def main(cfg: DictConfig): compute_fns.append(add_new_events_fntr(age_fntr(feature_cfg))) case "time_of_day": compute_fns.append(add_new_events_fntr(time_of_day_fntr(feature_cfg))) + case str() if feature_name in INFERRED_STAGE_KEYS: + continue case _: raise ValueError(f"Unknown time-derived measurement: {feature_name}") diff --git a/src/MEDS_polars_functions/utils.py b/src/MEDS_polars_functions/utils.py index a307bae..045e053 100644 --- a/src/MEDS_polars_functions/utils.py +++ b/src/MEDS_polars_functions/utils.py @@ -150,6 +150,9 @@ def populate_stage( "output_dir": os.path.join(cohort_dir, stage_name), } + if "is_metadata" in stage and not isinstance(stage["is_metadata"], bool): + raise TypeError(f"If specified manually, is_metadata must be a boolean. Got {stage['is_metadata']}") + out = {**stage} for key, val in inferred_keys.items(): if key not in out or out[key] is None: From a03862fd09cb4e81e40ae9440ee10c6e8d8978ab Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 13 Jun 2024 17:16:45 -0400 Subject: [PATCH 15/53] Fixing various typos -- the metadata_input_dir change was actually wrong; shards should come from the raw input dir for pre-processing jobs. --- MIMIC-IV_Example/README.md | 6 ++++++ scripts/preprocessing/filter_codes.py | 4 ++-- scripts/preprocessing/filter_outliers.py | 4 ++-- scripts/preprocessing/filter_patients.py | 4 ++-- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/MIMIC-IV_Example/README.md b/MIMIC-IV_Example/README.md index d804d2d..a568f52 100644 --- a/MIMIC-IV_Example/README.md +++ b/MIMIC-IV_Example/README.md @@ -211,6 +211,12 @@ mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps R/test" code_modifier_columns=null stage_configs.add_time_derived_measurements.age.DOB_code="DOB" ``` +3. Get preliminary counts for code filtering: +```bash +mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines +❯ ./scripts/preprocessing/collect_code_metadata.py --multirun worker="range(0,3)" hydra/launcher=joblib input_dir="$MIMICIV_MEDS_DIR/3workers_slurm" cohort_dir="$MIMICIV_MEDS_PROC_DIR/test" code_modifier_columns=null stage="preliminary_counts" +``` + ## Limitations / TO-DOs: diff --git a/scripts/preprocessing/filter_codes.py b/scripts/preprocessing/filter_codes.py index fe3bf20..d086f5e 100755 --- a/scripts/preprocessing/filter_codes.py +++ b/scripts/preprocessing/filter_codes.py @@ -30,7 +30,7 @@ def main(cfg: DictConfig): metadata_input_dir = Path(cfg.stage_cfg.metadata_input_dir) output_dir = Path(cfg.stage_cfg.output_dir) - shards = json.loads((metadata_input_dir / "splits.json").read_text()) + shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) patient_splits = list(shards.keys()) random.shuffle(patient_splits) @@ -39,7 +39,7 @@ def main(cfg: DictConfig): compute_fn = filter_codes_fntr(cfg.stage_cfg, code_metadata) for sp in patient_splits: - in_fp = input_dir / "final_cohort" / f"{sp}.parquet" + in_fp = input_dir / f"{sp}.parquet" out_fp = output_dir / f"{sp}.parquet" logger.info(f"Filtering {str(in_fp.resolve())} into {str(out_fp.resolve())}") diff --git a/scripts/preprocessing/filter_outliers.py b/scripts/preprocessing/filter_outliers.py index 46214e0..1643e07 100755 --- a/scripts/preprocessing/filter_outliers.py +++ b/scripts/preprocessing/filter_outliers.py @@ -30,7 +30,7 @@ def main(cfg: DictConfig): metadata_input_dir = Path(cfg.stage_cfg.metadata_input_dir) output_dir = Path(cfg.stage_cfg.output_dir) - shards = json.loads((metadata_input_dir / "splits.json").read_text()) + shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) patient_splits = list(shards.keys()) random.shuffle(patient_splits) @@ -39,7 +39,7 @@ def main(cfg: DictConfig): compute_fn = filter_outliers_fntr(cfg.stage_cfg, code_metadata) for sp in patient_splits: - in_fp = input_dir / "final_cohort" / f"{sp}.parquet" + in_fp = input_dir / f"{sp}.parquet" out_fp = output_dir / f"{sp}.parquet" logger.info(f"Filtering {str(in_fp.resolve())} into {str(out_fp.resolve())}") diff --git a/scripts/preprocessing/filter_patients.py b/scripts/preprocessing/filter_patients.py index 362afa7..1beef8e 100755 --- a/scripts/preprocessing/filter_patients.py +++ b/scripts/preprocessing/filter_patients.py @@ -34,7 +34,7 @@ def main(cfg: DictConfig): input_dir = Path(cfg.stage_cfg.data_input_dir) output_dir = Path(cfg.stage_cfg.output_dir) - shards = json.loads((metadata_input_dir / "splits.json").read_text()) + shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) patient_splits = list(shards.keys()) random.shuffle(patient_splits) @@ -60,7 +60,7 @@ def main(cfg: DictConfig): ) for sp in patient_splits: - in_fp = input_dir / "final_cohort" / f"{sp}.parquet" + in_fp = input_dir / f"{sp}.parquet" out_fp = output_dir / f"{sp}.parquet" logger.info(f"Filtering {str(in_fp.resolve())} into {str(out_fp.resolve())}") From 01f3f789c5e056f290abc483c26bfae05dd29116 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 13 Jun 2024 18:17:23 -0400 Subject: [PATCH 16/53] Fixed some more typos and added more working or at least running commands to the MIMIC example. --- MIMIC-IV_Example/README.md | 24 +++++++++++++++++++ .../preprocessing/collect_code_metadata.py | 2 +- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/MIMIC-IV_Example/README.md b/MIMIC-IV_Example/README.md index a568f52..f51bf98 100644 --- a/MIMIC-IV_Example/README.md +++ b/MIMIC-IV_Example/README.md @@ -217,6 +217,30 @@ mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps ❯ ./scripts/preprocessing/collect_code_metadata.py --multirun worker="range(0,3)" hydra/launcher=joblib input_dir="$MIMICIV_MEDS_DIR/3workers_slurm" cohort_dir="$MIMICIV_MEDS_PROC_DIR/test" code_modifier_columns=null stage="preliminary_counts" ``` +4. Filter codes: +```bash +mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines took 4s +❯ ./scripts/preprocessing/filter_codes.py --multirun worker="range(0,3)" hydra/launcher=joblib input_dir="$MIMICIV_MEDS_DIR/3workers_slurm" cohort_dir="$MIMICIV_MEDS_PROC_DIR/test" code_modi +fier_columns=null stage_configs.filter_codes.min_patients_per_code=128 stage_configs.filter_codes.min_occurrences_per_code=256 +``` + +5. Get outlier detection params: +```bash +mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines took 19m57s +❯ ./scripts/preprocessing/collect_code_metadata.py --multirun worker="range(0,3)" hydra/launcher=joblib input_dir="$MIMICIV_MEDS_DIR/3workers_slurm" cohort_dir="$MIMICIV_MEDS_PROC_DIR/test" code_modifier_columns=null stage=fit_outlier_detection +``` + +6. Filter outliers: +```bash +mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines took 5m14s +❯ ./scripts/preprocessing/filter_outliers.py --multirun worker="range(0,3)" hydra/launcher=joblib input_dir="$MIMICIV_MEDS_DIR/3workers_slurm" cohort_dir="$MIMICIV_MEDS_PROC_DIR/test" code_modifier_columns=null +``` + +7. Fit normalization parameters: +```bash + +``` + ## Limitations / TO-DOs: diff --git a/scripts/preprocessing/collect_code_metadata.py b/scripts/preprocessing/collect_code_metadata.py index 6f0defa..71a6b27 100755 --- a/scripts/preprocessing/collect_code_metadata.py +++ b/scripts/preprocessing/collect_code_metadata.py @@ -31,7 +31,7 @@ def main(cfg: DictConfig): input_dir = Path(cfg.stage_cfg.data_input_dir) output_dir = Path(cfg.stage_cfg.output_dir) - shards = json.loads((Path(cfg.stage_cfg.metadata_input_dir) / "splits.json").read_text()) + shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) patient_splits = list(shards.keys()) random.shuffle(patient_splits) From 6123087e115f145e95208232084221cf7b538ff7 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 13 Jun 2024 18:33:35 -0400 Subject: [PATCH 17/53] Added command for getting normalization parameters --- MIMIC-IV_Example/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/MIMIC-IV_Example/README.md b/MIMIC-IV_Example/README.md index f51bf98..d939212 100644 --- a/MIMIC-IV_Example/README.md +++ b/MIMIC-IV_Example/README.md @@ -238,7 +238,8 @@ mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps 7. Fit normalization parameters: ```bash - +mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines took 16m25s +❯ ./scripts/preprocessing/collect_code_metadata.py --multirun worker="range(0,3)" hydra/launcher=joblib input_dir="$MIMICIV_MEDS_DIR/3workers_slurm" cohort_dir="$MIMICIV_MEDS_PROC_DIR/test" code_modifier_columns=null stage=fit_normalization ``` From 148fdd51fcd285931312ac60817a48efd0563441 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Fri, 14 Jun 2024 16:15:05 -0400 Subject: [PATCH 18/53] Added code to get lexicographic code indices for vocabulary tokenization. --- MIMIC-IV_Example/README.md | 23 ++- scripts/preprocessing/filter_patients.py | 8 +- src/MEDS_polars_functions/get_vocabulary.py | 185 ++++++++++++++++++++ src/MEDS_polars_functions/utils.py | 2 +- terminology.md | 8 + 5 files changed, 214 insertions(+), 12 deletions(-) create mode 100644 src/MEDS_polars_functions/get_vocabulary.py create mode 100644 terminology.md diff --git a/MIMIC-IV_Example/README.md b/MIMIC-IV_Example/README.md index d939212..6bdba6f 100644 --- a/MIMIC-IV_Example/README.md +++ b/MIMIC-IV_Example/README.md @@ -196,53 +196,60 @@ and performance is not necessary; however, for larger datasets, it can be. 5. (Optional) Generate preliminary code statistics and merge to external metadata. ## Pre-processing for a model + To run the pre-processing steps for a model, consider the sample script provided here: 1. Filter patients to only those with at least 32 events (unique timepoints): + ```bash -mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines +mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines ❯ ./scripts/preprocessing/filter_patients.py --multirun worker="range(0,3)" hydra/launcher=joblib input_dir="$MIMICIV_MEDS_DIR/3workers_slurm" cohort_dir="$MIMICIV_MEDS_PROC_DIR/test" code_modifier_columns=null stage_configs.filter_patients.min_events_per_patient=32 ``` 2. Add time-derived measurements (age and time-of-day): + ```bash -mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines took 3s +mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines took 3s ❯ ./scripts/preprocessing/add_time_derived_measurements.py --multirun worker="range(0,3)" hydra/launcher=joblib input_dir="$MIMICIV_MEDS_DIR/3workers_slurm" cohort_dir="$MIMICIV_MEDS_PROC_DI R/test" code_modifier_columns=null stage_configs.add_time_derived_measurements.age.DOB_code="DOB" ``` 3. Get preliminary counts for code filtering: + ```bash -mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines +mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines ❯ ./scripts/preprocessing/collect_code_metadata.py --multirun worker="range(0,3)" hydra/launcher=joblib input_dir="$MIMICIV_MEDS_DIR/3workers_slurm" cohort_dir="$MIMICIV_MEDS_PROC_DIR/test" code_modifier_columns=null stage="preliminary_counts" ``` 4. Filter codes: + ```bash -mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines took 4s +mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines took 4s ❯ ./scripts/preprocessing/filter_codes.py --multirun worker="range(0,3)" hydra/launcher=joblib input_dir="$MIMICIV_MEDS_DIR/3workers_slurm" cohort_dir="$MIMICIV_MEDS_PROC_DIR/test" code_modi fier_columns=null stage_configs.filter_codes.min_patients_per_code=128 stage_configs.filter_codes.min_occurrences_per_code=256 ``` 5. Get outlier detection params: + ```bash -mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines took 19m57s +mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines took 19m57s ❯ ./scripts/preprocessing/collect_code_metadata.py --multirun worker="range(0,3)" hydra/launcher=joblib input_dir="$MIMICIV_MEDS_DIR/3workers_slurm" cohort_dir="$MIMICIV_MEDS_PROC_DIR/test" code_modifier_columns=null stage=fit_outlier_detection ``` 6. Filter outliers: + ```bash -mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines took 5m14s +mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines took 5m14s ❯ ./scripts/preprocessing/filter_outliers.py --multirun worker="range(0,3)" hydra/launcher=joblib input_dir="$MIMICIV_MEDS_DIR/3workers_slurm" cohort_dir="$MIMICIV_MEDS_PROC_DIR/test" code_modifier_columns=null ``` 7. Fit normalization parameters: + ```bash -mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines took 16m25s +mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines took 16m25s ❯ ./scripts/preprocessing/collect_code_metadata.py --multirun worker="range(0,3)" hydra/launcher=joblib input_dir="$MIMICIV_MEDS_DIR/3workers_slurm" cohort_dir="$MIMICIV_MEDS_PROC_DIR/test" code_modifier_columns=null stage=fit_normalization ``` - ## Limitations / TO-DOs: Currently, some tables are ignored, including: diff --git a/scripts/preprocessing/filter_patients.py b/scripts/preprocessing/filter_patients.py index 1beef8e..602ae20 100755 --- a/scripts/preprocessing/filter_patients.py +++ b/scripts/preprocessing/filter_patients.py @@ -30,7 +30,6 @@ def main(cfg: DictConfig): f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" ) - metadata_input_dir = Path(cfg.stage_cfg.metadata_input_dir) input_dir = Path(cfg.stage_cfg.data_input_dir) output_dir = Path(cfg.stage_cfg.output_dir) @@ -53,10 +52,13 @@ def main(cfg: DictConfig): ) if cfg.stage_cfg.min_events_per_patient: logger.info( - f"Filtering patients with fewer than {cfg.stage_cfg.min_events_per_patient} events (unique timepoints)." + f"Filtering patients with fewer than {cfg.stage_cfg.min_events_per_patient} events " + "(unique timepoints)." ) compute_fns.append( - partial(filter_patients_by_num_events, min_events_per_patient=cfg.stage_cfg.min_events_per_patient) + partial( + filter_patients_by_num_events, min_events_per_patient=cfg.stage_cfg.min_events_per_patient + ) ) for sp in patient_splits: diff --git a/src/MEDS_polars_functions/get_vocabulary.py b/src/MEDS_polars_functions/get_vocabulary.py new file mode 100644 index 0000000..318b9e1 --- /dev/null +++ b/src/MEDS_polars_functions/get_vocabulary.py @@ -0,0 +1,185 @@ +"""Simple helper functions to define a consistent code vocabulary for normalizing a MEDS dataset.""" + +from collections.abc import Callable +from enum import StrEnum + +import polars as pl + + +class VOCABULARY_ORDERING(StrEnum): + """Enumeration of different ways a vocabulary order can be selected. + + These are stored as a `StrEnum` so that they can be easily specified by the user in a configuration file + or on the command line. + + Currently, only one ordering method is supported, but others can be added, such as a frequency-based + ordering so that the most frequent codes have the smallest indices. + + Args: + "lexicographic": Assigns vocabulary indices to codes and code modifiers via a lexicographic order. + """ + + LEXICOGRAPHIC = "lexicographic" + + +INDEX_ASSIGNMENT_FN = Callable[[pl.DataFrame, list[str]], pl.DataFrame] + + +def validate_code_metadata(code_metadata: pl.DataFrame, code_modifiers: list[str]): + """Validate the code metadata has the requisite columns and is unique. + + Args: + code_metadata: Metadata about the codes in the MEDS dataset, with a column `code` and a collection + of code modifier columns. + code_modifiers: The names of the code modifier columns in the `code_metadata` dataset. + + Raises: + KeyError: If the `code_metadata` dataset does not contain the specified `code_modifiers` or `code` + columns. + ValueError: If the `code_metadata` dataset is not unique on the `code` and `code_modifiers` columns. + + Examples: + >>> code_metadata = pl.DataFrame({ + ... "code": pl.Series(["A", "B", "A", "A"], dtype=pl.Categorical), + ... "modifier1": ["X", "D", "Z", "Z"], + ... "modifier2": [None, None, None, 3], + ... }) + >>> validate_code_metadata(code_metadata, ["modifier1", "modifier2"]) + >>> # This returns None in the absence of an exception. + >>> code_metadata = pl.DataFrame({ + ... "code": pl.Series(["A", "B", "A", "A"], dtype=pl.Categorical), + ... "modifier1": ["X", "D", "Z", "Z"], + ... "modifier2": [None, None, None, 3], + ... }) + >>> validate_code_metadata(code_metadata, ["modifier1", "modifier2", "missing_modifier"]) + Traceback (most recent call last): + ... + KeyError: "The following columns are not present in the code metadata: 'missing_modifier'." + >>> code_metadata = pl.DataFrame({ + ... "code": pl.Series(["A", "B", "A", "A", "B", "B"], dtype=pl.Categorical), + ... "modifier1": ["X", "D", "Z", "Z", "Y", "Y"], + ... "modifier2": [None, None, None, None, 2, 1], + ... }) + >>> validate_code_metadata(code_metadata, ["modifier1", "modifier2"]) + Traceback (most recent call last): + ... + ValueError: The code and code modifiers are not unique: + shape: (1, 4) + ┌──────┬───────────┬───────────┬───────┐ + │ code ┆ modifier1 ┆ modifier2 ┆ count │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ cat ┆ str ┆ i64 ┆ u32 │ + ╞══════╪═══════════╪═══════════╪═══════╡ + │ A ┆ Z ┆ null ┆ 2 │ + └──────┴───────────┴───────────┴───────┘ + """ + + cols = ["code"] + code_modifiers + + # Check that the code and code modifiers are present in the code metadata + if not set(cols).issubset(code_metadata.columns): + missing_cols = set(cols) - set(code_metadata.columns) + missing_cols_str = "', '".join(missing_cols) + raise KeyError(f"The following columns are not present in the code metadata: '{missing_cols_str}'.") + + # Check that the code and code modifiers are unique + n_unique_codes = code_metadata.n_unique(cols) + n_total_rows = len(code_metadata) + + if n_unique_codes != n_total_rows: + code_counts = code_metadata.groupby(cols).agg(pl.len().alias("count")).sort("count", descending=True) + extra_codes = code_counts.filter(pl.col("count") > 1) + raise ValueError(f"The code and code modifiers are not unique:\n{extra_codes.head(100)}") + + +def lexicographic_indices(code_metadata: pl.DataFrame, code_modifiers: list[str]) -> pl.DataFrame: + """Assign vocabulary indices to codes and code modifiers via a lexicographic order. + + Args: + code_metadata: Metadata about the codes in the MEDS dataset, with a column `code` and a collection + of code modifier columns. + code_modifiers: The names of the code modifier columns in the `code_metadata` dataset. Each of these + columns should be lexicographically orderable. + + Returns: + The code metadata dataframe with an additional column added that hasvocabulary token indices to the + code + modifier unique combinations in the `code_metadata` dataset. The expression will be aliased to + "code/vocab_index", and will be of the smallest unsigned dtype possible given the number of included + vocabulary elements. The given vocabulary indices will correspond to the following order: + - The index `0` will be assigned to a sentinel, `"UNK"` code for codes/modifiers not present in the + vocabulary. + - The remaining indices will be assigned to the unique code + modifier combinations such that + sorting the code + modifier combinations by the assigned index will result in a lexicographically + ordered list of codes + modifiers, sorting first by code and subsequently by each modifier column, + in order of specification. The sort will go from smallest lexiographic value to largest (e.g., be + an ascending sort). `null` values in the modifier columns (`null`s are disallowed in the code + columns) will be treated as the smallest possible value in the lexicographic order. + + Raises: + KeyError: If the `code_metadata` dataset does not contain the specified `code_modifiers` or `code` + columns. + ValueError: If the `code_metadata` dataset is not unique on the `code` and `code_modifiers` columns. + ValueError: If the `code` and `code_modifier` columns are not all lexicographically orderable. + + Examples: + >>> code_metadata = pl.DataFrame({ + ... "code": pl.Series(["A", "B", "A", "A", "B", "B"], dtype=pl.Categorical), + ... "modifier1": ["X", "D", None, "Z", "Y", "Y"], + ... "modifier2": [None, None, None, None, 2, 1], + ... }) + >>> code_modifiers = ["modifier1", "modifier2"] + >>> expr = lexicographic_indices(code_metadata, code_modifiers) + >>> code_metadata.with_columns(expr) + shape: (6, 4) + ┌──────┬───────────┬───────────┬──────────────────┐ + │ code ┆ modifier1 ┆ modifier2 ┆ code/vocab_index │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ cat ┆ str ┆ i64 ┆ u8 │ + ╞══════╪═══════════╪═══════════╪══════════════════╡ + │ A ┆ X ┆ null ┆ 2 │ + │ B ┆ D ┆ null ┆ 4 │ + │ A ┆ null ┆ null ┆ 1 │ + │ A ┆ Z ┆ null ┆ 3 │ + │ B ┆ Y ┆ 2 ┆ 6 │ + │ B ┆ Y ┆ 1 ┆ 5 │ + └──────┴───────────┴───────────┴──────────────────┘ + """ + + validate_code_metadata(code_metadata, code_modifiers) + + # We'll perform this sort in three steps. To guide this, consider as an example that our set of codes is + # ["B", "D", "A", "C"]. For this, we want to produce a set of indices that correspond to the lexicographic + # order of the codes, starting at 1; namely, [2, 4, 1, 3]. This is because "B" is the 2nd letter, "D" the + # fourth, and so on. + # 1. First, we'll use an `pl.arg_sort_by` to produce a set of indices that, were we to select the rows in + # that order, would give us the codes in lexicographic order. This is _not_ the final order -- it tells + # us what row we'd need to put in each position to _get_ the codes in sorted order. + # E.g., if our codes were ["B", "D", "A", "C"], the result of this step would be [2, 0, 3, 1], because + # we'd need to get the 2nd row to have the first lexicographically ordered code, the 0th row to have + # the second, and so on. + # 2. Second, we'll use _another_ `pl.arg_sort_by` to identify the row indices that would sort the very + # sort indices we just produced. This works because the index of the destination each row would have in + # the final sorted array is exactly the position that that row's index appears in the sort indices + # produced in step 1, by definition. And, in the second arg-sort, when we ask "which row do we need to + # grab to fill slot $j$ in this array with the sorted element that would belong at position $j$ of this + # set of numbers between $0$ and $N-1$, we are really asking which row has $j$ in it now, which is + # exactly the lexicographically ordered index. + # 3. Finally, third, we will add one (to start at one) and shrink the dtype. + # + # Note that we use this algorithm over something like just sorting the whole dataframe once then assigning + # integer indices is that this approach merely assigns indices, and does not change the order of the + # dataframe, and similarly does not require actually touching any of the memory of the dataframe. Though, + # admittedly, it is not clear how significant this choice is in practice. + + sort_cols = ["code"] + code_modifiers + + return code_metadata.with_columns( + (pl.arg_sort_by(pl.arg_sort_by(sort_cols, descending=False, nulls_last=False)) + 1) + .shrink_dtype() + .alias("code/vocab_index") + ) + + +VOCABULARY_ORDERING_METHODS: dict[VOCABULARY_ORDERING, INDEX_ASSIGNMENT_FN] = { + VOCABULARY_ORDERING.LEXICOGRAPHIC: lexicographic_indices, +} diff --git a/src/MEDS_polars_functions/utils.py b/src/MEDS_polars_functions/utils.py index 045e053..4f9ff7e 100644 --- a/src/MEDS_polars_functions/utils.py +++ b/src/MEDS_polars_functions/utils.py @@ -150,7 +150,7 @@ def populate_stage( "output_dir": os.path.join(cohort_dir, stage_name), } - if "is_metadata" in stage and not isinstance(stage["is_metadata"], bool): + if "is_metadata" in stage and not isinstance(stage["is_metadata"], (bool, type(None))): raise TypeError(f"If specified manually, is_metadata must be a boolean. Got {stage['is_metadata']}") out = {**stage} diff --git a/terminology.md b/terminology.md new file mode 100644 index 0000000..0fc57e7 --- /dev/null +++ b/terminology.md @@ -0,0 +1,8 @@ +# Canonical Definitions for MEDS Terminology Elements + +#### "vocabulary index" or "code index" + +The integer index (starting from 0, which will always correspond to an `"UNK"` vocabulary element) that +uniquely identifies where in the ordered list of vocabulary elements a given element is located. This will be +used as an integral or positional encoding of the vocabulary element for things like embedding matrices, +output layer logit identification, etc. From 935102e02f03a20f81ecb51738f558a68ad45b8a Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Fri, 14 Jun 2024 16:58:01 -0400 Subject: [PATCH 19/53] Added (yet untested) script for fitting vocabulary indices and corrected a bug in collection of code metadata --- configs/preprocess.yaml | 8 +++ scripts/extraction/collect_code_metadata.py | 4 ++ .../preprocessing/collect_code_metadata.py | 4 ++ .../preprocessing/fit_vocabulary_indices.py | 56 +++++++++++++++++++ src/MEDS_polars_functions/normalization.py | 14 ++--- 5 files changed, 79 insertions(+), 7 deletions(-) create mode 100644 scripts/preprocessing/fit_vocabulary_indices.py diff --git a/configs/preprocess.yaml b/configs/preprocess.yaml index d52f573..9936a5b 100644 --- a/configs/preprocess.yaml +++ b/configs/preprocess.yaml @@ -15,6 +15,7 @@ stages: - fit_outlier_detection - filter_outliers - fit_normalization + - fit_vocabulary_indices - normalize - tokenize - tensorize @@ -60,3 +61,10 @@ stage_configs: - "values/n_occurrences" - "values/sum" - "values/sum_sqd" + + fit_vocabulary_indices: + is_metadata: true + ordering_method: "lexicographic" + output_dir: "${cohort_dir}" + + normalize: diff --git a/scripts/extraction/collect_code_metadata.py b/scripts/extraction/collect_code_metadata.py index 97f1f69..9fb020c 100755 --- a/scripts/extraction/collect_code_metadata.py +++ b/scripts/extraction/collect_code_metadata.py @@ -34,6 +34,10 @@ def main(cfg: DictConfig): shards = json.loads((Path(cfg.stage_cfg.metadata_input_dir) / "splits.json").read_text()) + examine_splits = [f"{sp}/" for sp in cfg.stage_cfg.get("examine_splits", ["train"])] + logger.info(f"Computing metadata over shards with any prefix in {examine_splits}") + shards = {k: v for k, v in shards.items() if any(k.startswith(prefix) for prefix in examine_splits)} + patient_splits = list(shards.keys()) random.shuffle(patient_splits) diff --git a/scripts/preprocessing/collect_code_metadata.py b/scripts/preprocessing/collect_code_metadata.py index 71a6b27..514e871 100755 --- a/scripts/preprocessing/collect_code_metadata.py +++ b/scripts/preprocessing/collect_code_metadata.py @@ -33,6 +33,10 @@ def main(cfg: DictConfig): shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) + examine_splits = [f"{sp}/" for sp in cfg.stage_cfg.get("examine_splits", ["train"])] + logger.info(f"Computing metadata over shards with any prefix in {examine_splits}") + shards = {k: v for k, v in shards.items() if any(k.startswith(prefix) for prefix in examine_splits)} + patient_splits = list(shards.keys()) random.shuffle(patient_splits) diff --git a/scripts/preprocessing/fit_vocabulary_indices.py b/scripts/preprocessing/fit_vocabulary_indices.py new file mode 100644 index 0000000..aa601a9 --- /dev/null +++ b/scripts/preprocessing/fit_vocabulary_indices.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python + +from pathlib import Path + +import hydra +import polars as pl +from loguru import logger +from omegaconf import DictConfig, OmegaConf + +from MEDS_polars_functions.get_vocabulary import ( + VOCABULARY_ORDERING, + VOCABULARY_ORDERING_METHODS, +) +from MEDS_polars_functions.utils import hydra_loguru_init + + +@hydra.main(version_base=None, config_path="../../configs", config_name="preprocess") +def main(cfg: DictConfig): + """TODO.""" + + hydra_loguru_init() + + logger.info( + f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" + f"Stage: {cfg.stage}\n\n" + f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" + ) + + metadata_input_dir = Path(cfg.stage_cfg.metadata_input_dir) + output_dir = Path(cfg.stage_cfg.output_dir) + + code_metadata = pl.read_parquet(metadata_input_dir / "code_metadata.parquet", use_pyarrow=True) + + ordering_method = cfg.stage_cfg.get("ordering_method", VOCABULARY_ORDERING.LEXICOGRAPHIC) + + if ordering_method not in VOCABULARY_ORDERING_METHODS: + raise ValueError( + f"Invalid ordering method: {ordering_method}. " + f"Expected one of {', '.join(VOCABULARY_ORDERING_METHODS.keys())}" + ) + + logger.info(f"Assigning code vocabulary indices via a {ordering_method} order.") + ordering_fn = VOCABULARY_ORDERING_METHODS[ordering_method] + + code_metadata = ordering_fn(code_metadata, cfg.code_modifier_columns) + + output_fp = output_dir / "code_metadata.parquet" + logger.info(f"Indices assigned. Writing to {output_fp}") + + code_metadata.write_parquet(output_fp, use_pyarrow=True) + + logger.info(f"Done with {cfg.stage}") + + +if __name__ == "__main__": + main() diff --git a/src/MEDS_polars_functions/normalization.py b/src/MEDS_polars_functions/normalization.py index fa03945..d1b9175 100644 --- a/src/MEDS_polars_functions/normalization.py +++ b/src/MEDS_polars_functions/normalization.py @@ -17,14 +17,14 @@ def normalize( In addition, the `code_metadata` dataset should contain information about the codes in the MEDS dataset, including: - `code` (`categorical`) - - `code/vocab_id` (`int`) + - `code/vocab_index` (`int`) - `value/mean` (`float`) - `value/std` (`float`) The `value/*` functions will be used to normalize the code numerical values to have a mean of 0 and a standard deviation of 1. The output dataframe will further be filtered to only contain rows where the `code` in the MEDS dataset appears in the `code_metadata` dataset, and the output `code` column will be - converted to the `code/vocab_id` integral ID from the `code_metadata` dataset. + converted to the `code/vocab_index` integral ID from the `code_metadata` dataset. This function can further be customized by specifying additional columns to join on, via the `extra_join_columns` parameter, which must appear in both the MEDS dataset and the code metadata. These @@ -68,13 +68,13 @@ def normalize( >>> code_metadata = pl.DataFrame( ... { ... "code": ["lab//A", "lab//C", "dx//B", "dx//E", "lab//F"], - ... "code/vocab_id": [0, 2, 3, 4, 5], + ... "code/vocab_index": [0, 2, 3, 4, 5], ... "value/mean": [2.0, None, None, None, 3], ... "value/std": [0.5, None, None, None, 0.2], ... }, ... schema = { ... "code": pl.Categorical(ordering='physical'), - ... "code/vocab_id": pl.UInt32, + ... "code/vocab_index": pl.UInt32, ... "value/mean": pl.Float64, ... "value/std": pl.Float64, ... }, @@ -121,14 +121,14 @@ def normalize( ... { ... "code": ["lab//A", "lab//A", "lab//C", "dx//B", "dx//E", "lab//F"], ... "unit": ["mg/dL", "g/dL", None, None, None, None], - ... "code/vocab_id": [0, 1, 2, 3, 4, 5], + ... "code/vocab_index": [0, 1, 2, 3, 4, 5], ... "value/mean": [2.0, 3.0, None, None, None, 3], ... "value/std": [0.5, 2.0, None, None, None, 0.2], ... }, ... schema = { ... "code": pl.Categorical(ordering='physical'), ... "unit": pl.Utf8, - ... "code/vocab_id": pl.UInt32, + ... "code/vocab_index": pl.UInt32, ... "value/mean": pl.Float64, ... "value/std": pl.Float64, ... }, @@ -160,6 +160,6 @@ def normalize( ).select( "patient_id", "timestamp", - pl.col("code/vocab_id").alias("code"), + pl.col("code/vocab_index").alias("code"), ((pl.col("numerical_value") - pl.col("value/mean")) / pl.col("value/std")).alias("numerical_value"), ) From 3f32271d86129182d27e92da0418790d20a4ed1d Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Fri, 14 Jun 2024 17:14:36 -0400 Subject: [PATCH 20/53] documentation --- preprocessing_operation_prototypes.md | 68 +++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 preprocessing_operation_prototypes.md diff --git a/preprocessing_operation_prototypes.md b/preprocessing_operation_prototypes.md new file mode 100644 index 0000000..972f472 --- /dev/null +++ b/preprocessing_operation_prototypes.md @@ -0,0 +1,68 @@ +# MEDS Pre-processing Operation Prototypes STILL IN PROGRESS + +To support communal development and sharing of pre-processing operations, MEDS defines a set of core operation +"prototypes", which are extensible, reusable operations that can be applied to MEDS datasets in a variety of +circumstances to accomplish diverse, yet common, pre-processing tasks. The intent with these prototypes is +both that users can leverage these pre-built operations to quickly and easily accomplish common tasks, and +that they can _extend_ the set of supported operations within the broader framework using these prototypes to +support new operations to accelerate their own development and streamline the adoption of their innovations by +the broader community. + +## Core Prototypes + +### Collect & Aggregate Metadata + +#### `collect_code_metadata` + +This prototype is for summarizing MEDS data by code (and code modifier columns) and collecting aggregate +information across diverse axes over the entire dataset (or a subset of shards of the data, such as all those +shards in the train set). + +TODO: Describe the operation in more detail. + +### Filter the dataset + +#### `remove_patients` + +For removing patients who fail to meet some criteria from the dataset. + +#### `remove_measurements` + +For removing measurements that fail to meet some criteria from the dataset. + +### Uncategorized as of yet. + +#### `occlude_outliers` + +For occluding (setting to `None` or `np.NaN`) features observed about events within the data. This is +typically used for removing outlier numerical values, but could in theory be used on other features as well, +though that would require additional development. + +#### `reorder_measurements` + +Some pipelines desire a specific order of measurements within the broader per-patient event order (meaning the +order as implied by unique timestamps). + +#### `extract_numeric_values` + +These prototypes are for extracting numeric values from other columns in the dataset, most notably `text` or +`categorical` value columns. + +#### `extract_categorical_values` + +These prototypes are for extracting numeric values from other columns in the dataset, most notably `text` or +`categorical` value columns. + +## Possible Future Prototypes + +### `remove_events` + +For filtering unique timestamps based on some criteria. + +### `filter_to_cohort` + +For filtering the dataset to only include data matching some cohort specification, as defined by a dataframe +of patient IDs and start/end timestamps. This is not currently occluded as it can often happen trivially +during the dataloading stage for machine learning models, but for true supervised training, it may be useful +so that train-set pre-processing parameters can be fit specific to the cohort-specific train set, rather than +the general train set. From e8a01b44744161fe6d2a2b09e09e2f678edfef89 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 15 Jun 2024 13:12:52 -0400 Subject: [PATCH 21/53] Pre-processing prototypes documentation improvements --- preprocessing_operation_prototypes.md | 226 ++++++++++++++++++++++++-- 1 file changed, 210 insertions(+), 16 deletions(-) diff --git a/preprocessing_operation_prototypes.md b/preprocessing_operation_prototypes.md index 972f472..78479b5 100644 --- a/preprocessing_operation_prototypes.md +++ b/preprocessing_operation_prototypes.md @@ -8,35 +8,229 @@ that they can _extend_ the set of supported operations within the broader framew support new operations to accelerate their own development and streamline the adoption of their innovations by the broader community. +Note that, pursuant to the [core MEDS terminology](terminology.md), we will use "code" to refer to the unique +sets (with `null`s allowed) of `code` and all `code_modifier` columns. All operations should be presumed to be +potentially parametrized by the datasets list of code modifier columns. + ## Core Prototypes -### Collect & Aggregate Metadata +### Transform Codes (just codes, not patient data!) -#### `collect_code_metadata` +This operation is used to perform a static re-mapping over the allowed codes in a MEDS dataset, typically in +preparation for mapping that transformation out across the patient data by code. -This prototype is for summarizing MEDS data by code (and code modifier columns) and collecting aggregate -information across diverse axes over the entire dataset (or a subset of shards of the data, such as all those -shards in the train set). +##### Operation Steps -TODO: Describe the operation in more detail. +1. Add new information or transform existing columns in an existing `code_metadata.parquet` file. Note that + `code` or `code_modifier` columns should _not_ be modified in this step as that will break the linkage + with the patient data. -### Filter the dataset +##### Parameters -#### `remove_patients` +1. What function should be applied to each code row. -For removing patients who fail to meet some criteria from the dataset. +##### Status -#### `remove_measurements` +Individual functions are supported, but the operation as a prototypical paradigm is not yet implemented. -For removing measurements that fail to meet some criteria from the dataset. +##### Currently Supported Operations -### Uncategorized as of yet. +Functions: + +1. Assign vocabulary indices to codes. See `src/MEDS_polars_functions/get_vocabulary.py` + +### Collect metadata about code realizations in patient data + +This operation is used to produce summary information about the realizations of any unique code in the data, +such as the number of times a code occurs, the mean or variance of numerical values that are associated with a +code, etc. This operation can be applied over all data, or across patient groups or cohorts (saved into +separate files per patient group -- each output file is only grouped by code, not by patient group for +simplicity). + +##### Operation Steps + +1. Per-shard, filter the pateint data to satisfy desired set of patient or other data critieria. +2. Per-shard, group by code and collect some aggregate statistics. Optionally also compute statistics across + all codes. +3. Reduce the per-shard aggregate files into a unified `code_metadata.parquet` file. +4. Optionally merge with static per-code metadata from prior steps. + +##### Parameters + +1. What (if any) patient data filters should be applied prior to aggregation. +2. What aggregation functions should be applied to each code. Each aggregation function must specify both a + _mapper_ function that computes aggregate data on a per-shard basis and a _reducer_ function that + combines different shards together into a single, unified metadata file. +3. Whether or not aggregation functions should be computed over all raw data (the "null" code case). + +##### Status + +This operation is partially implemented as a prototype, but is not yet fully functional, as it lacks support +for patient-data filters prior to aggregation. + +##### Currently supported operations + +Patient Filters: **None** + +Functions: + +1. Various aggregation functions; see `src/MEDS_polars_functions/code_metadata.py` for a list of supported + functions. + +##### Planned Future Operations + +None at this time. To request a new operation, please open a GitHub issue. + +### Filtering the Dataset + +A note on terminology: We will use the term "removing data" to refer to operations that fully drop data from +the record, retaining no notion of the corresponding data occurring in the dataset. Operations that remove +data will result in smaller overall datasets (either in number of patients or number of measurements). We will +use the term "occluding data" to refer to operations that set data to `UNK`, `None`, or `np.NaN`, but retain +that there was _some_ data in the dataset originally. Operations that occlude data will result in the same +size dataset in terms of number of patients, number of measurements, etc., but will not have the same degree +of data granularity or information content. Occlud operations will typically *not* be reversible, but will +include a boolean indicator identifying that data was definitively occluded. + +There are a few modes of filtering data from MEDS datasets that are configured as separate prototypes. These +include: + +1. Filtering patients wholesale based on aggregate, patient-level criteria (e.g., number of events, etc.) +2. Filtering the data to only include patient data that matches some cohort specification (meaning removing + data that is not within pre-identified ranges of time on a per-patient basis). +3. Filtering individual measurements from the data based on some criteria (e.g., removing measurements that + have codes that are not included in the overall vocabulary, etc.). +4. Occluding features from individual measurements from the data based on some criteria (e.g., occluding + outlier numerical values or infrequent codes, etc.) + +#### Filtering Patients + +##### Operation Steps + +1. Per-shard, aggregate data per-patient and compute some aggregate criteria. +2. Remove all data corresponding to patients on the basis of the resulting criteria. +3. Return the filtered dataset, in the same format as the original, but with only the remaining patients. + +##### Parameters + +1. What aggregation functions should be applied to each patient. +2. What criteria should be used based on those aggregations to filter patients. + +These parameters may be specified with a single variable (e.g., `min_events_per_patient` indicates we need to +compute the number of unique timepoints per patient and impose a minimum threshold on that number). + +##### Status + +This operation is only implemented through two concrete functions, not a generalizable prototype in +`src/MEDS_polars_functions/filter_patients_by_length.py`. + +##### Currently supported operations + +1. Filtering patients by the number of events (unique timepoints) in their record. +2. Filtering patients by the number of measurements in their record. + +##### Planned Future Operations + +None at this time. To request a new operation, please open a GitHub issue. -#### `occlude_outliers` +#### Filtering Measurements -For occluding (setting to `None` or `np.NaN`) features observed about events within the data. This is -typically used for removing outlier numerical values, but could in theory be used on other features as well, -though that would require additional development. +This operation assumes that any requisite aggregate, per-code information is pre-computed and can be joined in +via a `code_metadata.parquet` file. + +##### Operation Steps + +1. Per-shard, join the data, if necessary, to the provided, global `code_metadata.parquet` file. +2. Apply row-based criteria to each measurement to determine if it should be retained or removed. +3. Return the filtered dataset, in the same format as the original, but with only the measurements to be + retained. + +##### Parameters + +1. What criteria should be used to filter measurements. +2. What, if any, columns in the `code_metadata.parquet` file should be joined in to the data. + +##### Status + +This operation is supported as a partial prototype, through the +`src/MEDS_polars_functions/filter_measurements.py` file. It needs extension to reach a full prototype status, +but supports such extension relatively natively. + +##### Currently supported operations + +Currently, measurements can be filtered on the basis of `min_patients_per_code` and `min_occurrences_per_code` +thresholds, which are read from the `code_metadata.parquet` file via the `code/n_patients` and +`code/n_occurrences` columns, respectively. + +##### Planned Future Operations + +None at this time. To request a new operation, please open a GitHub issue. + +#### Occluding Features within Measurements + +This operation assumes that any requisite aggregate, per-code information is pre-computed and can be joined in +via a `code_metadata.parquet` file. + +**TODO**: Should this operation be considered a realization of the "feature transformation" prototype, instead +of a special case in this section? + +##### Operation Steps + +1. Per-shard, join the data, if necessary, to the provided, global `code_metadata.parquet` file. +2. Apply row-based criteria to each measurement to determine if individual features should be occluded or + retained in full granularity. +3. Set occluded data to the occlusion target (typically `"UNK"`, `None`, or `np.NaN`) and add an indicator + column indicating occlusion status. + +##### Parameters + +1. What criteria should be used to occlude features. + \- Relatedly, what occlusion value should be used for occluded features. + \- Relatedly, what the name of the occlusion column should be (can be set by default for features). +2. What, if any, columns in the `code_metadata.parquet` file should be joined in to the data. + +##### Status + +This operation is only supported through the single `filter_outliers_fntr` function in +`src/MEDS_polars_functions/filter_measurements.py`. It is not yet a general prototype. + +##### Currently supported operations + +1. Occluding numerical values if they take a value more distant from the code's mean by a specified number + of standard deviations. + +## Requesting New Prototypes + +To request or suggest a new prototypical paradigm, please open a GitHub issue. In that issue, please include a +description of the desired operation in the format used for the operations above, following the below +template: + +```markdown +### NAME + +Describe the operation in natural language. + +##### Operation Steps +Describe the rough API that this operation would take, as a configurable prototype. + +##### Parameters +Describe how this operation would be controlled in pipelines by the user. This will ultimately map into +configuration parameters. + +##### Status +Describe the current status of this operation. It may, generally speaking, either be fully unsupported, have +realizations of select funcctions supported, but not a general prototype, or be supported either partially or +fully as a prototype. + +##### Currently supported operations +Describe what specific realizations of this operation as a prototypes are (e.g., what options the user can +select to realize different functions within this prototype). + +##### Planned Future Operations +ADD TEXT HERE +``` + +### Uncategorized as of yet. #### `reorder_measurements` From d0a4a0c3849411dae72a8272bb80eb95700092f8 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 15 Jun 2024 13:46:16 -0400 Subject: [PATCH 22/53] Further extensions to documentation. --- preprocessing_operation_prototypes.md | 96 ++++++++++++++------------- 1 file changed, 49 insertions(+), 47 deletions(-) diff --git a/preprocessing_operation_prototypes.md b/preprocessing_operation_prototypes.md index 78479b5..71098d9 100644 --- a/preprocessing_operation_prototypes.md +++ b/preprocessing_operation_prototypes.md @@ -1,5 +1,8 @@ # MEDS Pre-processing Operation Prototypes STILL IN PROGRESS +**NOTE**: This document is currently aspirational, not yet implemented. Some functions in these patterns are +implemented, but not universally. + To support communal development and sharing of pre-processing operations, MEDS defines a set of core operation "prototypes", which are extensible, reusable operations that can be applied to MEDS datasets in a variety of circumstances to accomplish diverse, yet common, pre-processing tasks. The intent with these prototypes is @@ -12,7 +15,24 @@ Note that, pursuant to the [core MEDS terminology](terminology.md), we will use sets (with `null`s allowed) of `code` and all `code_modifier` columns. All operations should be presumed to be potentially parametrized by the datasets list of code modifier columns. -## Core Prototypes +A note on terminology: We will use the term "removing data" to refer to operations that fully drop data from +the record, retaining no notion of the corresponding data occurring in the dataset. Operations that remove +data will result in smaller overall datasets (either in number of patients or number of measurements). We will +use the term "occluding data" to refer to operations that set data to `UNK`, `None`, or `np.NaN`, but retain +that there was _some_ data in the dataset originally. Operations that occlude data will result in the same +size dataset in terms of number of patients, number of measurements, etc., but will not have the same degree +of data granularity or information content. Occlud operations will typically *not* be reversible, but will +include a boolean indicator identifying that data was definitively occluded. + +## Filtering Prototypes (a.k.a. Match and Revise) + +A subset of the prototypes listed below can be modified to only be applied to a subset of the data. These +subsets can be based on patient level criteria (e.g., patients who meet certain criteria) or via code filters +(e.g., to only apply a certain value extraction regex to codes that match a certain pattern), with the +results being merged into the output dataset in a consistent manner. Currently, these capabilities are only +planned, not yet implemented. + +## Prototypes ### Transform Codes (just codes, not patient data!) @@ -83,15 +103,6 @@ None at this time. To request a new operation, please open a GitHub issue. ### Filtering the Dataset -A note on terminology: We will use the term "removing data" to refer to operations that fully drop data from -the record, retaining no notion of the corresponding data occurring in the dataset. Operations that remove -data will result in smaller overall datasets (either in number of patients or number of measurements). We will -use the term "occluding data" to refer to operations that set data to `UNK`, `None`, or `np.NaN`, but retain -that there was _some_ data in the dataset originally. Operations that occlude data will result in the same -size dataset in terms of number of patients, number of measurements, etc., but will not have the same degree -of data granularity or information content. Occlud operations will typically *not* be reversible, but will -include a boolean indicator identifying that data was definitively occluded. - There are a few modes of filtering data from MEDS datasets that are configured as separate prototypes. These include: @@ -100,8 +111,6 @@ include: data that is not within pre-identified ranges of time on a per-patient basis). 3. Filtering individual measurements from the data based on some criteria (e.g., removing measurements that have codes that are not included in the overall vocabulary, etc.). -4. Occluding features from individual measurements from the data based on some criteria (e.g., occluding - outlier numerical values or infrequent codes, etc.) #### Filtering Patients @@ -166,13 +175,26 @@ thresholds, which are read from the `code_metadata.parquet` file via the `code/n None at this time. To request a new operation, please open a GitHub issue. +### Transforming Features within Measurements + +These prototypes or functional patterns are for transforming features within measurements. Critically, they +leave the output dataset in the same length and in the same order as the input dataset, and only transform +features. For operations that change the length or order (within the mandated `patient_id` and `timepoint` +order), see the "Transforming Measurements within Events" section. + +**TODO** Add or merge in the following: + +1. Normalizing numerical values (this is currently implemented with `normalization.py`). +2. Extract numerical values from text (e.g., extracting a number from a string). + #### Occluding Features within Measurements This operation assumes that any requisite aggregate, per-code information is pre-computed and can be joined in via a `code_metadata.parquet` file. -**TODO**: Should this operation be considered a realization of the "feature transformation" prototype, instead -of a special case in this section? +**TODO** This is not really a prototype, but is really a single function, or a subset of a prototype. IT has +functionally the same API as numerical value normalization, with the modification that the indicator columns +are added and this function is not reversible. ##### Operation Steps @@ -185,8 +207,8 @@ of a special case in this section? ##### Parameters 1. What criteria should be used to occlude features. - \- Relatedly, what occlusion value should be used for occluded features. - \- Relatedly, what the name of the occlusion column should be (can be set by default for features). + - Relatedly, what occlusion value should be used for occluded features. + - Relatedly, what the name of the occlusion column should be (can be set by default for features). 2. What, if any, columns in the `code_metadata.parquet` file should be joined in to the data. ##### Status @@ -199,6 +221,17 @@ This operation is only supported through the single `filter_outliers_fntr` funct 1. Occluding numerical values if they take a value more distant from the code's mean by a specified number of standard deviations. +### Transforming Measurements within Events + +These aren't implemented yet, but are planned: + +1. Re-order measurements within the event ordering. +2. Split measurements into multiple measurements in a particular order and via a particular functional form. + E.g., + - Performing ontology expansion + - Splitting a multi-faceted measurement (e.g., blood pressure recorded as `"120/80"`) into multiple + measurements (e.g., a systolic and diastolic blood pressure measurement with values `120` and `80`). + ## Requesting New Prototypes To request or suggest a new prototypical paradigm, please open a GitHub issue. In that issue, please include a @@ -229,34 +262,3 @@ select to realize different functions within this prototype). ##### Planned Future Operations ADD TEXT HERE ``` - -### Uncategorized as of yet. - -#### `reorder_measurements` - -Some pipelines desire a specific order of measurements within the broader per-patient event order (meaning the -order as implied by unique timestamps). - -#### `extract_numeric_values` - -These prototypes are for extracting numeric values from other columns in the dataset, most notably `text` or -`categorical` value columns. - -#### `extract_categorical_values` - -These prototypes are for extracting numeric values from other columns in the dataset, most notably `text` or -`categorical` value columns. - -## Possible Future Prototypes - -### `remove_events` - -For filtering unique timestamps based on some criteria. - -### `filter_to_cohort` - -For filtering the dataset to only include data matching some cohort specification, as defined by a dataframe -of patient IDs and start/end timestamps. This is not currently occluded as it can often happen trivially -during the dataloading stage for machine learning models, but for true supervised training, it may be useful -so that train-set pre-processing parameters can be fit specific to the cohort-specific train set, rather than -the general train set. From 2bb28637e14b5065a4fa9b5be74c1dc46ff7102c Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 15 Jun 2024 13:54:15 -0400 Subject: [PATCH 23/53] Added (untested) code metadata type shrinking --- scripts/preprocessing/collect_code_metadata.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scripts/preprocessing/collect_code_metadata.py b/scripts/preprocessing/collect_code_metadata.py index 514e871..191894f 100755 --- a/scripts/preprocessing/collect_code_metadata.py +++ b/scripts/preprocessing/collect_code_metadata.py @@ -8,6 +8,7 @@ import hydra import polars as pl +import polars.selectors as cs from loguru import logger from omegaconf import DictConfig, OmegaConf @@ -79,7 +80,10 @@ def main(cfg: DictConfig): logger.info("All map shards complete! Starting code metadata reduction computation.") reducer_fn = reducer_fntr(cfg.stage_cfg, cfg.get("code_modifier_columns", None)) - reduced = reducer_fn(*[pl.scan_parquet(fp, glob=False) for fp in all_out_fps]) + reduced = ( + reducer_fn(*[pl.scan_parquet(fp, glob=False) for fp in all_out_fps]) + .with_columns(cs.is_numeric().shrink_dtype().keep_name()) + ) write_lazyframe(reduced, output_dir / "code_metadata.parquet") logger.info(f"Finished reduction in {datetime.now() - start}") From 117d95063c7cca1e1630e4cb3bb1b4e5c2b4b80d Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 15 Jun 2024 14:01:09 -0400 Subject: [PATCH 24/53] Added (yet untested) normalization script --- .../preprocessing/collect_code_metadata.py | 5 +- scripts/preprocessing/normalize.py | 64 ++++++++++++++++++ .../filter_measurements.py | 8 +-- src/MEDS_polars_functions/normalization.py | 66 +++++++++++++------ 4 files changed, 116 insertions(+), 27 deletions(-) create mode 100644 scripts/preprocessing/normalize.py diff --git a/scripts/preprocessing/collect_code_metadata.py b/scripts/preprocessing/collect_code_metadata.py index 191894f..6791fc6 100755 --- a/scripts/preprocessing/collect_code_metadata.py +++ b/scripts/preprocessing/collect_code_metadata.py @@ -80,9 +80,8 @@ def main(cfg: DictConfig): logger.info("All map shards complete! Starting code metadata reduction computation.") reducer_fn = reducer_fntr(cfg.stage_cfg, cfg.get("code_modifier_columns", None)) - reduced = ( - reducer_fn(*[pl.scan_parquet(fp, glob=False) for fp in all_out_fps]) - .with_columns(cs.is_numeric().shrink_dtype().keep_name()) + reduced = reducer_fn(*[pl.scan_parquet(fp, glob=False) for fp in all_out_fps]).with_columns( + cs.is_numeric().shrink_dtype().keep_name() ) write_lazyframe(reduced, output_dir / "code_metadata.parquet") logger.info(f"Finished reduction in {datetime.now() - start}") diff --git a/scripts/preprocessing/normalize.py b/scripts/preprocessing/normalize.py new file mode 100644 index 0000000..eac01c0 --- /dev/null +++ b/scripts/preprocessing/normalize.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python + +import json +import random +from functools import partial +from pathlib import Path + +import hydra +import polars as pl +from loguru import logger +from omegaconf import DictConfig, OmegaConf + +from MEDS_polars_functions.mapper import wrap as rwlock_wrap +from MEDS_polars_functions.normalization import normalize +from MEDS_polars_functions.utils import hydra_loguru_init, write_lazyframe + + +@hydra.main(version_base=None, config_path="../../configs", config_name="preprocess") +def main(cfg: DictConfig): + """TODO.""" + + hydra_loguru_init() + + logger.info( + f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" + f"Stage: {cfg.stage}\n\n" + f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" + ) + + input_dir = Path(cfg.stage_cfg.data_input_dir) + metadata_input_dir = Path(cfg.stage_cfg.metadata_input_dir) + output_dir = Path(cfg.stage_cfg.output_dir) + + shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) + + patient_splits = list(shards.keys()) + random.shuffle(patient_splits) + + code_metadata = pl.read_parquet(metadata_input_dir / "code_metadata.parquet", use_pyarrow=True) + code_modifiers = cfg.get("code_modifier_columns", None) + compute_fn = partial(normalize, code_metadata=code_metadata, code_modifiers=code_modifiers) + + for sp in patient_splits: + in_fp = input_dir / f"{sp}.parquet" + out_fp = output_dir / f"{sp}.parquet" + + logger.info(f"Filtering {str(in_fp.resolve())} into {str(out_fp.resolve())}") + + rwlock_wrap( + in_fp, + out_fp, + pl.scan_parquet, + write_lazyframe, + compute_fn, + do_return=False, + cache_intermediate=False, + do_overwrite=cfg.do_overwrite, + ) + + logger.info(f"Done with {cfg.stage}") + + +if __name__ == "__main__": + main() diff --git a/src/MEDS_polars_functions/filter_measurements.py b/src/MEDS_polars_functions/filter_measurements.py index 29cbee2..b3deeb2 100644 --- a/src/MEDS_polars_functions/filter_measurements.py +++ b/src/MEDS_polars_functions/filter_measurements.py @@ -198,8 +198,8 @@ def filter_outliers_fntr( stddev_col = (pl.col("values/sum_sqd") / pl.col("values/n_occurrences") - mean_col**2) ** 0.5 if "values/mean" not in code_metadata.columns: cols_to_select.append(mean_col.alias("values/mean")) - if "values/stddev" not in code_metadata.columns: - cols_to_select.append(stddev_col.alias("values/stddev")) + if "values/std" not in code_metadata.columns: + cols_to_select.append(stddev_col.alias("values/std")) code_metadata = code_metadata.lazy().select(cols_to_select) @@ -212,7 +212,7 @@ def filter_outliers_fn(df: pl.LazyFrame) -> pl.LazyFrame: val = pl.col("numerical_value") mean = pl.col("values/mean") - stddev = pl.col("values/stddev") + stddev = pl.col("values/std") filter_expr = (val - mean).abs() <= stddev_cutoff * stddev return ( @@ -221,7 +221,7 @@ def filter_outliers_fn(df: pl.LazyFrame) -> pl.LazyFrame: filter_expr.alias("numerical_value/is_inlier"), pl.when(filter_expr).then(pl.col("numerical_value")).alias("numerical_value"), ) - .drop("values/mean", "values/stddev") + .drop("values/mean", "values/std") ) return filter_outliers_fn diff --git a/src/MEDS_polars_functions/normalization.py b/src/MEDS_polars_functions/normalization.py index d1b9175..28a6ac8 100644 --- a/src/MEDS_polars_functions/normalization.py +++ b/src/MEDS_polars_functions/normalization.py @@ -4,7 +4,7 @@ def normalize( - df: pl.LazyFrame, code_metadata: pl.LazyFrame, extra_join_columns: list[str] | None = None + df: pl.LazyFrame, code_metadata: pl.LazyFrame, code_modifiers: list[str] | None = None ) -> pl.LazyFrame: """Normalize a MEDS dataset across both categorical and continuous dimensions. @@ -15,26 +15,37 @@ def normalize( - `numerical_value` In addition, the `code_metadata` dataset should contain information about the codes in the MEDS dataset, - including: + including the mandatory columns: - `code` (`categorical`) - `code/vocab_index` (`int`) - - `value/mean` (`float`) - - `value/std` (`float`) + - Any `code_modifiers` columns, if specified - The `value/*` functions will be used to normalize the code numerical values to have a mean of 0 and a + Additionally, it must either have: + - Pre-computed means and standard deviations for the numerical values of the codes in the MEDS dataset, + via: + - `values/mean` (`float`) + - `values/std` (`float`) + - Or the necessary statistics to compute the per-occurrence mean and standard deviation of the numerical + values of the codes in the MEDS dataset, via: + - `values/n_occurrences` (`int`) + - `values/sum` (`float`) + - `values/sum_sqd` (`float`) + + + The `values/*` functions will be used to normalize the code numerical values to have a mean of 0 and a standard deviation of 1. The output dataframe will further be filtered to only contain rows where the `code` in the MEDS dataset appears in the `code_metadata` dataset, and the output `code` column will be converted to the `code/vocab_index` integral ID from the `code_metadata` dataset. This function can further be customized by specifying additional columns to join on, via the - `extra_join_columns` parameter, which must appear in both the MEDS dataset and the code metadata. These + `code_modifiers` parameter, which must appear in both the MEDS dataset and the code metadata. These columns will be discarded from the output dataframe, which will only contain the four expected input columns, though normalized. Args: df: The MEDS dataset to normalize. See above for the expected schema. code_metadata: Metadata about the codes in the MEDS dataset. See above for the expected schema. - extra_join_columns: Additional columns to join on, which will be discarded from the output dataframe. + code_modifiers: Additional columns to join on, which will be discarded from the output dataframe. Returns: The normalized MEDS dataset, with the schema described above. @@ -69,14 +80,14 @@ def normalize( ... { ... "code": ["lab//A", "lab//C", "dx//B", "dx//E", "lab//F"], ... "code/vocab_index": [0, 2, 3, 4, 5], - ... "value/mean": [2.0, None, None, None, 3], - ... "value/std": [0.5, None, None, None, 0.2], + ... "values/mean": [2.0, None, None, None, 3], + ... "values/std": [0.5, None, None, None, 0.2], ... }, ... schema = { ... "code": pl.Categorical(ordering='physical'), ... "code/vocab_index": pl.UInt32, - ... "value/mean": pl.Float64, - ... "value/std": pl.Float64, + ... "values/mean": pl.Float64, + ... "values/std": pl.Float64, ... }, ... ) >>> normalize(MEDS_df.lazy(), code_metadata.lazy()).collect() @@ -122,15 +133,15 @@ def normalize( ... "code": ["lab//A", "lab//A", "lab//C", "dx//B", "dx//E", "lab//F"], ... "unit": ["mg/dL", "g/dL", None, None, None, None], ... "code/vocab_index": [0, 1, 2, 3, 4, 5], - ... "value/mean": [2.0, 3.0, None, None, None, 3], - ... "value/std": [0.5, 2.0, None, None, None, 0.2], + ... "values/mean": [2.0, 3.0, None, None, None, 3], + ... "values/std": [0.5, 2.0, None, None, None, 0.2], ... }, ... schema = { ... "code": pl.Categorical(ordering='physical'), ... "unit": pl.Utf8, ... "code/vocab_index": pl.UInt32, - ... "value/mean": pl.Float64, - ... "value/std": pl.Float64, + ... "values/mean": pl.Float64, + ... "values/std": pl.Float64, ... }, ... ) >>> normalize(MEDS_df.lazy(), code_metadata.lazy(), ["unit"]).collect() @@ -149,17 +160,32 @@ def normalize( └────────────┴─────────────────────┴──────┴─────────────────┘ """ - if extra_join_columns is None: - extra_join_columns = [] + if code_modifiers is None: + code_modifiers = [] + + cols_to_select = ["code", "code/vocab_index"] + code_modifiers + + mean_col = pl.col("values/sum") / pl.col("values/n_occurrences") + stddev_col = (pl.col("values/sum_sqd") / pl.col("values/n_occurrences") - mean_col**2) ** 0.5 + + if "values/mean" in code_metadata.columns: + cols_to_select.append("values/mean") + else: + cols_to_select.append(mean_col.alias("values/mean")) + + if "values/std" in code_metadata.columns: + cols_to_select.append("values/std") + else: + cols_to_select.append(stddev_col.alias("values/std")) return df.join( - code_metadata, - on=["code"] + extra_join_columns, + code_metadata.select(cols_to_select), + on=["code"] + code_modifiers, how="inner", join_nulls=True, ).select( "patient_id", "timestamp", pl.col("code/vocab_index").alias("code"), - ((pl.col("numerical_value") - pl.col("value/mean")) / pl.col("value/std")).alias("numerical_value"), + ((pl.col("numerical_value") - pl.col("values/mean")) / pl.col("values/std")).alias("numerical_value"), ) From fae41158b7ea7049b530bfe570318b80262f11f9 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 15 Jun 2024 14:24:54 -0400 Subject: [PATCH 25/53] Fit vocabulary works --- MIMIC-IV_Example/README.md | 6 ++++++ scripts/preprocessing/fit_vocabulary_indices.py | 6 +++++- scripts/preprocessing/normalize.py | 0 3 files changed, 11 insertions(+), 1 deletion(-) mode change 100644 => 100755 scripts/preprocessing/fit_vocabulary_indices.py mode change 100644 => 100755 scripts/preprocessing/normalize.py diff --git a/MIMIC-IV_Example/README.md b/MIMIC-IV_Example/README.md index 6bdba6f..4689b74 100644 --- a/MIMIC-IV_Example/README.md +++ b/MIMIC-IV_Example/README.md @@ -250,6 +250,12 @@ mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps ❯ ./scripts/preprocessing/collect_code_metadata.py --multirun worker="range(0,3)" hydra/launcher=joblib input_dir="$MIMICIV_MEDS_DIR/3workers_slurm" cohort_dir="$MIMICIV_MEDS_PROC_DIR/test" code_modifier_columns=null stage=fit_normalization ``` +8. Fit vocabulary: +```bash +mbm47 in  compute-e-16-230 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines took 2s +❯ ./scripts/preprocessing/fit_vocabulary_indices.py input_dir="$MIMICIV_MEDS_DIR/3workers_slurm" cohort_dir="$MIMICIV_MEDS_PROC_DIR/test" code_modifier_columns=null +``` + ## Limitations / TO-DOs: Currently, some tables are ignored, including: diff --git a/scripts/preprocessing/fit_vocabulary_indices.py b/scripts/preprocessing/fit_vocabulary_indices.py old mode 100644 new mode 100755 index aa601a9..73c56e7 --- a/scripts/preprocessing/fit_vocabulary_indices.py +++ b/scripts/preprocessing/fit_vocabulary_indices.py @@ -42,7 +42,11 @@ def main(cfg: DictConfig): logger.info(f"Assigning code vocabulary indices via a {ordering_method} order.") ordering_fn = VOCABULARY_ORDERING_METHODS[ordering_method] - code_metadata = ordering_fn(code_metadata, cfg.code_modifier_columns) + code_modifiers = cfg.get("code_modifier_columns", None) + if code_modifiers is None: + code_modifiers = [] + + code_metadata = ordering_fn(code_metadata, code_modifiers) output_fp = output_dir / "code_metadata.parquet" logger.info(f"Indices assigned. Writing to {output_fp}") diff --git a/scripts/preprocessing/normalize.py b/scripts/preprocessing/normalize.py old mode 100644 new mode 100755 From bd035b9b395f38fdd95da1eaff46b3fd5529f3a2 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 15 Jun 2024 14:30:08 -0400 Subject: [PATCH 26/53] Normalize works after some minor modification. --- MIMIC-IV_Example/README.md | 7 +++++++ configs/preprocess.yaml | 2 -- scripts/preprocessing/normalize.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/MIMIC-IV_Example/README.md b/MIMIC-IV_Example/README.md index 4689b74..56821de 100644 --- a/MIMIC-IV_Example/README.md +++ b/MIMIC-IV_Example/README.md @@ -256,6 +256,13 @@ mbm47 in  compute-e-16-230 in MEDS_polars_functions on  preprocessing_step ❯ ./scripts/preprocessing/fit_vocabulary_indices.py input_dir="$MIMICIV_MEDS_DIR/3workers_slurm" cohort_dir="$MIMICIV_MEDS_PROC_DIR/test" code_modifier_columns=null ``` +9. Normalize: +```bash +mbm47 in  compute-e-16-230 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines took 4s +❯ ./scripts/preprocessing/normalize.py --multirun worker="range(0,3)" hydra/launcher=joblib input_dir="$MIMICIV_MEDS_DIR/3workers_slurm" cohort_dir="$MIMICIV_MEDS_PROC_DIR/test" code_modifie +r_columns=null +``` + ## Limitations / TO-DOs: Currently, some tables are ignored, including: diff --git a/configs/preprocess.yaml b/configs/preprocess.yaml index 9936a5b..3c8bb2d 100644 --- a/configs/preprocess.yaml +++ b/configs/preprocess.yaml @@ -66,5 +66,3 @@ stage_configs: is_metadata: true ordering_method: "lexicographic" output_dir: "${cohort_dir}" - - normalize: diff --git a/scripts/preprocessing/normalize.py b/scripts/preprocessing/normalize.py index eac01c0..a80522c 100755 --- a/scripts/preprocessing/normalize.py +++ b/scripts/preprocessing/normalize.py @@ -36,7 +36,7 @@ def main(cfg: DictConfig): patient_splits = list(shards.keys()) random.shuffle(patient_splits) - code_metadata = pl.read_parquet(metadata_input_dir / "code_metadata.parquet", use_pyarrow=True) + code_metadata = pl.read_parquet(metadata_input_dir / "code_metadata.parquet", use_pyarrow=True).lazy() code_modifiers = cfg.get("code_modifier_columns", None) compute_fn = partial(normalize, code_metadata=code_metadata, code_modifiers=code_modifiers) From b48aefd9d4fba7477055a9e8c076de63ef1ab0f3 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 15 Jun 2024 14:32:58 -0400 Subject: [PATCH 27/53] Fixed lint issues. --- MIMIC-IV_Example/README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/MIMIC-IV_Example/README.md b/MIMIC-IV_Example/README.md index 56821de..0f9ffed 100644 --- a/MIMIC-IV_Example/README.md +++ b/MIMIC-IV_Example/README.md @@ -251,14 +251,16 @@ mbm47 in  compute-a-17-72 in MEDS_polars_functions on  preprocessing_steps ``` 8. Fit vocabulary: + ```bash -mbm47 in  compute-e-16-230 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines took 2s +mbm47 in  compute-e-16-230 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines took 2s ❯ ./scripts/preprocessing/fit_vocabulary_indices.py input_dir="$MIMICIV_MEDS_DIR/3workers_slurm" cohort_dir="$MIMICIV_MEDS_PROC_DIR/test" code_modifier_columns=null ``` 9. Normalize: + ```bash -mbm47 in  compute-e-16-230 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines took 4s +mbm47 in  compute-e-16-230 in MEDS_polars_functions on  preprocessing_steps [$] is 󰏗 v0.0.1 via  v3.12.3 via  MEDS_pipelines took 4s ❯ ./scripts/preprocessing/normalize.py --multirun worker="range(0,3)" hydra/launcher=joblib input_dir="$MIMICIV_MEDS_DIR/3workers_slurm" cohort_dir="$MIMICIV_MEDS_PROC_DIR/test" code_modifie r_columns=null ``` From 8820ce787c9dfe755cf9316dbde5d06ef5a62cb5 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 16 Jun 2024 13:02:31 -0400 Subject: [PATCH 28/53] Corrected typos --- configs/pipeline.yaml | 2 +- configs/preprocess.yaml | 1 + scripts/preprocessing/collect_code_metadata.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/configs/pipeline.yaml b/configs/pipeline.yaml index 229ea26..985866a 100644 --- a/configs/pipeline.yaml +++ b/configs/pipeline.yaml @@ -6,7 +6,7 @@ _default_description: |- This is a MEDS pipeline ETL. Please set a more detailed description at the top of your specific pipeline configuration file. -log_dir: "${cohort_dir}/.logs/${stage}" +log_dir: "${stage_cfg.output_dir}/.logs" # General pipeline variables do_overwrite: False diff --git a/configs/preprocess.yaml b/configs/preprocess.yaml index 3c8bb2d..bf925a1 100644 --- a/configs/preprocess.yaml +++ b/configs/preprocess.yaml @@ -25,6 +25,7 @@ stage_configs: filter_patients: min_events_per_patient: null min_measurements_per_patient: null + data_input_dir: ${input_dir}/final_cohort add_time_derived_measurements: age: diff --git a/scripts/preprocessing/collect_code_metadata.py b/scripts/preprocessing/collect_code_metadata.py index 6791fc6..d69bbc8 100755 --- a/scripts/preprocessing/collect_code_metadata.py +++ b/scripts/preprocessing/collect_code_metadata.py @@ -69,7 +69,7 @@ def main(cfg: DictConfig): logger.info(f"Finished mapping in {datetime.now() - start}") - if cfg.worker != 1: + if cfg.worker != 0: return while not all(fp.is_file() for fp in all_out_fps): @@ -81,7 +81,7 @@ def main(cfg: DictConfig): reducer_fn = reducer_fntr(cfg.stage_cfg, cfg.get("code_modifier_columns", None)) reduced = reducer_fn(*[pl.scan_parquet(fp, glob=False) for fp in all_out_fps]).with_columns( - cs.is_numeric().shrink_dtype().keep_name() + cs.numeric().shrink_dtype().keep_name() ) write_lazyframe(reduced, output_dir / "code_metadata.parquet") logger.info(f"Finished reduction in {datetime.now() - start}") From 982b537396e9357523fdb9d36d278446266e8d3a Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 16 Jun 2024 17:33:54 -0400 Subject: [PATCH 29/53] Added some details on tokenization and tensorization. --- README.md | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/README.md b/README.md index 8258939..fbed6d3 100644 --- a/README.md +++ b/README.md @@ -327,6 +327,60 @@ running multiple copies of the same script on independent workers to process the steps again need to happen in a single-threaded manner, but these steps are generally very fast and should not be a bottleneck. +### Tokenization + +Tokenization is the process of producing dataframes that are arranged into the sequences that will eventually +be processed by deep-learning methods. Generally, these dataframes will be arranged such that each row +corresponds to a unique patient, with nested list-type columns corresponding either to _events_ (unique +timepoints), themselves with nested, list-type measurements, or to _measurements_ (unique measurements within +a timepoint) directly. Importantly, _tokenized files are generally not ideally suited to direct ingestion by +PyTorch datasets_. Instead, they should undergo a _tensorization_ process to be converted into a format that +permits fast, efficient, scalable retrieval for deep-learning training. + +### Tensorization + +Tensorization is the process of producing files of the tokenized, normalized sequences that permit efficient, +scalable deep-learning. Here, by _efficiency_, we mean that the file structure and arrangement should permit +the deep learning process to (1) begin smoothly after startup, without a long, data-ingestion phase, (2) be +organized such that individual items (e.g., in a `__getitem__` call) can be retrieved quickly in a manner that +does not inhibit rapid training, and (3) be organized such that CPU and GPU resources are used efficiently +during training. Similarly, by _scalability_, we mean that the three desiderata above should hold true even as +the dataset size grows much larger---while total training time can increase, time to begin training, to +process the data per-item, and CPU/GPU resources required should remain constant, or only grow negligibly, +such as the cost of maintaining a larger index of patient IDs to file offsets or paths (though disk space will +of course increase). + +Depending on one's performance needs and dataset sizes, there are 3 modes of deep learning training that can +be used that warrant different styles of tensorization: + +#### In-memory Training + +This mode of training does not scale to large datasets, and given the parallelizability of the data-loading +phase, may or may not actually be significantly faster than other modes. It is not currently supported in this +repository. **TODO** describe in more detail. + +#### Direct Retrieval + +This mode of training has the data needed for any given PyTorch Dataset `__getitem__` call retrieved from disk +on an as-needed basis. This mode is extremely scalable, because the entire dataset never need be +loaded or stored in memory in its entirety. When done properly, retrieving data from disk can be done in a +manner that is independent of the total dataset size as well, thereby rendering the load time similarly +unconstrained by total dataset size. This mode is also extremely flexible, because different cohorts can be +loaded from the same base dataset simply by changing which patients and what offsets within patient data are +read on any given cohort, all without changing the base files or underlying code. However, this mode does +require ragged dataset collation which can be more resource intensive than pre-batched iteration, so it is +slower than the "Fixed-batch retrieval" approach. This mode is what is currently supported by this repository. + +#### Fixed-batch Retrieval + +In this mode of training, batches are selected once (potentially over many epochs), the items making up those +batches are selected, then their contents are frozen and written to disk in a fully tensorized, padded format. +This enables one to merely load batched data from disk directly onto the GPU during training, which is the +fastest possible way to train a model. However, this mode is less flexible than the other modes, as the +batches are frozen during training and cannot be changed without re-tensorizing the dataset, meaning that +every new cohort for training requires a new tensorization step. This mode is not currently supported by this +repository. + ## Overview of configuration manipulation ### Pipeline configuration: Stages and OmegaConf Resolvers From 1e06a63c6857b4ab9fc0492d3e34574ab8184ae3 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 16 Jun 2024 19:00:02 -0400 Subject: [PATCH 30/53] Initial files -- not yet processed, tested, or verified. --- configs/pytorch_dataset.yaml | 9 + pyproject.toml | 2 +- src/MEDS_polars_functions/pytorch_batch.py | 752 +++++++++++++++++++ src/MEDS_polars_functions/pytorch_dataset.py | 577 ++++++++++++++ src/MEDS_polars_functions/tensorize.py | 45 ++ src/MEDS_polars_functions/tokenize.py | 128 ++++ 6 files changed, 1512 insertions(+), 1 deletion(-) create mode 100644 configs/pytorch_dataset.yaml create mode 100644 src/MEDS_polars_functions/pytorch_batch.py create mode 100644 src/MEDS_polars_functions/pytorch_dataset.py create mode 100644 src/MEDS_polars_functions/tensorize.py create mode 100644 src/MEDS_polars_functions/tokenize.py diff --git a/configs/pytorch_dataset.yaml b/configs/pytorch_dataset.yaml new file mode 100644 index 0000000..3208c93 --- /dev/null +++ b/configs/pytorch_dataset.yaml @@ -0,0 +1,9 @@ + +MEDS_cohort_dir: ??? +task_name: null + +code_metadata_fp: ${MEDS_cohort_dir}/code_metadata.parquet +split_shards_fp: ${MEDS_cohort_dir}/splits.json +schema_files_root: ${MEDS_cohort_dir}/tokenize/schemas +tasks_root: ${MEDS_cohort_dir}/tasks +tensorized_root: ${MEDS_cohort_dir}/tensorized diff --git a/pyproject.toml b/pyproject.toml index 25b9527..142779c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ] -dependencies = ["polars", "pyarrow", "nested_ragged_tensors", "loguru", "hydra-core", "numpy"] +dependencies = ["polars", "pyarrow", "nested_ragged_tensors", "loguru", "hydra-core", "numpy", "ml-mixins"] [project.optional-dependencies] examples = ["rootutils"] diff --git a/src/MEDS_polars_functions/pytorch_batch.py b/src/MEDS_polars_functions/pytorch_batch.py new file mode 100644 index 0000000..45404c1 --- /dev/null +++ b/src/MEDS_polars_functions/pytorch_batch.py @@ -0,0 +1,752 @@ +"""A pytorch batch object for ease of working with tensorized data. Curently **IN PROGRESS**""" + +@dataclasses.dataclass +class PytorchBatch: + """A dataclass representing a batch of event flow data for a Pytorch model. + + This class defines the data-output interface for deep learning models built off Event Flow GPT datasets. + It stores the underlying data in the batch in a set of tensors, and also exposes some helpful methods and + properties to simplify interacting with data. + + Attributes: + event_mask: A boolean tensor of shape (batch_size, sequence_length) indicating which events in the + batch are valid (i.e., which are not padding). + time_delta: A float tensor of shape (batch_size, sequence_length) indicating the time delta in minutes + between each event and the subsequent event in that subject's sequence in the batch. + time: A float tensor of shape (batch_size, sequence_length) indicating the time in minutes since the + start of the subject's sequence of each event in the batch. This is often left unset, as it is + generally redundant with `time_delta`. However, it is used in generation, when the batch is + truncated to use efficient caching so the raw time point can't be recovered from the time delta. + static_indices: A long tensor of shape (batch_size, n_static_data_elements) indicating the indices of + the static data elements observed for each subject in the batch. These are *unordered*; meaning + that the second dimension position of a given element in this tensor is not necessarily + meaningful. This is because the static data elements are sparsely encoded, so the indices are + sufficient to recover the original data even in an unordered form. Here, by "indices" we mean that + these are integer values indicating the index of the associated categorical vocabulary element + corresponding to this observation; e.g., if the static measurement records that the subject's eye + color is brown, then if the categorical measurement of ``eye_color/BROWN``` in the unified + vocabulary is at position 32, then the index for that observation would be 32. + static_measurement_indices: A long tensor of shape (batch_size, n_static_data_elements) indicating + which measurements the indices in `static_indices` correspond to. E.g., if there is a static data + element corresponding to race, then the value in `static_measurement_indices` at the associated + position would be an integer index corresponding to the race measurement overall, whereas the + index at the identical position in `static_indices` would be an integer index corresponding to the + specific race observed for the subject (e.g., "White", "Black", etc.). + dynamic_indices: A long tensor of shape (batch_size, sequence_length, n_data_elements) indicating the + indices of the dynamic data elements observed for each subject in the batch. These are + *unordered* in the last dimension, meaning that the third dimension position of a given element in + this tensor is not necessarily meaningful. This is because the dynamic data elements are sparsely + encoded, so the indices and values are sufficient to recover the original data even in an + unordered form. + dynamic_measurement_indices: A long tensor of shape (batch_size, sequence_length, n_data_elements) + indicating which measurements the indices in `dynamic_indices` correspond to, similar to the + `static_measurement_indices` attribute. + dynamic_values: A float tensor of shape (batch_size, sequence_length, n_data_elements) indicating the + numeric values associated with each dynamic data element in the `dynamic_indices` tensor. If no + value was recorded for a given dynamic data element, the value in this tensor will be zero. + dynamic_values_mask: A boolean tensor of shape (batch_size, sequence_length, n_data_elements) + indicating which values in the `dynamic_values` tensor were actually observed. + start_time: A float tensor of shape (batch_size,) indicating the start time in minutes since the epoch + of each subject's sequence in the batch. This is often unset, as it is only used in generation + when we may need to know the actual time of day of any generated event. + start_idx: A long tensor of shape (batch_size,) indicating the start index of the sampled sub-sequence + for each subject in the batch relative to their raw data. + end_idx: A long tensor of shape (batch_size,) indicating the end index of the sampled sub-sequence + for each subject in the batch relative to their raw data. + subject_id: A long tensor of shape (batch_size,) indicating the subject ID of each member of the + batch. + stream_labels: A dictionary mapping task names to label LongTensors of shape (batch_size,) providing + labels for the associated tasks for the sequences in the batch. Is only used during fine-tuning or + zero-shot evaluation runs. + """ + + event_mask: torch.BoolTensor | None = None + + # We track this instead of raw times as it is less likely to suffer from underflow errors. + time_delta: torch.FloatTensor | None = None + + # We don't often use this, but it is used in generation. + time: torch.FloatTensor | None = None + + static_indices: torch.LongTensor | None = None + static_measurement_indices: torch.LongTensor | None = None + + dynamic_indices: torch.LongTensor | None = None + dynamic_measurement_indices: torch.LongTensor | None = None + dynamic_values: torch.FloatTensor | None = None + dynamic_values_mask: torch.BoolTensor | None = None + + start_time: torch.FloatTensor | None = None + start_idx: torch.LongTensor | None = None + end_idx: torch.LongTensor | None = None + subject_id: torch.LongTensor | None = None + + stream_labels: dict[str, torch.FloatTensor | torch.LongTensor] | None = None + + @staticmethod + def de_pad(L: list[int], *other_L) -> list[int] | tuple[list[int]]: + """Filters down all passed lists to only the indices where the first arg is non-zero. + + Args: + L: The list whose entries denote padding (0) or non-padding (non-zero). + *other_L: Any other lists that should be de-padded in the same way as L. + + Examples: + >>> de_pad([1, 3, 0, 4, 0, 0], [10, 0, 5, 8, 1, 0]) + ([1, 3, 4], [10, 0, 8]) + >>> de_pad([1, 3, 0, 4, 0, 0]) + [1, 3, 4] + """ + + out_L = [] + out_other = [None if x is None else [] for x in other_L] + + for i, v in enumerate(L): + if v != 0: + out_L.append(v) + for j, LL in enumerate(other_L): + if LL is not None: + out_other[j].append(LL[i]) + + if other_L: + return tuple([out_L] + out_other) + else: + return out_L + + @property + def device(self) -> torch.device: + """Returns the device storing the tensors in this batch. + + Assumes all elements of the batch are on the same device. + """ + return self.event_mask.device + + @property + def batch_size(self) -> int: + """Returns the batch size of this batch. + + Assumes the batch has not been sliced from its initial configuration. + """ + return self.event_mask.shape[0] + + @property + def sequence_length(self) -> int: + """Returns the maximum sequence length of the sequences in this batch. + + Assumes the batch has not been sliced from its initial configuration. + """ + return self.event_mask.shape[1] + + @property + def n_data_elements(self) -> int: + """Returns the maximum number of dynamic data elements of the events in this batch. + + Assumes the batch has not been sliced from its initial configuration. + """ + return self.dynamic_indices.shape[2] + + @property + def n_static_data_elements(self) -> int: + """Returns the maximum number of static data elements of the subjects in this batch. + + Assumes the batch has not been sliced from its initial configuration. + """ + return self.static_indices.shape[1] + + def get(self, item: str, default: Any) -> Any: + """A dictionary like get method for this batch, by attribute name.""" + return getattr(self, item) if item in self.keys() else default + + def _slice(self, index: tuple[int | slice] | int | slice) -> "PytorchBatch": + if not isinstance(index, tuple): + index = (index,) + if len(index) == 0 or len(index) > 3: + raise ValueError(f"Invalid index {index} for PytorchBatch! Must be of length 1, 2, or 3.") + if any(not isinstance(i, (int, slice)) for i in index): + raise ValueError(f"Invalid index {index} for PytorchBatch! Can only consist of ints and slices.") + + batch_index = index[0] + seq_index = slice(None) + meas_index = slice(None) + + if len(index) > 1: + seq_index = index[1] + if len(index) > 2: + meas_index = index[2] + + return PytorchBatch( + event_mask=self.event_mask[batch_index, seq_index], + time_delta=self.time_delta[batch_index, seq_index], + static_indices=None if self.static_indices is None else self.static_indices[batch_index], + static_measurement_indices=( + None + if self.static_measurement_indices is None + else self.static_measurement_indices[batch_index] + ), + dynamic_indices=self.dynamic_indices[batch_index, seq_index, meas_index], + dynamic_measurement_indices=self.dynamic_measurement_indices[batch_index, seq_index, meas_index], + dynamic_values=self.dynamic_values[batch_index, seq_index, meas_index], + dynamic_values_mask=self.dynamic_values_mask[batch_index, seq_index, meas_index], + start_time=None if self.start_time is None else self.start_time[batch_index], + start_idx=None if self.start_idx is None else self.start_idx[batch_index], + end_idx=None if self.end_idx is None else self.end_idx[batch_index], + subject_id=None if self.subject_id is None else self.subject_id[batch_index], + stream_labels=( + None + if self.stream_labels is None + else {k: v[batch_index] for k, v in self.stream_labels.items()} + ), + time=None if self.time is None else self.time[batch_index, seq_index], + ) + + def __getitem__(self, item: str | tuple[int | slice]) -> Union[torch.Tensor, "PytorchBatch"]: + match item: + case str(): + return dataclasses.asdict(self)[item] + case tuple() | int() | slice(): + return self._slice(item) + case _: + raise TypeError(f"Invalid type {type(item)} for {item} for indexing!") + + def __setitem__(self, item: str, val: torch.Tensor): + if not hasattr(self, item): + raise KeyError(f"Key {item} not found") + setattr(self, item, val) + + def __eq__(self, other: "PytorchBatch") -> bool: + """Checks for equality between self and other.""" + if self.keys() != other.keys(): + return False + + for k in self.keys(): + self_v = self[k] + other_v = other[k] + + if type(self_v) is not type(other_v): + return False + + match self_v: + case dict() if k == "stream_labels": + if self_v.keys() != other_v.keys(): + return False + for kk in self_v.keys(): + self_vv = self_v[kk] + other_vv = other_v[kk] + + if self_vv.shape != other_vv.shape: + return False + if (self_vv != other_vv).any(): + return False + + case torch.Tensor(): + if self_v.shape != other_v.shape: + return False + if (self_v != other_v).any(): + return False + case None if k in ("time", "stream_labels", "start_idx", "end_idx", "subject_id"): + if other_v is not None: + return False + case _: + raise ValueError(f"{k}: {type(self_v)} not supported in batch!") + return True + + def items(self): + """A dictionary like items` method for the elements of this batch, by attribute.""" + return dataclasses.asdict(self).items() + + def keys(self): + """A dictionary like keys method for the elements of this batch, by attribute.""" + return dataclasses.asdict(self).keys() + + def values(self): + """A dictionary like values method for the elements of this batch, by attribute.""" + return dataclasses.asdict(self).values() + + def last_sequence_element_unsqueezed(self) -> "PytorchBatch": + """Filters the batch down to just the last event, while retaining the same # of dims.""" + return self[:, -1:] + + def repeat_batch_elements(self, expand_size: int) -> "PytorchBatch": + """Repeats each batch element expand_size times in order. Used for generation. + + Args: + expand_size: The number of times each batch elements data should be repeated. + + Returns: A new PytorchBatch object with each batch element's data repeated expand_size times. + + Examples: + >>> import torch + >>> batch = PytorchBatch( + ... event_mask=torch.tensor([[True, True, True], [True, True, False]]), + ... time_delta=torch.tensor([[1.0, 2.0, 3.0], [1.0, 5.0, 0.0]]), + ... static_indices=torch.tensor([[0, 1], [1, 2]]), + ... static_measurement_indices=torch.tensor([[0, 1], [1, 1]]), + ... dynamic_indices=torch.tensor([[[0, 1], [1, 2], [2, 3]], [[0, 1], [1, 5], [0, 0]]]), + ... dynamic_measurement_indices=torch.tensor( + ... [[[0, 1], [1, 2], [2, 3]], [[0, 1], [1, 2], [0, 0]]] + ... ), + ... dynamic_values=torch.tensor( + ... [[[0.0, 1.0], [1.0, 2.0], [0, 0]], [[0.0, 1.0], [1.0, 0.0], [0, 0]]] + ... ), + ... dynamic_values_mask=torch.tensor([ + ... [[False, True], [True, True], [False, False]], + ... [[False, True], [True, False], [False, False]] + ... ]), + ... start_time=torch.tensor([0.0, 10.0]), + ... stream_labels={"a": torch.tensor([0, 1]), "b": torch.tensor([1, 2])}, + ... time=None, + ... ) + >>> repeated_batch = batch.repeat_batch_elements(2) + >>> for k, v in repeated_batch.items(): + ... print(k) + ... print(v) + event_mask + tensor([[ True, True, True], + [ True, True, True], + [ True, True, False], + [ True, True, False]]) + time_delta + tensor([[1., 2., 3.], + [1., 2., 3.], + [1., 5., 0.], + [1., 5., 0.]]) + time + None + static_indices + tensor([[0, 1], + [0, 1], + [1, 2], + [1, 2]]) + static_measurement_indices + tensor([[0, 1], + [0, 1], + [1, 1], + [1, 1]]) + dynamic_indices + tensor([[[0, 1], + [1, 2], + [2, 3]], + + [[0, 1], + [1, 2], + [2, 3]], + + [[0, 1], + [1, 5], + [0, 0]], + + [[0, 1], + [1, 5], + [0, 0]]]) + dynamic_measurement_indices + tensor([[[0, 1], + [1, 2], + [2, 3]], + + [[0, 1], + [1, 2], + [2, 3]], + + [[0, 1], + [1, 2], + [0, 0]], + + [[0, 1], + [1, 2], + [0, 0]]]) + dynamic_values + tensor([[[0., 1.], + [1., 2.], + [0., 0.]], + + [[0., 1.], + [1., 2.], + [0., 0.]], + + [[0., 1.], + [1., 0.], + [0., 0.]], + + [[0., 1.], + [1., 0.], + [0., 0.]]]) + dynamic_values_mask + tensor([[[False, True], + [ True, True], + [False, False]], + + [[False, True], + [ True, True], + [False, False]], + + [[False, True], + [ True, False], + [False, False]], + + [[False, True], + [ True, False], + [False, False]]]) + start_time + tensor([ 0., 0., 10., 10.]) + start_idx + None + end_idx + None + subject_id + None + stream_labels + {'a': tensor([0, 0, 1, 1]), 'b': tensor([1, 1, 2, 2])} + """ + + expanded_return_idx = ( + torch.arange(self.batch_size).view(-1, 1).repeat(1, expand_size).view(-1).to(self.device) + ) + + out_batch = {} + + for k, v in self.items(): + match v: + case dict(): + out_batch[k] = {kk: vv.index_select(0, expanded_return_idx) for kk, vv in v.items()} + case torch.Tensor(): + out_batch[k] = v.index_select(0, expanded_return_idx) + case None if k in ("time", "stream_labels", "start_idx", "end_idx", "subject_id"): + out_batch[k] = None + case _: + raise TypeError(f"{k}: {type(v)} not supported in batch for generation!") + + return PytorchBatch(**out_batch) + + def split_repeated_batch(self, n_splits: int) -> list["PytorchBatch"]: + """Split a batch into a list of batches by chunking batch elements into groups. + + This is the inverse of `PytorchBatch.repeat_batch_elements`. It is used for taking a generated batch + that has been expanded and splitting it into separate list elements with independent generations for + each batch element in the original batch. + + Args: + n_splits: The number of splits to make. + + Returns: A list of length `n_splits` of PytorchBatch objects, such that the list element i contains + batch elements [i, i+self.batch_size/n_splits). + + Raises: + ValueError: if `n_splits` is not a positive integer divisor of `self.batch_size`. + + Examples: + >>> import torch + >>> batch = PytorchBatch( + ... event_mask=torch.tensor([ + ... [True, True, True], + ... [True, True, False], + ... [True, False, False], + ... [False, False, False] + ... ]), + ... time_delta=torch.tensor([ + ... [1.0, 2.0, 3.0], + ... [1.0, 5.0, 0.0], + ... [2.3, 0.0, 0.0], + ... [0.0, 0.0, 0.0], + ... ]), + ... static_indices=torch.tensor([[0, 1], [1, 2], [1, 3], [0, 5]]), + ... static_measurement_indices=torch.tensor([[0, 1], [1, 1], [1, 1], [0, 2]]), + ... dynamic_indices=torch.tensor([ + ... [[0, 1], [1, 2], [2, 3]], + ... [[0, 1], [1, 5], [0, 0]], + ... [[0, 2], [0, 0], [0, 0]], + ... [[0, 0], [0, 0], [0, 0]], + ... ]), + ... dynamic_measurement_indices=torch.tensor([ + ... [[0, 1], [1, 2], [2, 3]], + ... [[0, 1], [1, 2], [0, 0]], + ... [[0, 2], [0, 0], [0, 0]], + ... [[0, 0], [0, 0], [0, 0]], + ... ]), + ... dynamic_values=torch.tensor([ + ... [[0.0, 1.0], [1.0, 2.0], [0.0, 0.0]], + ... [[0.0, 1.0], [1.0, 0.0], [0.0, 0.0]], + ... [[0.0, 1.0], [0.0, 0.0], [0.0, 0.0]], + ... [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + ... ]), + ... dynamic_values_mask=torch.tensor([ + ... [[False, True], [True, True], [False, False]], + ... [[False, True], [True, False], [False, False]], + ... [[False, True], [False, False], [False, False]], + ... [[False, False], [False, False], [False, False]], + ... ]), + ... start_time=torch.tensor([0.0, 10.0, 3.0, 2.2]), + ... stream_labels={"a": torch.tensor([0, 1, 0, 1]), "b": torch.tensor([1, 2, 4, 3])}, + ... time=None, + ... ) + >>> batch.split_repeated_batch(3) + Traceback (most recent call last): + ... + ValueError: n_splits (3) must be a positive integer divisor of batch_size (4) + >>> for i, T in enumerate(batch.split_repeated_batch(2)): + ... print(f"Returned batch {i}:") + ... for k, v in T.items(): + ... print(k) + ... print(v) + Returned batch 0: + event_mask + tensor([[ True, True, True], + [ True, False, False]]) + time_delta + tensor([[1.0000, 2.0000, 3.0000], + [2.3000, 0.0000, 0.0000]]) + time + None + static_indices + tensor([[0, 1], + [1, 3]]) + static_measurement_indices + tensor([[0, 1], + [1, 1]]) + dynamic_indices + tensor([[[0, 1], + [1, 2], + [2, 3]], + + [[0, 2], + [0, 0], + [0, 0]]]) + dynamic_measurement_indices + tensor([[[0, 1], + [1, 2], + [2, 3]], + + [[0, 2], + [0, 0], + [0, 0]]]) + dynamic_values + tensor([[[0., 1.], + [1., 2.], + [0., 0.]], + + [[0., 1.], + [0., 0.], + [0., 0.]]]) + dynamic_values_mask + tensor([[[False, True], + [ True, True], + [False, False]], + + [[False, True], + [False, False], + [False, False]]]) + start_time + tensor([0., 3.]) + start_idx + None + end_idx + None + subject_id + None + stream_labels + {'a': tensor([0, 0]), 'b': tensor([1, 4])} + Returned batch 1: + event_mask + tensor([[ True, True, False], + [False, False, False]]) + time_delta + tensor([[1., 5., 0.], + [0., 0., 0.]]) + time + None + static_indices + tensor([[1, 2], + [0, 5]]) + static_measurement_indices + tensor([[1, 1], + [0, 2]]) + dynamic_indices + tensor([[[0, 1], + [1, 5], + [0, 0]], + + [[0, 0], + [0, 0], + [0, 0]]]) + dynamic_measurement_indices + tensor([[[0, 1], + [1, 2], + [0, 0]], + + [[0, 0], + [0, 0], + [0, 0]]]) + dynamic_values + tensor([[[0., 1.], + [1., 0.], + [0., 0.]], + + [[0., 0.], + [0., 0.], + [0., 0.]]]) + dynamic_values_mask + tensor([[[False, True], + [ True, False], + [False, False]], + + [[False, False], + [False, False], + [False, False]]]) + start_time + tensor([10.0000, 2.2000]) + start_idx + None + end_idx + None + subject_id + None + stream_labels + {'a': tensor([1, 1]), 'b': tensor([2, 3])} + >>> repeat_batch = batch.repeat_batch_elements(5) + >>> split_batches = repeat_batch.split_repeated_batch(5) + >>> for i, v in enumerate(split_batches): + ... assert v == batch, f"Batch {i} ({v}) not equal to original batch {batch}!" + """ + + if not isinstance(n_splits, int) or n_splits <= 0 or self.batch_size % n_splits != 0: + raise ValueError( + f"n_splits ({n_splits}) must be a positive integer divisor of batch_size ({self.batch_size})" + ) + + self.batch_size // n_splits + out_batches = [defaultdict(dict) for _ in range(n_splits)] + for k, v in self.items(): + match v: + case dict(): + for kk, vv in v.items(): + reshaped = vv.reshape(vv.shape[0] // n_splits, n_splits, *vv.shape[1:]) + for i in range(n_splits): + out_batches[i][k][kk] = reshaped[:, i, ...] + case torch.Tensor(): + reshaped = v.reshape(v.shape[0] // n_splits, n_splits, *v.shape[1:]) + for i in range(n_splits): + out_batches[i][k] = reshaped[:, i, ...] + case None if k in ("time", "stream_labels", "start_idx", "end_idx", "subject_id"): + pass + case _: + raise TypeError(f"{k}: {type(v)} not supported in batch for generation!") + + return [PytorchBatch(**B) for B in out_batches] + + def convert_to_DL_DF(self) -> pl.DataFrame: + """Converts the batch data into a sparse DataFrame representation. + + Examples: + >>> import torch + >>> batch = PytorchBatch( + ... event_mask=torch.tensor([ + ... [True, True, True], + ... [True, True, False], + ... [True, False, False], + ... [False, False, False] + ... ]), + ... time_delta=torch.tensor([ + ... [1.0, 2.0, 3.0], + ... [1.0, 5.0, 0.0], + ... [2.3, 0.0, 0.0], + ... [0.0, 0.0, 0.0], + ... ]), + ... static_indices=torch.tensor([[0, 1], [1, 2], [1, 3], [0, 5]]), + ... static_measurement_indices=torch.tensor([[0, 1], [1, 1], [1, 1], [0, 2]]), + ... dynamic_indices=torch.tensor([ + ... [[0, 1], [1, 2], [2, 3]], + ... [[0, 1], [1, 5], [0, 0]], + ... [[0, 2], [0, 0], [0, 0]], + ... [[0, 0], [0, 0], [0, 0]], + ... ]), + ... dynamic_measurement_indices=torch.tensor([ + ... [[0, 1], [1, 2], [2, 3]], + ... [[0, 1], [1, 2], [0, 0]], + ... [[0, 2], [0, 0], [0, 0]], + ... [[0, 0], [0, 0], [0, 0]], + ... ]), + ... dynamic_values=torch.tensor([ + ... [[0.0, 1.0], [1.0, 2.0], [0.0, 0.0]], + ... [[0.0, 1.0], [1.0, 0.0], [0.0, 0.0]], + ... [[0.0, 1.0], [0.0, 0.0], [0.0, 0.0]], + ... [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + ... ]), + ... dynamic_values_mask=torch.tensor([ + ... [[False, True], [True, True], [False, False]], + ... [[False, True], [True, False], [False, False]], + ... [[False, True], [False, False], [False, False]], + ... [[False, False], [False, False], [False, False]], + ... ]), + ... start_time=torch.tensor([0.0, 10.0, 3.0, 2.2]), + ... stream_labels={"a": torch.tensor([0, 1, 0, 1]), "b": torch.tensor([1, 2, 4, 3])}, + ... time=None, + ... ) + >>> pl.Config.set_tbl_width_chars(80) + + >>> batch.convert_to_DL_DF() + shape: (4, 7) + ┌───────────┬───────────┬──────────┬──────────┬──────────┬──────────┬──────────┐ + │ time_delt ┆ static_in ┆ static_m ┆ dynamic_ ┆ dynamic_ ┆ dynamic_ ┆ start_ti │ + │ a ┆ dices ┆ easureme ┆ indices ┆ measurem ┆ values ┆ me │ + │ --- ┆ --- ┆ nt_indic ┆ --- ┆ ent_indi ┆ --- ┆ --- │ + │ list[f64] ┆ list[f64] ┆ es ┆ list[lis ┆ ces ┆ list[lis ┆ f64 │ + │ ┆ ┆ --- ┆ t[f64]] ┆ --- ┆ t[f64]] ┆ │ + │ ┆ ┆ list[f64 ┆ ┆ list[lis ┆ ┆ │ + │ ┆ ┆ ] ┆ ┆ t[f64]] ┆ ┆ │ + ╞═══════════╪═══════════╪══════════╪══════════╪══════════╪══════════╪══════════╡ + │ [1.0, ┆ [1.0] ┆ [1.0] ┆ [[1.0], ┆ [[1.0], ┆ [[1.0], ┆ 0.0 │ + │ 2.0, 3.0] ┆ ┆ ┆ [1.0, ┆ [1.0, ┆ [1.0, ┆ │ + │ ┆ ┆ ┆ 2.0], ┆ 2.0], ┆ 2.0], ┆ │ + │ ┆ ┆ ┆ [2.0, ┆ [2.0, ┆ [null, ┆ │ + │ ┆ ┆ ┆ 3.0]… ┆ 3.0]… ┆ nul… ┆ │ + │ [1.0, ┆ [1.0, ┆ [1.0, ┆ [[1.0], ┆ [[1.0], ┆ [[1.0], ┆ 10.0 │ + │ 5.0] ┆ 2.0] ┆ 1.0] ┆ [1.0, ┆ [1.0, ┆ [1.0, ┆ │ + │ ┆ ┆ ┆ 5.0]] ┆ 2.0]] ┆ null]] ┆ │ + │ [2.3] ┆ [1.0, ┆ [1.0, ┆ [[2.0]] ┆ [[2.0]] ┆ [[1.0]] ┆ 3.0 │ + │ ┆ 3.0] ┆ 1.0] ┆ ┆ ┆ ┆ │ + │ [] ┆ [5.0] ┆ [2.0] ┆ [] ┆ [] ┆ [] ┆ 2.2 │ + └───────────┴───────────┴──────────┴──────────┴──────────┴──────────┴──────────┘ + """ + + df = { + k: [] + for k, v in self.items() + if k not in ("stream_labels", "event_mask", "dynamic_values_mask") and v is not None + } + + for k in ("start_time", "subject_id", "start_idx", "end_idx"): + if self[k] is not None: + df[k] = list(self[k]) + + for i in range(self.batch_size): + idx, measurement_idx = self.de_pad(self.static_indices[i], self.static_measurement_indices[i]) + df["static_indices"].append(idx) + df["static_measurement_indices"].append(measurement_idx) + + _, time_delta, time, idx, measurement_idx, vals, vals_mask = self.de_pad( + self.event_mask[i], + None if self.time_delta is None else self.time_delta[i], + None if self.time is None else self.time[i], + self.dynamic_indices[i], + self.dynamic_measurement_indices[i], + self.dynamic_values[i], + self.dynamic_values_mask[i], + ) + + if time_delta is not None: + df["time_delta"].append(time_delta) + if time is not None: + df["time"].append(time) + + names = ("dynamic_indices", "dynamic_measurement_indices", "dynamic_values") + for n in names: + df[n].append([]) + + for j in range(len(idx)): + de_padded_vals = self.de_pad(idx[j], measurement_idx[j], vals[j], vals_mask[j]) + # Now we add the indices and measurement indices + for n, v in zip(names[:-1], de_padded_vals[:-2]): + df[n][i].append(v) + + df["dynamic_values"][i].append([None if not m else v for v, m in zip(*de_padded_vals[-2:])]) + + return pl.DataFrame(df) diff --git a/src/MEDS_polars_functions/pytorch_dataset.py b/src/MEDS_polars_functions/pytorch_dataset.py new file mode 100644 index 0000000..60424f8 --- /dev/null +++ b/src/MEDS_polars_functions/pytorch_dataset.py @@ -0,0 +1,577 @@ + +import json +from collections import defaultdict +from pathlib import Path + +import numpy as np +import polars as pl +import torch +from loguru import logger +from mixins import SeedableMixin +from nested_ragged_tensors.ragged_numpy import ( + NP_FLOAT_TYPES, + NP_INT_TYPES, + NP_UINT_TYPES, + JointNestedRaggedTensorDict, +) +from tqdm.auto import tqdm + +from ..utils import count_or_proportion +from .config import PytorchDatasetConfig, SeqPaddingSide, SubsequenceSamplingStrategy + + +import dataclasses +import enum +from collections import defaultdict +from typing import Any, Union + +import polars as pl +import torch + +from omegaconf import DictConfig + + +def to_int_index(col: pl.Expr) -> pl.Expr: + """Returns an integer index of the unique elements seen in this column. + + The returned index is into a vocabulary sorted lexographically. + + Args: + col: The column containing the data to be converted into integer indices. + + Examples: + >>> import polars as pl + >>> X = pl.DataFrame({ + ... 'c': ['foo', 'bar', 'foo', 'bar', 'baz', None, 'bar', 'aba'], + ... 'd': [1, 2, 3, 4, 5, 6, 7, 8] + ... }) + >>> X.with_columns(to_int_index(pl.col('c'))) + shape: (8, 2) + ┌──────┬─────┐ + │ c ┆ d │ + │ --- ┆ --- │ + │ u32 ┆ i64 │ + ╞══════╪═════╡ + │ 4 ┆ 1 │ + │ 1 ┆ 2 │ + │ 4 ┆ 3 │ + │ 1 ┆ 4 │ + │ 2 ┆ 5 │ + │ null ┆ 6 │ + │ 1 ┆ 7 │ + │ 0 ┆ 8 │ + └──────┴─────┘ + """ + + indices = col.unique(maintain_order=True).drop_nulls().search_sorted(col) + return pl.when(col.is_null()).then(pl.lit(None)).otherwise(indices).alias(col.meta.output_name()) + + +class PytorchDataset(SeedableMixin, torch.utils.data.Dataset): + """A PyTorch Dataset class. + + Args: + config: Configuration options for the dataset, in an `omegaconf.DictConfig` object. + split: The split of data which should be used in this dataset (e.g., ``'train'``, ``'tuning'``, + ``'held_out'``). This will dictate where the system looks for files. + """ + + TYPE_CHECKERS = { + "multi_class_classification": [ + ( + {pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64, pl.Int8, pl.Int16, pl.Int32, pl.Int64}, + None, + ), + ({pl.Categorical(ordering="physical"), pl.Categorical(ordering="lexical")}, to_int_index), + ({pl.Utf8}, to_int_index), + ], + "binary_classification": [({pl.Boolean}, lambda Y: Y.cast(pl.Float32))], + "regression": [({pl.Float32, pl.Float64}, None)], + } + """Type checker and conversion parameters for labeled datasets.""" + + @classmethod + def normalize_task(cls, col: pl.Expr, dtype: pl.DataType) -> tuple[str, pl.Expr]: + """Normalizes the task labels in `col` of dtype `dtype` to a common format. + + Args: + col: The column containing the task labels, in polars expression format. + dtype: The polars data type of the task labels. + + Returns: + The task type (a string key into the `TYPE_CHECKERS` dictionary) and the normalized column + expression. + + Raises: + TypeError: If the task labels are not of a supported type. + """ + for task_type, checkers in cls.TYPE_CHECKERS.items(): + for valid_dtypes, normalize_fn in checkers: + if dtype in valid_dtypes: + return task_type, (col if normalize_fn is None else normalize_fn(col)) + + raise TypeError(f"Can't process label of {dtype} type!") + + def __init__(self, cfg: DictConfig, split: str): + super().__init__() + + self.config = cfg + self.split = split + + logger.info("Scanning code metadata") + self.code_metadata = pl.scan_parquet(self.config.code_metadata_fp) + + logger.info("Reading splits & patient shards") + self.read_shards() + + logger.info("Reading patient descriptors") + self.read_patient_descriptors() + + if self.config.min_seq_len is not None and self.config.min_seq_len > 1: + logger.info(f"Restricting to subjects with at least {config.min_seq_len} events") + self.filter_to_min_seq_len() + + if self.config.train_subset_size not in (None, "FULL") and self.split == "train": + logger.info(f"Filtering training subset size to {self.config.train_subset_size}") + self.filter_to_subset() + + self.set_inter_event_time_stats() + + def read_shards(self): + """Reads the split-specific patient shards from the ESGPT or MEDS dataset.""" + shards_fp = self.config.split_shards_fp + all_shards = json.loads(shards_fp.read_text()) + self.shards = {sp: subjs for sp, subjs in all_shards.items() if sp.startswith(f"{self.split}/")} + self.subj_map = {subj: sp for sp, subjs in self.shards.items() for subj in subjs} + + @property + def measurement_configs(self): + """Grabs the measurement configs from the config.""" + return self.config.measurement_configs + + def read_patient_descriptors(self): + """Reads the patient descriptors from the ESGPT or MEDS dataset.""" + self.static_dfs = {} + self.subj_indices = {} + self.subj_seq_bounds = {} + + shards = tqdm(self.shards.keys(), total=len(self.shards), desc="Reading static shards", leave=False) + for shard in shards: + static_fp = self.config.schema_files_root / f"{shard}.parquet" + df = pl.read_parquet( + static_fp, + columns=[ + "patient_id", + "start_time", + pl.col("code").alias("static_indices"), + pl.col("numerical_value").alias("static_values"), + "timestamp", + "patient_offset", + ], + use_pyarrow=True, + ) + + self.static_dfs[shard] = df + patient_ids = df["patient_id"] + n_events = df.select(pl.col("timestamp").list.lengths().alias("n_events")).get_column("n_events") + for i, (subj, n_events) in enumerate(zip(patient_ids, n_events)): + if subj in self.subj_indices or subj in self.subj_seq_bounds: + raise ValueError(f"Duplicate subject {subj} in {shard}!") + + self.subj_indices[subj] = i + self.subj_seq_bounds[subj] = (0, n_events) + + if self.has_task: + self.index = [(subj, *bounds) for subj, bounds in self.subj_seq_bounds.items()] + self.labels = {} + self.tasks = None + self.task_types = None + self.task_vocabs = None + else: + task_df_fp = self.config.tasks_root / f"{self.config.task_name}.parquet" + task_info_fp = self.config.tasks_root / f"{self.config.task_name}_info.json" + + logger.info(f"Reading task constraints for {self.config.task_name} from {task_df_fp}") + task_df = pl.read_parquet(task_df_fp, use_pyarrow=True) + + task_info = self.get_task_info(task_df) + + if task_info_fp.is_file(): + loaded_task_info = json.loads(task_info_fp.read_text()) + if loaded_task_info != task_info: + raise ValueError( + f"Task info differs from on disk!\nDisk:\n{loaded_task_info}\n" + f"Local:\n{task_info}\nSplit: {self.split}" + ) + logger.info(f"Re-built existing {task_info_fp} and it matches.") + else: + task_info_fp.parent.mkdir(exist_ok=True, parents=True) + task_info_fp.write_text(json.dumps(task_info)) + + idx_col = "_row_index" + while idx_col in task_df.columns: + idx_col = f"_{idx_col}" + + raise NotImplementedError("Need to figure out task constraints still" + + task_df_joint = ( + task_df.select("patient_id", "start_time", "end_time") + .with_row_index(idx_col) + .group_by("patient_id") + .agg("start_time", "end_time", idx_col) + .join( + pl.concat(self.static_dfs.values()).select( + "patient_id", pl.col("start_time").alias("start_time_global"), "time_delta" + ), + on="patient_id", + how="left", + ) + .with_columns( + pl.col("timestamp").alias("min_since_start") + ) + ) + + min_at_task_start = ( + (pl.col("start_time") - pl.col("start_time_global")).dt.total_seconds() / 60 + ).alias("min_at_task_start") + min_at_task_end = ( + (pl.col("end_time") - pl.col("start_time_global")).dt.total_seconds() / 60 + ).alias("min_at_task_end") + + start_idx_expr = (pl.col("min_since_start").search_sorted(pl.col("min_at_task_start"))).alias( + "start_idx" + ) + end_idx_expr = (pl.col("min_since_start").search_sorted(pl.col("min_at_task_end"))).alias( + "end_idx" + ) + + task_df_joint = ( + task_df_joint.explode(idx_col, "start_time", "end_time") + .with_columns(min_at_task_start, min_at_task_end) + .explode("min_since_start") + .group_by("patient_id", idx_col, "min_at_task_start", "min_at_task_end", maintain_order=True) + .agg(start_idx_expr.first(), end_idx_expr.first()) + .sort(by=idx_col, descending=False) + ) + + patient_ids = task_df_joint["patient_id"] + start_indices = task_df_joint["start_idx"] + end_indices = task_df_joint["end_idx"] + + self.labels = {t: task_df.get_column(t).to_list() for t in self.tasks} + self.index = list(zip(patient_ids, start_indices, end_indices)) + + def get_task_info(self, task_df: pl.DataFrame): + """Gets the task information from the task dataframe.""" + self.tasks = sorted([c for c in task_df.columns if c not in ["patient_id", "start_time", "end_time"]]) + + self.task_types = {} + self.task_vocabs = {} + + normalized_cols = [] + for t in self.tasks: + task_type, normalized_vals = self.normalize_task(col=pl.col(t), dtype=task_df.schema[t]) + self.task_types[t] = task_type + normalized_cols.append(normalized_vals.alias(t)) + + task_df = task_df.with_columns(normalized_cols) + + for t in self.tasks: + match self.task_types[t]: + case "binary_classification": + self.task_vocabs[t] = [False, True] + case "multi_class_classification": + self.task_vocabs[t] = list(range(task_df.select(pl.col(t).max()).item() + 1)) + case _: + raise NotImplementedError(f"Task type {self.task_types[t]} not implemented!") + + return {"tasks": sorted(self.tasks), "vocabs": self.task_vocabs, "types": self.task_types} + + def filter_to_min_seq_len(self): + """Filters the dataset to only include subjects with at least `config.min_seq_len` events.""" + if self.has_task: + logger.warning( + f"Filtering task {self.config.task_name} to min_seq_len {self.config.min_seq_len}. " + "This may result in incomparable model results against runs with different constraints!" + ) + + orig_len = len(self) + orig_n_subjects = len(set(self.patient_ids)) + valid_indices = [ + i for i, (subj, start, end) in enumerate(self.index) if end - start >= self.config.min_seq_len + ] + self.index = [self.index[i] for i in valid_indices] + self.labels = {t: [t_labels[i] for i in valid_indices] for t, t_labels in self.labels.items()} + new_len = len(self) + new_n_subjects = len(set(self.patient_ids)) + logger.info( + f"Filtered data due to sequence length constraint (>= {self.config.min_seq_len}) from " + f"{orig_len} to {new_len} rows and {orig_n_subjects} to {new_n_subjects} subjects." + ) + + def filter_to_subset(self): + """Filters the dataset to only include a subset of subjects.""" + + orig_len = len(self) + orig_n_subjects = len(set(self.patient_ids)) + rng = np.random.default_rng(self.config.train_subset_seed) + subset_subjects = rng.choice( + list(set(self.patient_ids)), + size=count_or_proportion(orig_n_subjects, self.config.train_subset_size), + replace=False, + ) + valid_indices = [i for i, (subj, start, end) in enumerate(self.index) if subj in subset_subjects] + self.index = [self.index[i] for i in valid_indices] + self.labels = {t: [t_labels[i] for i in valid_indices] for t, t_labels in self.labels.items()} + new_len = len(self) + new_n_subjects = len(set(self.patient_ids)) + logger.info( + f"Filtered data to subset of {self.config.train_subset_size} subjects from " + f"{orig_len} to {new_len} rows and {orig_n_subjects} to {new_n_subjects} subjects." + ) + + def set_inter_event_time_stats(self): + """Sets the inter-event time statistics for the dataset.""" + data_for_stats = pl.concat([x.lazy() for x in self.static_dfs.values()]) + stats = ( + data_for_stats.select( + pl.col("time_delta").explode().drop_nulls().drop_nans().alias("inter_event_time") + ) + .select( + pl.col("inter_event_time").min().alias("min"), + pl.col("inter_event_time").log().mean().alias("mean_log"), + pl.col("inter_event_time").log().std().alias("std_log"), + ) + .collect() + ) + + if stats["min"].item() <= 0: + bad_inter_event_times = data_for_stats.filter(pl.col("time_delta").list.min() <= 0).collect() + bad_patient_ids = set(bad_inter_event_times["patient_id"].to_list()) + warning_strs = [ + f"Observed inter-event times <= 0 for {len(bad_inter_event_times)} subjects!", + f"Bad Subject IDs: {', '.join(str(x) for x in bad_patient_ids)}", + f"Global min: {stats['min'].item()}", + ] + if self.config.save_dir is not None: + fp = self.config.save_dir / f"malformed_data_{self.split}.parquet" + bad_inter_event_times.write_parquet(fp) + warning_strs.append(f"Wrote malformed data records to {fp}") + warning_strs.append("Removing malformed subjects") + + logger.warning("\n".join(warning_strs)) + + self.index = [x for x in self.index if x[0] not in bad_patient_ids] + + self.mean_log_inter_event_time_min = stats["mean_log"].item() + self.std_log_inter_event_time_min = stats["std_log"].item() + + @property + def patient_ids(self) -> list[int]: + return [x[0] for x in self.index] + + def __len__(self): + return len(self.index) + + @property + def has_task(self) -> bool: + return self.config.task_name is not None + + @property + def seq_padding_side(self) -> SeqPaddingSide: + return self.config.seq_padding_side + + @property + def max_seq_len(self) -> int: + return self.config.max_seq_len + + @property + def is_subset_dataset(self) -> bool: + return self.config.train_subset_size != "FULL" + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + """Returns a Returns a dictionary corresponding to a single subject's data. + + The output of this will not be tensorized as that work will need to be re-done in the collate function + regardless. The output will have structure: + `` + { + 'time_delta': [seq_len], + 'dynamic_indices': [seq_len, n_data_per_event] (ragged), + 'dynamic_values': [seq_len, n_data_per_event] (ragged), + 'dynamic_measurement_indices': [seq_len, n_data_per_event] (ragged), + 'static_indices': [seq_len, n_data_per_event] (ragged), + 'static_measurement_indices': [seq_len, n_data_per_event] (ragged), + } + `` + + 1. ``time_delta`` captures the time between each event and the subsequent event. + 2. ``dynamic_indices`` captures the categorical metadata elements listed in `self.data_cols` in a + unified vocabulary space spanning all metadata vocabularies. + 3. ``dynamic_values`` captures the numerical metadata elements listed in `self.data_cols`. If no + numerical elements are listed in `self.data_cols` for a given categorical column, the according + index in this output will be `np.NaN`. + 4. ``dynamic_measurement_indices`` captures which measurement vocabulary was used to source a given + data element. + 5. ``static_indices`` captures the categorical metadata elements listed in `self.static_cols` in a + unified vocabulary. + 6. ``static_measurement_indices`` captures which measurement vocabulary was used to source a given + data element. + """ + return self._seeded_getitem(idx) + + @SeedableMixin.WithSeed + def _seeded_getitem(self, idx: int) -> dict[str, list[float]]: + """Returns a Returns a dictionary corresponding to a single subject's data. + + This function is a seedable version of `__getitem__`. + """ + + patient_id, st, end = self.index[idx] + + shard = self.subj_map[patient_id] + patient_idx = self.subj_indices[patient_id] + static_row = self.static_dfs[shard][patient_idx].to_dict() + + out = { + "static_indices": static_row["static_indices"].item().to_list(), + "static_measurement_indices": static_row["static_measurement_indices"].item().to_list(), + } + + if self.config.do_include_patient_id: + out["patient_id"] = patient_id + + seq_len = end - st + if seq_len > self.max_seq_len: + match self.config.subsequence_sampling_strategy: + case SubsequenceSamplingStrategy.RANDOM: + start_offset = np.random.choice(seq_len - self.max_seq_len) + case SubsequenceSamplingStrategy.TO_END: + start_offset = seq_len - self.max_seq_len + case SubsequenceSamplingStrategy.FROM_START: + start_offset = 0 + case _: + raise ValueError( + f"Invalid subsequence sampling strategy {self.config.subsequence_sampling_strategy}!" + ) + + st += start_offset + end = min(end, st + self.max_seq_len) + + if self.config.do_include_subsequence_indices: + out["start_idx"] = st + out["end_idx"] = end + + out["dynamic"] = ( + JointNestedRaggedTensorDict.load_slice( + self.config.tensorized_root / f"{shard}.pt", patient_idx + )[st:end] + ) + + if self.config.do_include_start_time_min: + out["start_time"] = static_row["start_time"] = static_row[ + "start_time" + ].item().timestamp() / 60.0 + sum(static_row["time_delta"].item().to_list()[:st]) + + for t, t_labels in self.labels.items(): + out[t] = t_labels[idx] + + return out + + def __dynamic_only_collate(self, batch: list[dict[str, list[float]]]) -> dict: + """An internal collate function for only dynamic data.""" + keys = batch[0].keys() + dense_keys = {k for k in keys if k not in ("dynamic", "static_indices", "static_measurement_indices")} + + if dense_keys: + dense_collated = torch.utils.data.default_collate([{k: x[k] for k in dense_keys} for x in batch]) + else: + dense_collated = {} + + dynamic = JointNestedRaggedTensorDict.vstack([x["dynamic"] for x in batch]).to_dense( + padding_side=self.seq_padding_side + ) + dynamic["event_mask"] = dynamic.pop("dim1/mask") + dynamic["dynamic_values_mask"] = dynamic.pop("dim2/mask") & ~np.isnan(dynamic["dynamic_values"]) + + dynamic_collated = {} + for k, v in dynamic.items(): + if k.endswith("mask"): + dynamic_collated[k] = torch.from_numpy(v) + elif v.dtype in NP_UINT_TYPES + NP_INT_TYPES: + dynamic_collated[k] = torch.from_numpy(v.astype(int)).long() + elif v.dtype in NP_FLOAT_TYPES: + dynamic_collated[k] = torch.from_numpy(v.astype(float)).float() + else: + raise TypeError(f"Don't know how to tensorify {k} of type {v.dtype}!") + + collated = {**dense_collated, **dynamic_collated} + + out_batch = {} + out_batch["event_mask"] = collated["event_mask"] + out_batch["dynamic_values_mask"] = collated["dynamic_values_mask"] + out_batch["time_delta"] = torch.nan_to_num(collated["time_delta"].float(), nan=0) + out_batch["dynamic_indices"] = collated["dynamic_indices"].long() + out_batch["dynamic_measurement_indices"] = collated["dynamic_measurement_indices"].long() + out_batch["dynamic_values"] = torch.nan_to_num(collated["dynamic_values"].float(), nan=0) + + if self.config.do_include_start_time_min: + out_batch["start_time"] = collated["start_time"].float() + if self.config.do_include_subsequence_indices: + out_batch["start_idx"] = collated["start_idx"].long() + out_batch["end_idx"] = collated["end_idx"].long() + if self.config.do_include_patient_id: + out_batch["patient_id"] = collated["patient_id"].long() + + if not self.has_task: + return out_batch + + out_labels = {} + for task in self.tasks: + match self.task_types[task]: + case "multi_class_classification": + out_labels[task] = collated[task].long() + case "binary_classification": + out_labels[task] = collated[task].float() + case "regression": + out_labels[task] = collated[task].float() + case _: + raise TypeError(f"Don't know how to tensorify task of type {self.task_types[task]}!") + out_batch["supervised_labels"] = out_labels + + return out_batch + + def collate(self, batch: list[dict]) -> dict: + """Combines the ragged dictionaries produced by `__getitem__` into a tensorized batch. + + This function handles conversion of arrays to tensors and padding of elements within the batch across + static data elements, sequence events, and dynamic data elements. + + Args: + batch: A list of `__getitem__` format output dictionaries. + + Returns: + A fully collated, tensorized, and padded batch. + """ + + out_batch = self.__dynamic_only_collate(batch) + + max_n_static = max(len(x["static_indices"]) for x in batch) + static_padded_fields = defaultdict(list) + for e in batch: + n_static = len(e["static_indices"]) + static_delta = max_n_static - n_static + for k in ("static_indices", "static_measurement_indices"): + if static_delta > 0: + static_padded_fields[k].append( + torch.nn.functional.pad( + torch.tensor(e[k], dtype=torch.long), (0, static_delta), value=0 + ) + ) + else: + static_padded_fields[k].append(torch.tensor(e[k], dtype=torch.long)) + + for k, v in static_padded_fields.items(): + out_batch[k] = torch.cat([T.unsqueeze(0) for T in v], dim=0) + + return out_batch diff --git a/src/MEDS_polars_functions/tensorize.py b/src/MEDS_polars_functions/tensorize.py new file mode 100644 index 0000000..a1b8961 --- /dev/null +++ b/src/MEDS_polars_functions/tensorize.py @@ -0,0 +1,45 @@ +"""Functions for tensorizing MEDS datasets. + +TODO +""" + +from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict +import polars as pl + +def convert_to_NRT(tokenized_df: pl.LazyFrame) -> JointNestedRaggedTensorDict: + """This converts a tokenized dataframe into a nested ragged tensor. + + Most of the work for this function is actually done in `tokenize` -- this function is just a wrapper + to convert the output into a nested ragged tensor using polars' built-in `to_dict` method. + + Args: + tokenized_df: The tokenized dataframe. + + Returns: + A `JointNestedRaggedTensorDict` object representing the tokenized dataframe, accounting for however + many levels of ragged nesting are present among the codes and numerical values. + + Raises: + ValueError: If there are no time delta columns or if there are multiple time delta columns. + + Examples: + >>> raise NotImplementedError + """ + + # There should only be one time delta column, but this ensures we catch it regardless of the unit of time + # used to convert the time deltas, and that we verify there is only one such column. + time_delta_cols = [c for c in tokenized_df.columns if c.startswith("time_delta/")] + + if len(time_delta_cols) == 0: + raise ValueError("Expected at least one time delta column, found none") + elif len(time_delta_cols) > 1: + raise ValueError(f"Expected exactly one time delta column, found columns: {time_delta_cols}") + + time_delta_col = time_delta_cols[0] + + return JointNestedRaggedTensorDict( + tokenized_df.select( + time_delta_col, + "code", + "numerical_value").collect().to_dict(as_series=False) + ) diff --git a/src/MEDS_polars_functions/tokenize.py b/src/MEDS_polars_functions/tokenize.py new file mode 100644 index 0000000..17bfab7 --- /dev/null +++ b/src/MEDS_polars_functions/tokenize.py @@ -0,0 +1,128 @@ +"""Functions for tokenizing MEDS datasets. + +Here, _tokenization_ refers specifically to the process of converting a longitudinal, irregularly sampled, +continuous time sequence into a temporal sequence at the level that will be consumed by deep-learning models. + +All these functions take in _normalized_ data -- meaning data where there are _no longer_ any code modifiers, +as those have been normalized alongside codes into integer indices (in the output code column). The only +columns of concern here thus are `patient_id`, `timestamp`, `code`, `numerical_value`. +""" + +import polars as pl + +SECONDS_PER_MINUTE = 60.0 +SECONDS_PER_HOUR = SECONDS_PER_MINUTE * 60.0 +SECONDS_PER_DAY = SECONDS_PER_HOUR * 24.0 + +def fill_to_nans(col: str | pl.Expr) -> pl.Expr: + """This function fills infinite and null values with NaN. + + This enables the downstream functions to naturally tensorize data into numpy or Torch tensors. + + Args: + col: The input column. + + Returns: + A `pl.Expr` object that fills infinite and null values with NaN. + + Examples: + >>> raise NotImplementedError + """ + + if isinstance(col, str): + col = pl.col(col) + + return ( + pl.when(col.is_infinite() | col.is_null()) + .then(float('nan')) + .otherwise(col) + .keep_name() + ) + +def split_static_and_dynamic(df: pl.LazyFrame) -> tuple[pl.LazyFrame, pl.LazyFrame]: + """This function splits the input data into static and dynamic data. + + Static data is data that has a null timestamp, and dynamic data is everything else. + + Args: + df: The input data. + + Returns: + A tuple of two `pl.LazyFrame` objects, the first being the static data and the second being the + dynamic data. + + Examples: + >>> raise NotImplementedError + """ + + static = df.filter(pl.col("timestamp").is_null()) + dynamic = df.filter(pl.col("timestamp").is_not_null()) + return static, dynamic + +def extract_statics_and_schema(df: pl.LazyFrame) -> pl.LazyFrame: + """This function extracts static data and schema information (sequence of patient unique timestamps). + + Args: + df: The input data. + + Returns: + A `pl.LazyFrame` object containing the static data and the unique timestamps of the patient, grouped + by patient as lists, in the same order as the patient IDs occurred in the original file. + """ + + static, dynamic = split_static_and_dynamic(df) + + # This collects static data by patient ID and stores only (as a list) the codes and numerical values. + static_by_patient = static.group_by("patient_id", maintain_order=True).agg("code", "numerical_value") + + # This collects the unique timestamps for each patient. + schema_by_patient = ( + dynamic + .group_by("patient_id", maintain_order=True) + .agg( + pl.col("timestamp").min().alias("start_time"), + pl.col("timestamp").unique(maintain_order=True) + ) + ) + + return ( + static_by_patient + .join(schema_by_patient, on="patient_id", how="inner") + .with_row_index("patient_offset") + ) + + +def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: + """This function extracts sequences of patient events, which are sequences of measurements. + + The result of this can be naturally tensorized into a `JointNestedRaggedTensorDict` object. + + Args: + df: The input data. + + Returns: + A `pl.LazyFrame` object containing the sequences of patient events, with the following columns: + - `patient_id`: The patient ID. + - `time_delta/days`: The time delta in days, as a list of floats (ragged). + - `code`: The code, as a list of lists of ints (ragged in both levels). + - `numerical_value`: The numerical value as a list of lists of floats (ragged in both levels). + + Examples: + >>> raise NotImplementedError + """ + + _, dynamic = split_static_and_dynamic(df) + + time_delta_days_expr = (pl.col("timestamp").diff().dt.total_seconds() / SECONDS_PER_DAY).cast(pl.Float64) + + return ( + dynamic + .group_by("patient_id", "timestamp", maintain_order=True) + .agg(fill_to_nans("code"), fill_to_nans("numerical_value")) + .group_by("patient_id", maintain_order=True) + .agg( + fill_to_nans(time_delta_days_expr).alias("time_delta/days"), + "code", + "numerical_value", + ) + ) From 188936b051fb8d89d2558d5efec4963b184d1269 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 16 Jun 2024 22:53:19 -0400 Subject: [PATCH 31/53] Fixed lint errors. --- configs/pytorch_dataset.yaml | 14 +- src/MEDS_polars_functions/pytorch_batch.py | 25 +++- src/MEDS_polars_functions/pytorch_dataset.py | 141 ++++++++++++++----- src/MEDS_polars_functions/tensorize.py | 8 +- src/MEDS_polars_functions/tokenize.py | 28 ++-- 5 files changed, 146 insertions(+), 70 deletions(-) diff --git a/configs/pytorch_dataset.yaml b/configs/pytorch_dataset.yaml index 3208c93..da76807 100644 --- a/configs/pytorch_dataset.yaml +++ b/configs/pytorch_dataset.yaml @@ -1,7 +1,19 @@ - MEDS_cohort_dir: ??? task_name: null +subsequence_sampling_strategy: random +seq_padding_side: right + +train_subset_size: null +train_subset_seed: 1 + +min_seq_len: null +max_seq_len: ??? + +do_include_patient_id: false +do_include_subsequence_indices: false +do_include_start_time_min: false + code_metadata_fp: ${MEDS_cohort_dir}/code_metadata.parquet split_shards_fp: ${MEDS_cohort_dir}/splits.json schema_files_root: ${MEDS_cohort_dir}/tokenize/schemas diff --git a/src/MEDS_polars_functions/pytorch_batch.py b/src/MEDS_polars_functions/pytorch_batch.py index 45404c1..727ea8d 100644 --- a/src/MEDS_polars_functions/pytorch_batch.py +++ b/src/MEDS_polars_functions/pytorch_batch.py @@ -1,4 +1,15 @@ -"""A pytorch batch object for ease of working with tensorized data. Curently **IN PROGRESS**""" +"""A pytorch batch object for ease of working with tensorized data. + +Currently **IN PROGRESS** +""" + +import dataclasses +from collections import defaultdict +from typing import Any + +import polars as pl +import torch + @dataclasses.dataclass class PytorchBatch: @@ -86,28 +97,28 @@ class PytorchBatch: @staticmethod def de_pad(L: list[int], *other_L) -> list[int] | tuple[list[int]]: """Filters down all passed lists to only the indices where the first arg is non-zero. - + Args: L: The list whose entries denote padding (0) or non-padding (non-zero). *other_L: Any other lists that should be de-padded in the same way as L. - + Examples: >>> de_pad([1, 3, 0, 4, 0, 0], [10, 0, 5, 8, 1, 0]) ([1, 3, 4], [10, 0, 8]) >>> de_pad([1, 3, 0, 4, 0, 0]) [1, 3, 4] """ - + out_L = [] out_other = [None if x is None else [] for x in other_L] - + for i, v in enumerate(L): if v != 0: out_L.append(v) for j, LL in enumerate(other_L): if LL is not None: out_other[j].append(LL[i]) - + if other_L: return tuple([out_L] + out_other) else: @@ -199,7 +210,7 @@ def _slice(self, index: tuple[int | slice] | int | slice) -> "PytorchBatch": time=None if self.time is None else self.time[batch_index, seq_index], ) - def __getitem__(self, item: str | tuple[int | slice]) -> Union[torch.Tensor, "PytorchBatch"]: + def __getitem__(self, item: str | tuple[int | slice]) -> torch.Tensor | "PytorchBatch": match item: case str(): return dataclasses.asdict(self)[item] diff --git a/src/MEDS_polars_functions/pytorch_dataset.py b/src/MEDS_polars_functions/pytorch_dataset.py index 60424f8..c730e26 100644 --- a/src/MEDS_polars_functions/pytorch_dataset.py +++ b/src/MEDS_polars_functions/pytorch_dataset.py @@ -1,7 +1,6 @@ - import json from collections import defaultdict -from pathlib import Path +from enum import StrEnum import numpy as np import polars as pl @@ -14,21 +13,92 @@ NP_UINT_TYPES, JointNestedRaggedTensorDict, ) +from omegaconf import DictConfig from tqdm.auto import tqdm -from ..utils import count_or_proportion -from .config import PytorchDatasetConfig, SeqPaddingSide, SubsequenceSamplingStrategy +PROPORTION = float +COUNT_OR_PROPORTION = int | PROPORTION +WHOLE = int | pl.Expr -import dataclasses -import enum -from collections import defaultdict -from typing import Any, Union +def count_or_proportion(N: WHOLE | None, cnt_or_prop: COUNT_OR_PROPORTION) -> int: + """Returns `cnt_or_prop` if it is an integer or `int(N*cnt_or_prop)` if it is a float. -import polars as pl -import torch + Resolves cutoff variables that can either be passed as integer counts or fractions of a whole. E.g., the + vocabulary should contain only elements that occur with count or proportion at least X, where X might be + 20 times, or 1%. -from omegaconf import DictConfig + Arguments: + N: The total number of elements in the whole. Only used if `cnt_or_prop` is a proportion (float). + cnt_or_prop: The cutoff value, either as an integer count or a proportion of the whole. + + Returns: + The cutoff value as an integer count of the whole. + + Raises: + TypeError: If `cnt_or_prop` is not an integer or a float or if `N` is needed and is not an integer or + a polars Expression. + ValueError: If `cnt_or_prop` is not a positive integer or a float between 0 and 1. + + Examples: + >>> count_or_proportion(100, 0.1) + 10 + >>> count_or_proportion(None, 11) + 11 + >>> count_or_proportion(100, 0.116) + 12 + >>> count_or_proportion(None, 0) + Traceback (most recent call last): + ... + ValueError: 0 must be positive if it is an integer + >>> count_or_proportion(None, 1.3) + Traceback (most recent call last): + ... + ValueError: 1.3 must be between 0 and 1 if it is a float + >>> count_or_proportion(None, "a") + Traceback (most recent call last): + ... + TypeError: a must be a positive integer or a float between 0 or 1 + >>> count_or_proportion("a", 0.2) + Traceback (most recent call last): + ... + TypeError: a must be an integer or a polars.Expr when cnt_or_prop is a float! + """ + + match cnt_or_prop: + case int() if 0 < cnt_or_prop: + return cnt_or_prop + case int(): + raise ValueError(f"{cnt_or_prop} must be positive if it is an integer") + case float() if 0 < cnt_or_prop < 1: + pass + case float(): + raise ValueError(f"{cnt_or_prop} must be between 0 and 1 if it is a float") + case _: + raise TypeError(f"{cnt_or_prop} must be a positive integer or a float between 0 or 1") + + match N: + case int(): + return int(round(cnt_or_prop * N)) + case pl.Expr(): + return (N * cnt_or_prop).round(0).cast(int) + case _: + raise TypeError(f"{N} must be an integer or a polars.Expr when cnt_or_prop is a float!") + + +class SubsequenceSamplingStrategy(StrEnum): + """An enumeration of the possible subsequence sampling strategies for the dataset.""" + + RANDOM = "random" + TO_END = "to_end" + FROM_START = "from_start" + + +class SeqPaddingSide(StrEnum): + """An enumeration of the possible sequence padding sides for the dataset.""" + + LEFT = "left" + RIGHT = "right" def to_int_index(col: pl.Expr) -> pl.Expr: @@ -115,6 +185,18 @@ def normalize_task(cls, col: pl.Expr, dtype: pl.DataType) -> tuple[str, pl.Expr] def __init__(self, cfg: DictConfig, split: str): super().__init__() + if cfg.subsequence_sampling_strategy not in SubsequenceSamplingStrategy: + raise ValueError( + f"Invalid subsequence sampling strategy {cfg.subsequence_sampling_strategy}! " + f"Valid options are {', '.join(SubsequenceSamplingStrategy.__members__)}" + ) + + if cfg.seq_padding_side not in SeqPaddingSide: + raise ValueError( + f"Invalid sequence padding side {cfg.seq_padding_side}! " + f"Valid options are {', '.join(SeqPaddingSide.__members__)}" + ) + self.config = cfg self.split = split @@ -128,7 +210,7 @@ def __init__(self, cfg: DictConfig, split: str): self.read_patient_descriptors() if self.config.min_seq_len is not None and self.config.min_seq_len > 1: - logger.info(f"Restricting to subjects with at least {config.min_seq_len} events") + logger.info(f"Restricting to subjects with at least {self.config.min_seq_len} events") self.filter_to_min_seq_len() if self.config.train_subset_size not in (None, "FULL") and self.split == "train": @@ -144,11 +226,6 @@ def read_shards(self): self.shards = {sp: subjs for sp, subjs in all_shards.items() if sp.startswith(f"{self.split}/")} self.subj_map = {subj: sp for sp, subjs in self.shards.items() for subj in subjs} - @property - def measurement_configs(self): - """Grabs the measurement configs from the config.""" - return self.config.measurement_configs - def read_patient_descriptors(self): """Reads the patient descriptors from the ESGPT or MEDS dataset.""" self.static_dfs = {} @@ -212,7 +289,7 @@ def read_patient_descriptors(self): while idx_col in task_df.columns: idx_col = f"_{idx_col}" - raise NotImplementedError("Need to figure out task constraints still" + raise NotImplementedError("Need to figure out task constraints still") task_df_joint = ( task_df.select("patient_id", "start_time", "end_time") @@ -226,9 +303,7 @@ def read_patient_descriptors(self): on="patient_id", how="left", ) - .with_columns( - pl.col("timestamp").alias("min_since_start") - ) + .with_columns(pl.col("timestamp").alias("min_since_start")) ) min_at_task_start = ( @@ -288,7 +363,7 @@ def get_task_info(self, task_df: pl.DataFrame): return {"tasks": sorted(self.tasks), "vocabs": self.task_vocabs, "types": self.task_types} def filter_to_min_seq_len(self): - """Filters the dataset to only include subjects with at least `config.min_seq_len` events.""" + """Filters the dataset to only include subjects with at least `self.config.min_seq_len` events.""" if self.has_task: logger.warning( f"Filtering task {self.config.task_name} to min_seq_len {self.config.min_seq_len}. " @@ -353,8 +428,8 @@ def set_inter_event_time_stats(self): f"Bad Subject IDs: {', '.join(str(x) for x in bad_patient_ids)}", f"Global min: {stats['min'].item()}", ] - if self.config.save_dir is not None: - fp = self.config.save_dir / f"malformed_data_{self.split}.parquet" + if self.config.MEDS_cohort_dir is not None: + fp = self.config.MEDS_cohort_dir / f"malformed_data_{self.split}.parquet" bad_inter_event_times.write_parquet(fp) warning_strs.append(f"Wrote malformed data records to {fp}") warning_strs.append("Removing malformed subjects") @@ -377,18 +452,10 @@ def __len__(self): def has_task(self) -> bool: return self.config.task_name is not None - @property - def seq_padding_side(self) -> SeqPaddingSide: - return self.config.seq_padding_side - @property def max_seq_len(self) -> int: return self.config.max_seq_len - @property - def is_subset_dataset(self) -> bool: - return self.config.train_subset_size != "FULL" - def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: """Returns a Returns a dictionary corresponding to a single subject's data. @@ -462,11 +529,9 @@ def _seeded_getitem(self, idx: int) -> dict[str, list[float]]: out["start_idx"] = st out["end_idx"] = end - out["dynamic"] = ( - JointNestedRaggedTensorDict.load_slice( - self.config.tensorized_root / f"{shard}.pt", patient_idx - )[st:end] - ) + out["dynamic"] = JointNestedRaggedTensorDict.load_slice( + self.config.tensorized_root / f"{shard}.pt", patient_idx + )[st:end] if self.config.do_include_start_time_min: out["start_time"] = static_row["start_time"] = static_row[ @@ -489,7 +554,7 @@ def __dynamic_only_collate(self, batch: list[dict[str, list[float]]]) -> dict: dense_collated = {} dynamic = JointNestedRaggedTensorDict.vstack([x["dynamic"] for x in batch]).to_dense( - padding_side=self.seq_padding_side + padding_side=self.config.seq_padding_side ) dynamic["event_mask"] = dynamic.pop("dim1/mask") dynamic["dynamic_values_mask"] = dynamic.pop("dim2/mask") & ~np.isnan(dynamic["dynamic_values"]) diff --git a/src/MEDS_polars_functions/tensorize.py b/src/MEDS_polars_functions/tensorize.py index a1b8961..f658924 100644 --- a/src/MEDS_polars_functions/tensorize.py +++ b/src/MEDS_polars_functions/tensorize.py @@ -3,8 +3,9 @@ TODO """ -from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict import polars as pl +from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict + def convert_to_NRT(tokenized_df: pl.LazyFrame) -> JointNestedRaggedTensorDict: """This converts a tokenized dataframe into a nested ragged tensor. @@ -38,8 +39,5 @@ def convert_to_NRT(tokenized_df: pl.LazyFrame) -> JointNestedRaggedTensorDict: time_delta_col = time_delta_cols[0] return JointNestedRaggedTensorDict( - tokenized_df.select( - time_delta_col, - "code", - "numerical_value").collect().to_dict(as_series=False) + tokenized_df.select(time_delta_col, "code", "numerical_value").collect().to_dict(as_series=False) ) diff --git a/src/MEDS_polars_functions/tokenize.py b/src/MEDS_polars_functions/tokenize.py index 17bfab7..1ab40e3 100644 --- a/src/MEDS_polars_functions/tokenize.py +++ b/src/MEDS_polars_functions/tokenize.py @@ -14,6 +14,7 @@ SECONDS_PER_HOUR = SECONDS_PER_MINUTE * 60.0 SECONDS_PER_DAY = SECONDS_PER_HOUR * 24.0 + def fill_to_nans(col: str | pl.Expr) -> pl.Expr: """This function fills infinite and null values with NaN. @@ -32,12 +33,8 @@ def fill_to_nans(col: str | pl.Expr) -> pl.Expr: if isinstance(col, str): col = pl.col(col) - return ( - pl.when(col.is_infinite() | col.is_null()) - .then(float('nan')) - .otherwise(col) - .keep_name() - ) + return pl.when(col.is_infinite() | col.is_null()).then(float("nan")).otherwise(col).keep_name() + def split_static_and_dynamic(df: pl.LazyFrame) -> tuple[pl.LazyFrame, pl.LazyFrame]: """This function splits the input data into static and dynamic data. @@ -59,6 +56,7 @@ def split_static_and_dynamic(df: pl.LazyFrame) -> tuple[pl.LazyFrame, pl.LazyFra dynamic = df.filter(pl.col("timestamp").is_not_null()) return static, dynamic + def extract_statics_and_schema(df: pl.LazyFrame) -> pl.LazyFrame: """This function extracts static data and schema information (sequence of patient unique timestamps). @@ -76,19 +74,12 @@ def extract_statics_and_schema(df: pl.LazyFrame) -> pl.LazyFrame: static_by_patient = static.group_by("patient_id", maintain_order=True).agg("code", "numerical_value") # This collects the unique timestamps for each patient. - schema_by_patient = ( - dynamic - .group_by("patient_id", maintain_order=True) - .agg( - pl.col("timestamp").min().alias("start_time"), - pl.col("timestamp").unique(maintain_order=True) - ) + schema_by_patient = dynamic.group_by("patient_id", maintain_order=True).agg( + pl.col("timestamp").min().alias("start_time"), pl.col("timestamp").unique(maintain_order=True) ) - return ( - static_by_patient - .join(schema_by_patient, on="patient_id", how="inner") - .with_row_index("patient_offset") + return static_by_patient.join(schema_by_patient, on="patient_id", how="inner").with_row_index( + "patient_offset" ) @@ -116,8 +107,7 @@ def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: time_delta_days_expr = (pl.col("timestamp").diff().dt.total_seconds() / SECONDS_PER_DAY).cast(pl.Float64) return ( - dynamic - .group_by("patient_id", "timestamp", maintain_order=True) + dynamic.group_by("patient_id", "timestamp", maintain_order=True) .agg(fill_to_nans("code"), fill_to_nans("numerical_value")) .group_by("patient_id", maintain_order=True) .agg( From 1286864128fdd30059a4e0f08a6983f4b4324045 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 16 Jun 2024 22:55:32 -0400 Subject: [PATCH 32/53] Moved pytorch dataset code. --- .../pytorch_batch.py | 0 .../pytorch_dataset.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename src/{MEDS_polars_functions => MEDS_pytorch_dataset}/pytorch_batch.py (100%) rename src/{MEDS_polars_functions => MEDS_pytorch_dataset}/pytorch_dataset.py (100%) diff --git a/src/MEDS_polars_functions/pytorch_batch.py b/src/MEDS_pytorch_dataset/pytorch_batch.py similarity index 100% rename from src/MEDS_polars_functions/pytorch_batch.py rename to src/MEDS_pytorch_dataset/pytorch_batch.py diff --git a/src/MEDS_polars_functions/pytorch_dataset.py b/src/MEDS_pytorch_dataset/pytorch_dataset.py similarity index 100% rename from src/MEDS_polars_functions/pytorch_dataset.py rename to src/MEDS_pytorch_dataset/pytorch_dataset.py From a4741feaa1960762d3d92944713e55abf8d3d461 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 16 Jun 2024 23:01:24 -0400 Subject: [PATCH 33/53] Added first test to tokenize. --- src/MEDS_polars_functions/tokenize.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/MEDS_polars_functions/tokenize.py b/src/MEDS_polars_functions/tokenize.py index 1ab40e3..9829fe7 100644 --- a/src/MEDS_polars_functions/tokenize.py +++ b/src/MEDS_polars_functions/tokenize.py @@ -27,7 +27,15 @@ def fill_to_nans(col: str | pl.Expr) -> pl.Expr: A `pl.Expr` object that fills infinite and null values with NaN. Examples: - >>> raise NotImplementedError + >>> print(fill_to_nans("value")) # doctest: +NORMALIZE_WHITESPACE + .when([(col("value").is_infinite()) | + (col("value").is_null())]).then(dyn float: NaN).otherwise(col("value")).name.keep() + >>> print(fill_to_nans(pl.col("time_delta"))) # doctest: +NORMALIZE_WHITESPACE + .when([(col("time_delta").is_infinite()) | + (col("time_delta").is_null())]).then(dyn float: NaN).otherwise(col("time_delta")).name.keep() + >>> df = pl.DataFrame({"value": [1.0, float("inf"), None, -float("inf"), 2.0]}) + >>> df.select(fill_to_nans("value"))["value"].to_list() + [1.0, nan, nan, nan, 2.0] """ if isinstance(col, str): From c67b50c683fdeec3551981e8386d67348dca9562 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 16 Jun 2024 23:04:22 -0400 Subject: [PATCH 34/53] Added a second test. --- src/MEDS_polars_functions/tokenize.py | 31 +++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/src/MEDS_polars_functions/tokenize.py b/src/MEDS_polars_functions/tokenize.py index 9829fe7..eaa3dd9 100644 --- a/src/MEDS_polars_functions/tokenize.py +++ b/src/MEDS_polars_functions/tokenize.py @@ -57,10 +57,37 @@ def split_static_and_dynamic(df: pl.LazyFrame) -> tuple[pl.LazyFrame, pl.LazyFra dynamic data. Examples: - >>> raise NotImplementedError + >>> from datetime import datetime + >>> df = pl.DataFrame({ + ... "patient_id": [1, 1, 2, 2], + ... "timestamp": [None, datetime(2021, 1, 1), None, datetime(2021, 1, 2)], + ... "code": [100, 101, 200, 201], + ... "numerical_value": [1.0, 2.0, 3.0, 4.0] + ... }).lazy() + >>> static, dynamic = split_static_and_dynamic(df) + >>> static.collect() + shape: (2, 3) + ┌────────────┬──────┬─────────────────┐ + │ patient_id ┆ code ┆ numerical_value │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ f64 │ + ╞════════════╪══════╪═════════════════╡ + │ 1 ┆ 100 ┆ 1.0 │ + │ 2 ┆ 200 ┆ 3.0 │ + └────────────┴──────┴─────────────────┘ + >>> dynamic.collect() + shape: (2, 4) + ┌────────────┬─────────────────────┬──────┬─────────────────┐ + │ patient_id ┆ timestamp ┆ code ┆ numerical_value │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ datetime[μs] ┆ i64 ┆ f64 │ + ╞════════════╪═════════════════════╪══════╪═════════════════╡ + │ 1 ┆ 2021-01-01 00:00:00 ┆ 101 ┆ 2.0 │ + │ 2 ┆ 2021-01-02 00:00:00 ┆ 201 ┆ 4.0 │ + └────────────┴─────────────────────┴──────┴─────────────────┘ """ - static = df.filter(pl.col("timestamp").is_null()) + static = df.filter(pl.col("timestamp").is_null()).drop("timestamp") dynamic = df.filter(pl.col("timestamp").is_not_null()) return static, dynamic From 501adbaa76ece4953d4ab39525ef587127e67bef Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 16 Jun 2024 23:25:26 -0400 Subject: [PATCH 35/53] Added the last of the tests. --- src/MEDS_polars_functions/tokenize.py | 69 +++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 9 deletions(-) diff --git a/src/MEDS_polars_functions/tokenize.py b/src/MEDS_polars_functions/tokenize.py index eaa3dd9..2c9be6a 100644 --- a/src/MEDS_polars_functions/tokenize.py +++ b/src/MEDS_polars_functions/tokenize.py @@ -29,19 +29,19 @@ def fill_to_nans(col: str | pl.Expr) -> pl.Expr: Examples: >>> print(fill_to_nans("value")) # doctest: +NORMALIZE_WHITESPACE .when([(col("value").is_infinite()) | - (col("value").is_null())]).then(dyn float: NaN).otherwise(col("value")).name.keep() + (col("value").is_null())]).then(dyn float: NaN).otherwise(col("value")) >>> print(fill_to_nans(pl.col("time_delta"))) # doctest: +NORMALIZE_WHITESPACE .when([(col("time_delta").is_infinite()) | - (col("time_delta").is_null())]).then(dyn float: NaN).otherwise(col("time_delta")).name.keep() + (col("time_delta").is_null())]).then(dyn float: NaN).otherwise(col("time_delta")) >>> df = pl.DataFrame({"value": [1.0, float("inf"), None, -float("inf"), 2.0]}) - >>> df.select(fill_to_nans("value"))["value"].to_list() + >>> df.select(fill_to_nans("value").alias("value"))["value"].to_list() [1.0, nan, nan, nan, 2.0] """ if isinstance(col, str): col = pl.col(col) - return pl.when(col.is_infinite() | col.is_null()).then(float("nan")).otherwise(col).keep_name() + return pl.when(col.is_infinite() | col.is_null()).then(float("nan")).otherwise(col) def split_static_and_dynamic(df: pl.LazyFrame) -> tuple[pl.LazyFrame, pl.LazyFrame]: @@ -101,6 +101,39 @@ def extract_statics_and_schema(df: pl.LazyFrame) -> pl.LazyFrame: Returns: A `pl.LazyFrame` object containing the static data and the unique timestamps of the patient, grouped by patient as lists, in the same order as the patient IDs occurred in the original file. + + Examples: + >>> from datetime import datetime + >>> df = pl.DataFrame({ + ... "patient_id": [1, 1, 1, 1, 2, 2, 2], + ... "timestamp": [ + ... None, datetime(2021, 1, 1), datetime(2021, 1, 1), datetime(2021, 1, 13), + ... None, datetime(2021, 1, 2), datetime(2021, 1, 2)], + ... "code": [100, 101, 102, 103, 200, 201, 202], + ... "numerical_value": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0] + ... }).lazy() + >>> df = extract_statics_and_schema(df).collect() + >>> df.drop("timestamp") + shape: (2, 4) + ┌────────────┬───────────┬─────────────────┬─────────────────────┐ + │ patient_id ┆ code ┆ numerical_value ┆ start_time │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ list[i64] ┆ list[f64] ┆ datetime[μs] │ + ╞════════════╪═══════════╪═════════════════╪═════════════════════╡ + │ 1 ┆ [100] ┆ [1.0] ┆ 2021-01-01 00:00:00 │ + │ 2 ┆ [200] ┆ [5.0] ┆ 2021-01-02 00:00:00 │ + └────────────┴───────────┴─────────────────┴─────────────────────┘ + >>> df.select("patient_id", "timestamp").explode("timestamp") + shape: (3, 2) + ┌────────────┬─────────────────────┐ + │ patient_id ┆ timestamp │ + │ --- ┆ --- │ + │ i64 ┆ datetime[μs] │ + ╞════════════╪═════════════════════╡ + │ 1 ┆ 2021-01-01 00:00:00 │ + │ 1 ┆ 2021-01-13 00:00:00 │ + │ 2 ┆ 2021-01-02 00:00:00 │ + └────────────┴─────────────────────┘ """ static, dynamic = split_static_and_dynamic(df) @@ -113,9 +146,9 @@ def extract_statics_and_schema(df: pl.LazyFrame) -> pl.LazyFrame: pl.col("timestamp").min().alias("start_time"), pl.col("timestamp").unique(maintain_order=True) ) - return static_by_patient.join(schema_by_patient, on="patient_id", how="inner").with_row_index( - "patient_offset" - ) + # TODO(mmd): Consider tracking patient offset explicitly here. + + return static_by_patient.join(schema_by_patient, on="patient_id", how="inner") def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: @@ -134,7 +167,25 @@ def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: - `numerical_value`: The numerical value as a list of lists of floats (ragged in both levels). Examples: - >>> raise NotImplementedError + >>> from datetime import datetime + >>> df = pl.DataFrame({ + ... "patient_id": [1, 1, 1, 1, 2, 2, 2], + ... "timestamp": [ + ... None, datetime(2021, 1, 1), datetime(2021, 1, 1), datetime(2021, 1, 13), + ... None, datetime(2021, 1, 2), datetime(2021, 1, 2)], + ... "code": [100, 101, 102, 103, 200, 201, 202], + ... "numerical_value": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0] + ... }).lazy() + >>> extract_seq_of_patient_events(df).collect() + shape: (2, 4) + ┌────────────┬─────────────────┬───────────────────────────┬─────────────────────┐ + │ patient_id ┆ time_delta/days ┆ code ┆ numerical_value │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ list[f64] ┆ list[list[f64]] ┆ list[list[f64]] │ + ╞════════════╪═════════════════╪═══════════════════════════╪═════════════════════╡ + │ 1 ┆ [NaN, 12.0] ┆ [[101.0, 102.0], [103.0]] ┆ [[2.0, 3.0], [4.0]] │ + │ 2 ┆ [NaN] ┆ [[201.0, 202.0]] ┆ [[6.0, 7.0]] │ + └────────────┴─────────────────┴───────────────────────────┴─────────────────────┘ """ _, dynamic = split_static_and_dynamic(df) @@ -143,7 +194,7 @@ def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: return ( dynamic.group_by("patient_id", "timestamp", maintain_order=True) - .agg(fill_to_nans("code"), fill_to_nans("numerical_value")) + .agg(fill_to_nans("code").keep_name(), fill_to_nans("numerical_value").keep_name()) .group_by("patient_id", maintain_order=True) .agg( fill_to_nans(time_delta_days_expr).alias("time_delta/days"), From 833adbf6f8245745a6ae96a68522ecbb2db1a069 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 16 Jun 2024 23:32:27 -0400 Subject: [PATCH 36/53] Tested tensorize as well --- src/MEDS_polars_functions/tensorize.py | 47 ++++++++++++++++++++++++-- src/MEDS_polars_functions/tokenize.py | 6 ++-- 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/src/MEDS_polars_functions/tensorize.py b/src/MEDS_polars_functions/tensorize.py index f658924..c0f3304 100644 --- a/src/MEDS_polars_functions/tensorize.py +++ b/src/MEDS_polars_functions/tensorize.py @@ -24,12 +24,55 @@ def convert_to_NRT(tokenized_df: pl.LazyFrame) -> JointNestedRaggedTensorDict: ValueError: If there are no time delta columns or if there are multiple time delta columns. Examples: - >>> raise NotImplementedError + >>> df = pl.DataFrame({ + ... "patient_id": [1, 2], + ... "time_delta_days": [[float("nan"), 12.0], [float("nan")]], + ... "code": [[[101.0, 102.0], [103.0]], [[201.0, 202.0]]], + ... "numerical_value": [[[2.0, 3.0], [4.0]], [[6.0, 7.0]]] + ... }) + >>> df + shape: (2, 4) + ┌────────────┬─────────────────┬───────────────────────────┬─────────────────────┐ + │ patient_id ┆ time_delta_days ┆ code ┆ numerical_value │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ list[f64] ┆ list[list[f64]] ┆ list[list[f64]] │ + ╞════════════╪═════════════════╪═══════════════════════════╪═════════════════════╡ + │ 1 ┆ [NaN, 12.0] ┆ [[101.0, 102.0], [103.0]] ┆ [[2.0, 3.0], [4.0]] │ + │ 2 ┆ [NaN] ┆ [[201.0, 202.0]] ┆ [[6.0, 7.0]] │ + └────────────┴─────────────────┴───────────────────────────┴─────────────────────┘ + >>> nrt = convert_to_NRT(df.lazy()) + >>> for k, v in nrt.to_dense().items(): + ... print(k) + ... print(v) + dim1/mask + [[ True True] + [ True False]] + time_delta_days + [[nan 12.] + [nan 0.]] + dim2/mask + [[[ True True] + [ True False]] + + [[ True True] + [False False]]] + numerical_value + [[[2. 3.] + [4. 0.]] + + [[6. 7.] + [0. 0.]]] + code + [[[101. 102.] + [103. 0.]] + + [[201. 202.] + [ 0. 0.]]] """ # There should only be one time delta column, but this ensures we catch it regardless of the unit of time # used to convert the time deltas, and that we verify there is only one such column. - time_delta_cols = [c for c in tokenized_df.columns if c.startswith("time_delta/")] + time_delta_cols = [c for c in tokenized_df.columns if c.startswith("time_delta_")] if len(time_delta_cols) == 0: raise ValueError("Expected at least one time delta column, found none") diff --git a/src/MEDS_polars_functions/tokenize.py b/src/MEDS_polars_functions/tokenize.py index 2c9be6a..7233ede 100644 --- a/src/MEDS_polars_functions/tokenize.py +++ b/src/MEDS_polars_functions/tokenize.py @@ -162,7 +162,7 @@ def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: Returns: A `pl.LazyFrame` object containing the sequences of patient events, with the following columns: - `patient_id`: The patient ID. - - `time_delta/days`: The time delta in days, as a list of floats (ragged). + - `time_delta_days`: The time delta in days, as a list of floats (ragged). - `code`: The code, as a list of lists of ints (ragged in both levels). - `numerical_value`: The numerical value as a list of lists of floats (ragged in both levels). @@ -179,7 +179,7 @@ def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: >>> extract_seq_of_patient_events(df).collect() shape: (2, 4) ┌────────────┬─────────────────┬───────────────────────────┬─────────────────────┐ - │ patient_id ┆ time_delta/days ┆ code ┆ numerical_value │ + │ patient_id ┆ time_delta_days ┆ code ┆ numerical_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ list[f64] ┆ list[list[f64]] ┆ list[list[f64]] │ ╞════════════╪═════════════════╪═══════════════════════════╪═════════════════════╡ @@ -197,7 +197,7 @@ def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: .agg(fill_to_nans("code").keep_name(), fill_to_nans("numerical_value").keep_name()) .group_by("patient_id", maintain_order=True) .agg( - fill_to_nans(time_delta_days_expr).alias("time_delta/days"), + fill_to_nans(time_delta_days_expr).alias("time_delta_days"), "code", "numerical_value", ) From 0509faafa71692eaa5db3089d1298ce1934e34de Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Mon, 17 Jun 2024 09:09:54 -0400 Subject: [PATCH 37/53] Some minor corrections for the pytorch dataset. --- src/MEDS_pytorch_dataset/pytorch_dataset.py | 42 +++++++-------------- 1 file changed, 13 insertions(+), 29 deletions(-) diff --git a/src/MEDS_pytorch_dataset/pytorch_dataset.py b/src/MEDS_pytorch_dataset/pytorch_dataset.py index c730e26..12652d6 100644 --- a/src/MEDS_pytorch_dataset/pytorch_dataset.py +++ b/src/MEDS_pytorch_dataset/pytorch_dataset.py @@ -16,12 +16,8 @@ from omegaconf import DictConfig from tqdm.auto import tqdm -PROPORTION = float -COUNT_OR_PROPORTION = int | PROPORTION -WHOLE = int | pl.Expr - -def count_or_proportion(N: WHOLE | None, cnt_or_prop: COUNT_OR_PROPORTION) -> int: +def count_or_proportion(N: int | pl.Expr | None, cnt_or_prop: int | float) -> int: """Returns `cnt_or_prop` if it is an integer or `int(N*cnt_or_prop)` if it is a float. Resolves cutoff variables that can either be passed as integer counts or fractions of a whole. E.g., the @@ -227,7 +223,7 @@ def read_shards(self): self.subj_map = {subj: sp for sp, subjs in self.shards.items() for subj in subjs} def read_patient_descriptors(self): - """Reads the patient descriptors from the ESGPT or MEDS dataset.""" + """Reads the patient schemas and static data.""" self.static_dfs = {} self.subj_indices = {} self.subj_seq_bounds = {} @@ -259,12 +255,6 @@ def read_patient_descriptors(self): self.subj_seq_bounds[subj] = (0, n_events) if self.has_task: - self.index = [(subj, *bounds) for subj, bounds in self.subj_seq_bounds.items()] - self.labels = {} - self.tasks = None - self.task_types = None - self.task_vocabs = None - else: task_df_fp = self.config.tasks_root / f"{self.config.task_name}.parquet" task_info_fp = self.config.tasks_root / f"{self.config.task_name}_info.json" @@ -303,28 +293,16 @@ def read_patient_descriptors(self): on="patient_id", how="left", ) - .with_columns(pl.col("timestamp").alias("min_since_start")) + .with_columns(pl.col("timestamp")) ) - min_at_task_start = ( - (pl.col("start_time") - pl.col("start_time_global")).dt.total_seconds() / 60 - ).alias("min_at_task_start") - min_at_task_end = ( - (pl.col("end_time") - pl.col("start_time_global")).dt.total_seconds() / 60 - ).alias("min_at_task_end") - - start_idx_expr = (pl.col("min_since_start").search_sorted(pl.col("min_at_task_start"))).alias( - "start_idx" - ) - end_idx_expr = (pl.col("min_since_start").search_sorted(pl.col("min_at_task_end"))).alias( - "end_idx" - ) + start_idx_expr = (pl.col("start_time").search_sorted(pl.col("timestamp"))).alias("start_idx") + end_idx_expr = (pl.col("end_time").search_sorted(pl.col("timestamp"))).alias("end_idx") task_df_joint = ( task_df_joint.explode(idx_col, "start_time", "end_time") - .with_columns(min_at_task_start, min_at_task_end) - .explode("min_since_start") - .group_by("patient_id", idx_col, "min_at_task_start", "min_at_task_end", maintain_order=True) + .explode("timestamp") + .group_by("patient_id", idx_col, "start_time", "end_time", maintain_order=True) .agg(start_idx_expr.first(), end_idx_expr.first()) .sort(by=idx_col, descending=False) ) @@ -335,6 +313,12 @@ def read_patient_descriptors(self): self.labels = {t: task_df.get_column(t).to_list() for t in self.tasks} self.index = list(zip(patient_ids, start_indices, end_indices)) + else: + self.index = [(subj, *bounds) for subj, bounds in self.subj_seq_bounds.items()] + self.labels = {} + self.tasks = None + self.task_types = None + self.task_vocabs = None def get_task_info(self, task_df: pl.DataFrame): """Gets the task information from the task dataframe.""" From 8fac9b8ed0524cb4065296612616ee287cbc9bba Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Mon, 17 Jun 2024 09:28:56 -0400 Subject: [PATCH 38/53] Removing pytorch dataset files as they are being moved to https://github.com/mmcdermott/MEDS_pytorch_dataset --- src/MEDS_pytorch_dataset/pytorch_batch.py | 763 -------------------- src/MEDS_pytorch_dataset/pytorch_dataset.py | 626 ---------------- 2 files changed, 1389 deletions(-) delete mode 100644 src/MEDS_pytorch_dataset/pytorch_batch.py delete mode 100644 src/MEDS_pytorch_dataset/pytorch_dataset.py diff --git a/src/MEDS_pytorch_dataset/pytorch_batch.py b/src/MEDS_pytorch_dataset/pytorch_batch.py deleted file mode 100644 index 727ea8d..0000000 --- a/src/MEDS_pytorch_dataset/pytorch_batch.py +++ /dev/null @@ -1,763 +0,0 @@ -"""A pytorch batch object for ease of working with tensorized data. - -Currently **IN PROGRESS** -""" - -import dataclasses -from collections import defaultdict -from typing import Any - -import polars as pl -import torch - - -@dataclasses.dataclass -class PytorchBatch: - """A dataclass representing a batch of event flow data for a Pytorch model. - - This class defines the data-output interface for deep learning models built off Event Flow GPT datasets. - It stores the underlying data in the batch in a set of tensors, and also exposes some helpful methods and - properties to simplify interacting with data. - - Attributes: - event_mask: A boolean tensor of shape (batch_size, sequence_length) indicating which events in the - batch are valid (i.e., which are not padding). - time_delta: A float tensor of shape (batch_size, sequence_length) indicating the time delta in minutes - between each event and the subsequent event in that subject's sequence in the batch. - time: A float tensor of shape (batch_size, sequence_length) indicating the time in minutes since the - start of the subject's sequence of each event in the batch. This is often left unset, as it is - generally redundant with `time_delta`. However, it is used in generation, when the batch is - truncated to use efficient caching so the raw time point can't be recovered from the time delta. - static_indices: A long tensor of shape (batch_size, n_static_data_elements) indicating the indices of - the static data elements observed for each subject in the batch. These are *unordered*; meaning - that the second dimension position of a given element in this tensor is not necessarily - meaningful. This is because the static data elements are sparsely encoded, so the indices are - sufficient to recover the original data even in an unordered form. Here, by "indices" we mean that - these are integer values indicating the index of the associated categorical vocabulary element - corresponding to this observation; e.g., if the static measurement records that the subject's eye - color is brown, then if the categorical measurement of ``eye_color/BROWN``` in the unified - vocabulary is at position 32, then the index for that observation would be 32. - static_measurement_indices: A long tensor of shape (batch_size, n_static_data_elements) indicating - which measurements the indices in `static_indices` correspond to. E.g., if there is a static data - element corresponding to race, then the value in `static_measurement_indices` at the associated - position would be an integer index corresponding to the race measurement overall, whereas the - index at the identical position in `static_indices` would be an integer index corresponding to the - specific race observed for the subject (e.g., "White", "Black", etc.). - dynamic_indices: A long tensor of shape (batch_size, sequence_length, n_data_elements) indicating the - indices of the dynamic data elements observed for each subject in the batch. These are - *unordered* in the last dimension, meaning that the third dimension position of a given element in - this tensor is not necessarily meaningful. This is because the dynamic data elements are sparsely - encoded, so the indices and values are sufficient to recover the original data even in an - unordered form. - dynamic_measurement_indices: A long tensor of shape (batch_size, sequence_length, n_data_elements) - indicating which measurements the indices in `dynamic_indices` correspond to, similar to the - `static_measurement_indices` attribute. - dynamic_values: A float tensor of shape (batch_size, sequence_length, n_data_elements) indicating the - numeric values associated with each dynamic data element in the `dynamic_indices` tensor. If no - value was recorded for a given dynamic data element, the value in this tensor will be zero. - dynamic_values_mask: A boolean tensor of shape (batch_size, sequence_length, n_data_elements) - indicating which values in the `dynamic_values` tensor were actually observed. - start_time: A float tensor of shape (batch_size,) indicating the start time in minutes since the epoch - of each subject's sequence in the batch. This is often unset, as it is only used in generation - when we may need to know the actual time of day of any generated event. - start_idx: A long tensor of shape (batch_size,) indicating the start index of the sampled sub-sequence - for each subject in the batch relative to their raw data. - end_idx: A long tensor of shape (batch_size,) indicating the end index of the sampled sub-sequence - for each subject in the batch relative to their raw data. - subject_id: A long tensor of shape (batch_size,) indicating the subject ID of each member of the - batch. - stream_labels: A dictionary mapping task names to label LongTensors of shape (batch_size,) providing - labels for the associated tasks for the sequences in the batch. Is only used during fine-tuning or - zero-shot evaluation runs. - """ - - event_mask: torch.BoolTensor | None = None - - # We track this instead of raw times as it is less likely to suffer from underflow errors. - time_delta: torch.FloatTensor | None = None - - # We don't often use this, but it is used in generation. - time: torch.FloatTensor | None = None - - static_indices: torch.LongTensor | None = None - static_measurement_indices: torch.LongTensor | None = None - - dynamic_indices: torch.LongTensor | None = None - dynamic_measurement_indices: torch.LongTensor | None = None - dynamic_values: torch.FloatTensor | None = None - dynamic_values_mask: torch.BoolTensor | None = None - - start_time: torch.FloatTensor | None = None - start_idx: torch.LongTensor | None = None - end_idx: torch.LongTensor | None = None - subject_id: torch.LongTensor | None = None - - stream_labels: dict[str, torch.FloatTensor | torch.LongTensor] | None = None - - @staticmethod - def de_pad(L: list[int], *other_L) -> list[int] | tuple[list[int]]: - """Filters down all passed lists to only the indices where the first arg is non-zero. - - Args: - L: The list whose entries denote padding (0) or non-padding (non-zero). - *other_L: Any other lists that should be de-padded in the same way as L. - - Examples: - >>> de_pad([1, 3, 0, 4, 0, 0], [10, 0, 5, 8, 1, 0]) - ([1, 3, 4], [10, 0, 8]) - >>> de_pad([1, 3, 0, 4, 0, 0]) - [1, 3, 4] - """ - - out_L = [] - out_other = [None if x is None else [] for x in other_L] - - for i, v in enumerate(L): - if v != 0: - out_L.append(v) - for j, LL in enumerate(other_L): - if LL is not None: - out_other[j].append(LL[i]) - - if other_L: - return tuple([out_L] + out_other) - else: - return out_L - - @property - def device(self) -> torch.device: - """Returns the device storing the tensors in this batch. - - Assumes all elements of the batch are on the same device. - """ - return self.event_mask.device - - @property - def batch_size(self) -> int: - """Returns the batch size of this batch. - - Assumes the batch has not been sliced from its initial configuration. - """ - return self.event_mask.shape[0] - - @property - def sequence_length(self) -> int: - """Returns the maximum sequence length of the sequences in this batch. - - Assumes the batch has not been sliced from its initial configuration. - """ - return self.event_mask.shape[1] - - @property - def n_data_elements(self) -> int: - """Returns the maximum number of dynamic data elements of the events in this batch. - - Assumes the batch has not been sliced from its initial configuration. - """ - return self.dynamic_indices.shape[2] - - @property - def n_static_data_elements(self) -> int: - """Returns the maximum number of static data elements of the subjects in this batch. - - Assumes the batch has not been sliced from its initial configuration. - """ - return self.static_indices.shape[1] - - def get(self, item: str, default: Any) -> Any: - """A dictionary like get method for this batch, by attribute name.""" - return getattr(self, item) if item in self.keys() else default - - def _slice(self, index: tuple[int | slice] | int | slice) -> "PytorchBatch": - if not isinstance(index, tuple): - index = (index,) - if len(index) == 0 or len(index) > 3: - raise ValueError(f"Invalid index {index} for PytorchBatch! Must be of length 1, 2, or 3.") - if any(not isinstance(i, (int, slice)) for i in index): - raise ValueError(f"Invalid index {index} for PytorchBatch! Can only consist of ints and slices.") - - batch_index = index[0] - seq_index = slice(None) - meas_index = slice(None) - - if len(index) > 1: - seq_index = index[1] - if len(index) > 2: - meas_index = index[2] - - return PytorchBatch( - event_mask=self.event_mask[batch_index, seq_index], - time_delta=self.time_delta[batch_index, seq_index], - static_indices=None if self.static_indices is None else self.static_indices[batch_index], - static_measurement_indices=( - None - if self.static_measurement_indices is None - else self.static_measurement_indices[batch_index] - ), - dynamic_indices=self.dynamic_indices[batch_index, seq_index, meas_index], - dynamic_measurement_indices=self.dynamic_measurement_indices[batch_index, seq_index, meas_index], - dynamic_values=self.dynamic_values[batch_index, seq_index, meas_index], - dynamic_values_mask=self.dynamic_values_mask[batch_index, seq_index, meas_index], - start_time=None if self.start_time is None else self.start_time[batch_index], - start_idx=None if self.start_idx is None else self.start_idx[batch_index], - end_idx=None if self.end_idx is None else self.end_idx[batch_index], - subject_id=None if self.subject_id is None else self.subject_id[batch_index], - stream_labels=( - None - if self.stream_labels is None - else {k: v[batch_index] for k, v in self.stream_labels.items()} - ), - time=None if self.time is None else self.time[batch_index, seq_index], - ) - - def __getitem__(self, item: str | tuple[int | slice]) -> torch.Tensor | "PytorchBatch": - match item: - case str(): - return dataclasses.asdict(self)[item] - case tuple() | int() | slice(): - return self._slice(item) - case _: - raise TypeError(f"Invalid type {type(item)} for {item} for indexing!") - - def __setitem__(self, item: str, val: torch.Tensor): - if not hasattr(self, item): - raise KeyError(f"Key {item} not found") - setattr(self, item, val) - - def __eq__(self, other: "PytorchBatch") -> bool: - """Checks for equality between self and other.""" - if self.keys() != other.keys(): - return False - - for k in self.keys(): - self_v = self[k] - other_v = other[k] - - if type(self_v) is not type(other_v): - return False - - match self_v: - case dict() if k == "stream_labels": - if self_v.keys() != other_v.keys(): - return False - for kk in self_v.keys(): - self_vv = self_v[kk] - other_vv = other_v[kk] - - if self_vv.shape != other_vv.shape: - return False - if (self_vv != other_vv).any(): - return False - - case torch.Tensor(): - if self_v.shape != other_v.shape: - return False - if (self_v != other_v).any(): - return False - case None if k in ("time", "stream_labels", "start_idx", "end_idx", "subject_id"): - if other_v is not None: - return False - case _: - raise ValueError(f"{k}: {type(self_v)} not supported in batch!") - return True - - def items(self): - """A dictionary like items` method for the elements of this batch, by attribute.""" - return dataclasses.asdict(self).items() - - def keys(self): - """A dictionary like keys method for the elements of this batch, by attribute.""" - return dataclasses.asdict(self).keys() - - def values(self): - """A dictionary like values method for the elements of this batch, by attribute.""" - return dataclasses.asdict(self).values() - - def last_sequence_element_unsqueezed(self) -> "PytorchBatch": - """Filters the batch down to just the last event, while retaining the same # of dims.""" - return self[:, -1:] - - def repeat_batch_elements(self, expand_size: int) -> "PytorchBatch": - """Repeats each batch element expand_size times in order. Used for generation. - - Args: - expand_size: The number of times each batch elements data should be repeated. - - Returns: A new PytorchBatch object with each batch element's data repeated expand_size times. - - Examples: - >>> import torch - >>> batch = PytorchBatch( - ... event_mask=torch.tensor([[True, True, True], [True, True, False]]), - ... time_delta=torch.tensor([[1.0, 2.0, 3.0], [1.0, 5.0, 0.0]]), - ... static_indices=torch.tensor([[0, 1], [1, 2]]), - ... static_measurement_indices=torch.tensor([[0, 1], [1, 1]]), - ... dynamic_indices=torch.tensor([[[0, 1], [1, 2], [2, 3]], [[0, 1], [1, 5], [0, 0]]]), - ... dynamic_measurement_indices=torch.tensor( - ... [[[0, 1], [1, 2], [2, 3]], [[0, 1], [1, 2], [0, 0]]] - ... ), - ... dynamic_values=torch.tensor( - ... [[[0.0, 1.0], [1.0, 2.0], [0, 0]], [[0.0, 1.0], [1.0, 0.0], [0, 0]]] - ... ), - ... dynamic_values_mask=torch.tensor([ - ... [[False, True], [True, True], [False, False]], - ... [[False, True], [True, False], [False, False]] - ... ]), - ... start_time=torch.tensor([0.0, 10.0]), - ... stream_labels={"a": torch.tensor([0, 1]), "b": torch.tensor([1, 2])}, - ... time=None, - ... ) - >>> repeated_batch = batch.repeat_batch_elements(2) - >>> for k, v in repeated_batch.items(): - ... print(k) - ... print(v) - event_mask - tensor([[ True, True, True], - [ True, True, True], - [ True, True, False], - [ True, True, False]]) - time_delta - tensor([[1., 2., 3.], - [1., 2., 3.], - [1., 5., 0.], - [1., 5., 0.]]) - time - None - static_indices - tensor([[0, 1], - [0, 1], - [1, 2], - [1, 2]]) - static_measurement_indices - tensor([[0, 1], - [0, 1], - [1, 1], - [1, 1]]) - dynamic_indices - tensor([[[0, 1], - [1, 2], - [2, 3]], - - [[0, 1], - [1, 2], - [2, 3]], - - [[0, 1], - [1, 5], - [0, 0]], - - [[0, 1], - [1, 5], - [0, 0]]]) - dynamic_measurement_indices - tensor([[[0, 1], - [1, 2], - [2, 3]], - - [[0, 1], - [1, 2], - [2, 3]], - - [[0, 1], - [1, 2], - [0, 0]], - - [[0, 1], - [1, 2], - [0, 0]]]) - dynamic_values - tensor([[[0., 1.], - [1., 2.], - [0., 0.]], - - [[0., 1.], - [1., 2.], - [0., 0.]], - - [[0., 1.], - [1., 0.], - [0., 0.]], - - [[0., 1.], - [1., 0.], - [0., 0.]]]) - dynamic_values_mask - tensor([[[False, True], - [ True, True], - [False, False]], - - [[False, True], - [ True, True], - [False, False]], - - [[False, True], - [ True, False], - [False, False]], - - [[False, True], - [ True, False], - [False, False]]]) - start_time - tensor([ 0., 0., 10., 10.]) - start_idx - None - end_idx - None - subject_id - None - stream_labels - {'a': tensor([0, 0, 1, 1]), 'b': tensor([1, 1, 2, 2])} - """ - - expanded_return_idx = ( - torch.arange(self.batch_size).view(-1, 1).repeat(1, expand_size).view(-1).to(self.device) - ) - - out_batch = {} - - for k, v in self.items(): - match v: - case dict(): - out_batch[k] = {kk: vv.index_select(0, expanded_return_idx) for kk, vv in v.items()} - case torch.Tensor(): - out_batch[k] = v.index_select(0, expanded_return_idx) - case None if k in ("time", "stream_labels", "start_idx", "end_idx", "subject_id"): - out_batch[k] = None - case _: - raise TypeError(f"{k}: {type(v)} not supported in batch for generation!") - - return PytorchBatch(**out_batch) - - def split_repeated_batch(self, n_splits: int) -> list["PytorchBatch"]: - """Split a batch into a list of batches by chunking batch elements into groups. - - This is the inverse of `PytorchBatch.repeat_batch_elements`. It is used for taking a generated batch - that has been expanded and splitting it into separate list elements with independent generations for - each batch element in the original batch. - - Args: - n_splits: The number of splits to make. - - Returns: A list of length `n_splits` of PytorchBatch objects, such that the list element i contains - batch elements [i, i+self.batch_size/n_splits). - - Raises: - ValueError: if `n_splits` is not a positive integer divisor of `self.batch_size`. - - Examples: - >>> import torch - >>> batch = PytorchBatch( - ... event_mask=torch.tensor([ - ... [True, True, True], - ... [True, True, False], - ... [True, False, False], - ... [False, False, False] - ... ]), - ... time_delta=torch.tensor([ - ... [1.0, 2.0, 3.0], - ... [1.0, 5.0, 0.0], - ... [2.3, 0.0, 0.0], - ... [0.0, 0.0, 0.0], - ... ]), - ... static_indices=torch.tensor([[0, 1], [1, 2], [1, 3], [0, 5]]), - ... static_measurement_indices=torch.tensor([[0, 1], [1, 1], [1, 1], [0, 2]]), - ... dynamic_indices=torch.tensor([ - ... [[0, 1], [1, 2], [2, 3]], - ... [[0, 1], [1, 5], [0, 0]], - ... [[0, 2], [0, 0], [0, 0]], - ... [[0, 0], [0, 0], [0, 0]], - ... ]), - ... dynamic_measurement_indices=torch.tensor([ - ... [[0, 1], [1, 2], [2, 3]], - ... [[0, 1], [1, 2], [0, 0]], - ... [[0, 2], [0, 0], [0, 0]], - ... [[0, 0], [0, 0], [0, 0]], - ... ]), - ... dynamic_values=torch.tensor([ - ... [[0.0, 1.0], [1.0, 2.0], [0.0, 0.0]], - ... [[0.0, 1.0], [1.0, 0.0], [0.0, 0.0]], - ... [[0.0, 1.0], [0.0, 0.0], [0.0, 0.0]], - ... [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], - ... ]), - ... dynamic_values_mask=torch.tensor([ - ... [[False, True], [True, True], [False, False]], - ... [[False, True], [True, False], [False, False]], - ... [[False, True], [False, False], [False, False]], - ... [[False, False], [False, False], [False, False]], - ... ]), - ... start_time=torch.tensor([0.0, 10.0, 3.0, 2.2]), - ... stream_labels={"a": torch.tensor([0, 1, 0, 1]), "b": torch.tensor([1, 2, 4, 3])}, - ... time=None, - ... ) - >>> batch.split_repeated_batch(3) - Traceback (most recent call last): - ... - ValueError: n_splits (3) must be a positive integer divisor of batch_size (4) - >>> for i, T in enumerate(batch.split_repeated_batch(2)): - ... print(f"Returned batch {i}:") - ... for k, v in T.items(): - ... print(k) - ... print(v) - Returned batch 0: - event_mask - tensor([[ True, True, True], - [ True, False, False]]) - time_delta - tensor([[1.0000, 2.0000, 3.0000], - [2.3000, 0.0000, 0.0000]]) - time - None - static_indices - tensor([[0, 1], - [1, 3]]) - static_measurement_indices - tensor([[0, 1], - [1, 1]]) - dynamic_indices - tensor([[[0, 1], - [1, 2], - [2, 3]], - - [[0, 2], - [0, 0], - [0, 0]]]) - dynamic_measurement_indices - tensor([[[0, 1], - [1, 2], - [2, 3]], - - [[0, 2], - [0, 0], - [0, 0]]]) - dynamic_values - tensor([[[0., 1.], - [1., 2.], - [0., 0.]], - - [[0., 1.], - [0., 0.], - [0., 0.]]]) - dynamic_values_mask - tensor([[[False, True], - [ True, True], - [False, False]], - - [[False, True], - [False, False], - [False, False]]]) - start_time - tensor([0., 3.]) - start_idx - None - end_idx - None - subject_id - None - stream_labels - {'a': tensor([0, 0]), 'b': tensor([1, 4])} - Returned batch 1: - event_mask - tensor([[ True, True, False], - [False, False, False]]) - time_delta - tensor([[1., 5., 0.], - [0., 0., 0.]]) - time - None - static_indices - tensor([[1, 2], - [0, 5]]) - static_measurement_indices - tensor([[1, 1], - [0, 2]]) - dynamic_indices - tensor([[[0, 1], - [1, 5], - [0, 0]], - - [[0, 0], - [0, 0], - [0, 0]]]) - dynamic_measurement_indices - tensor([[[0, 1], - [1, 2], - [0, 0]], - - [[0, 0], - [0, 0], - [0, 0]]]) - dynamic_values - tensor([[[0., 1.], - [1., 0.], - [0., 0.]], - - [[0., 0.], - [0., 0.], - [0., 0.]]]) - dynamic_values_mask - tensor([[[False, True], - [ True, False], - [False, False]], - - [[False, False], - [False, False], - [False, False]]]) - start_time - tensor([10.0000, 2.2000]) - start_idx - None - end_idx - None - subject_id - None - stream_labels - {'a': tensor([1, 1]), 'b': tensor([2, 3])} - >>> repeat_batch = batch.repeat_batch_elements(5) - >>> split_batches = repeat_batch.split_repeated_batch(5) - >>> for i, v in enumerate(split_batches): - ... assert v == batch, f"Batch {i} ({v}) not equal to original batch {batch}!" - """ - - if not isinstance(n_splits, int) or n_splits <= 0 or self.batch_size % n_splits != 0: - raise ValueError( - f"n_splits ({n_splits}) must be a positive integer divisor of batch_size ({self.batch_size})" - ) - - self.batch_size // n_splits - out_batches = [defaultdict(dict) for _ in range(n_splits)] - for k, v in self.items(): - match v: - case dict(): - for kk, vv in v.items(): - reshaped = vv.reshape(vv.shape[0] // n_splits, n_splits, *vv.shape[1:]) - for i in range(n_splits): - out_batches[i][k][kk] = reshaped[:, i, ...] - case torch.Tensor(): - reshaped = v.reshape(v.shape[0] // n_splits, n_splits, *v.shape[1:]) - for i in range(n_splits): - out_batches[i][k] = reshaped[:, i, ...] - case None if k in ("time", "stream_labels", "start_idx", "end_idx", "subject_id"): - pass - case _: - raise TypeError(f"{k}: {type(v)} not supported in batch for generation!") - - return [PytorchBatch(**B) for B in out_batches] - - def convert_to_DL_DF(self) -> pl.DataFrame: - """Converts the batch data into a sparse DataFrame representation. - - Examples: - >>> import torch - >>> batch = PytorchBatch( - ... event_mask=torch.tensor([ - ... [True, True, True], - ... [True, True, False], - ... [True, False, False], - ... [False, False, False] - ... ]), - ... time_delta=torch.tensor([ - ... [1.0, 2.0, 3.0], - ... [1.0, 5.0, 0.0], - ... [2.3, 0.0, 0.0], - ... [0.0, 0.0, 0.0], - ... ]), - ... static_indices=torch.tensor([[0, 1], [1, 2], [1, 3], [0, 5]]), - ... static_measurement_indices=torch.tensor([[0, 1], [1, 1], [1, 1], [0, 2]]), - ... dynamic_indices=torch.tensor([ - ... [[0, 1], [1, 2], [2, 3]], - ... [[0, 1], [1, 5], [0, 0]], - ... [[0, 2], [0, 0], [0, 0]], - ... [[0, 0], [0, 0], [0, 0]], - ... ]), - ... dynamic_measurement_indices=torch.tensor([ - ... [[0, 1], [1, 2], [2, 3]], - ... [[0, 1], [1, 2], [0, 0]], - ... [[0, 2], [0, 0], [0, 0]], - ... [[0, 0], [0, 0], [0, 0]], - ... ]), - ... dynamic_values=torch.tensor([ - ... [[0.0, 1.0], [1.0, 2.0], [0.0, 0.0]], - ... [[0.0, 1.0], [1.0, 0.0], [0.0, 0.0]], - ... [[0.0, 1.0], [0.0, 0.0], [0.0, 0.0]], - ... [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], - ... ]), - ... dynamic_values_mask=torch.tensor([ - ... [[False, True], [True, True], [False, False]], - ... [[False, True], [True, False], [False, False]], - ... [[False, True], [False, False], [False, False]], - ... [[False, False], [False, False], [False, False]], - ... ]), - ... start_time=torch.tensor([0.0, 10.0, 3.0, 2.2]), - ... stream_labels={"a": torch.tensor([0, 1, 0, 1]), "b": torch.tensor([1, 2, 4, 3])}, - ... time=None, - ... ) - >>> pl.Config.set_tbl_width_chars(80) - - >>> batch.convert_to_DL_DF() - shape: (4, 7) - ┌───────────┬───────────┬──────────┬──────────┬──────────┬──────────┬──────────┐ - │ time_delt ┆ static_in ┆ static_m ┆ dynamic_ ┆ dynamic_ ┆ dynamic_ ┆ start_ti │ - │ a ┆ dices ┆ easureme ┆ indices ┆ measurem ┆ values ┆ me │ - │ --- ┆ --- ┆ nt_indic ┆ --- ┆ ent_indi ┆ --- ┆ --- │ - │ list[f64] ┆ list[f64] ┆ es ┆ list[lis ┆ ces ┆ list[lis ┆ f64 │ - │ ┆ ┆ --- ┆ t[f64]] ┆ --- ┆ t[f64]] ┆ │ - │ ┆ ┆ list[f64 ┆ ┆ list[lis ┆ ┆ │ - │ ┆ ┆ ] ┆ ┆ t[f64]] ┆ ┆ │ - ╞═══════════╪═══════════╪══════════╪══════════╪══════════╪══════════╪══════════╡ - │ [1.0, ┆ [1.0] ┆ [1.0] ┆ [[1.0], ┆ [[1.0], ┆ [[1.0], ┆ 0.0 │ - │ 2.0, 3.0] ┆ ┆ ┆ [1.0, ┆ [1.0, ┆ [1.0, ┆ │ - │ ┆ ┆ ┆ 2.0], ┆ 2.0], ┆ 2.0], ┆ │ - │ ┆ ┆ ┆ [2.0, ┆ [2.0, ┆ [null, ┆ │ - │ ┆ ┆ ┆ 3.0]… ┆ 3.0]… ┆ nul… ┆ │ - │ [1.0, ┆ [1.0, ┆ [1.0, ┆ [[1.0], ┆ [[1.0], ┆ [[1.0], ┆ 10.0 │ - │ 5.0] ┆ 2.0] ┆ 1.0] ┆ [1.0, ┆ [1.0, ┆ [1.0, ┆ │ - │ ┆ ┆ ┆ 5.0]] ┆ 2.0]] ┆ null]] ┆ │ - │ [2.3] ┆ [1.0, ┆ [1.0, ┆ [[2.0]] ┆ [[2.0]] ┆ [[1.0]] ┆ 3.0 │ - │ ┆ 3.0] ┆ 1.0] ┆ ┆ ┆ ┆ │ - │ [] ┆ [5.0] ┆ [2.0] ┆ [] ┆ [] ┆ [] ┆ 2.2 │ - └───────────┴───────────┴──────────┴──────────┴──────────┴──────────┴──────────┘ - """ - - df = { - k: [] - for k, v in self.items() - if k not in ("stream_labels", "event_mask", "dynamic_values_mask") and v is not None - } - - for k in ("start_time", "subject_id", "start_idx", "end_idx"): - if self[k] is not None: - df[k] = list(self[k]) - - for i in range(self.batch_size): - idx, measurement_idx = self.de_pad(self.static_indices[i], self.static_measurement_indices[i]) - df["static_indices"].append(idx) - df["static_measurement_indices"].append(measurement_idx) - - _, time_delta, time, idx, measurement_idx, vals, vals_mask = self.de_pad( - self.event_mask[i], - None if self.time_delta is None else self.time_delta[i], - None if self.time is None else self.time[i], - self.dynamic_indices[i], - self.dynamic_measurement_indices[i], - self.dynamic_values[i], - self.dynamic_values_mask[i], - ) - - if time_delta is not None: - df["time_delta"].append(time_delta) - if time is not None: - df["time"].append(time) - - names = ("dynamic_indices", "dynamic_measurement_indices", "dynamic_values") - for n in names: - df[n].append([]) - - for j in range(len(idx)): - de_padded_vals = self.de_pad(idx[j], measurement_idx[j], vals[j], vals_mask[j]) - # Now we add the indices and measurement indices - for n, v in zip(names[:-1], de_padded_vals[:-2]): - df[n][i].append(v) - - df["dynamic_values"][i].append([None if not m else v for v, m in zip(*de_padded_vals[-2:])]) - - return pl.DataFrame(df) diff --git a/src/MEDS_pytorch_dataset/pytorch_dataset.py b/src/MEDS_pytorch_dataset/pytorch_dataset.py deleted file mode 100644 index 12652d6..0000000 --- a/src/MEDS_pytorch_dataset/pytorch_dataset.py +++ /dev/null @@ -1,626 +0,0 @@ -import json -from collections import defaultdict -from enum import StrEnum - -import numpy as np -import polars as pl -import torch -from loguru import logger -from mixins import SeedableMixin -from nested_ragged_tensors.ragged_numpy import ( - NP_FLOAT_TYPES, - NP_INT_TYPES, - NP_UINT_TYPES, - JointNestedRaggedTensorDict, -) -from omegaconf import DictConfig -from tqdm.auto import tqdm - - -def count_or_proportion(N: int | pl.Expr | None, cnt_or_prop: int | float) -> int: - """Returns `cnt_or_prop` if it is an integer or `int(N*cnt_or_prop)` if it is a float. - - Resolves cutoff variables that can either be passed as integer counts or fractions of a whole. E.g., the - vocabulary should contain only elements that occur with count or proportion at least X, where X might be - 20 times, or 1%. - - Arguments: - N: The total number of elements in the whole. Only used if `cnt_or_prop` is a proportion (float). - cnt_or_prop: The cutoff value, either as an integer count or a proportion of the whole. - - Returns: - The cutoff value as an integer count of the whole. - - Raises: - TypeError: If `cnt_or_prop` is not an integer or a float or if `N` is needed and is not an integer or - a polars Expression. - ValueError: If `cnt_or_prop` is not a positive integer or a float between 0 and 1. - - Examples: - >>> count_or_proportion(100, 0.1) - 10 - >>> count_or_proportion(None, 11) - 11 - >>> count_or_proportion(100, 0.116) - 12 - >>> count_or_proportion(None, 0) - Traceback (most recent call last): - ... - ValueError: 0 must be positive if it is an integer - >>> count_or_proportion(None, 1.3) - Traceback (most recent call last): - ... - ValueError: 1.3 must be between 0 and 1 if it is a float - >>> count_or_proportion(None, "a") - Traceback (most recent call last): - ... - TypeError: a must be a positive integer or a float between 0 or 1 - >>> count_or_proportion("a", 0.2) - Traceback (most recent call last): - ... - TypeError: a must be an integer or a polars.Expr when cnt_or_prop is a float! - """ - - match cnt_or_prop: - case int() if 0 < cnt_or_prop: - return cnt_or_prop - case int(): - raise ValueError(f"{cnt_or_prop} must be positive if it is an integer") - case float() if 0 < cnt_or_prop < 1: - pass - case float(): - raise ValueError(f"{cnt_or_prop} must be between 0 and 1 if it is a float") - case _: - raise TypeError(f"{cnt_or_prop} must be a positive integer or a float between 0 or 1") - - match N: - case int(): - return int(round(cnt_or_prop * N)) - case pl.Expr(): - return (N * cnt_or_prop).round(0).cast(int) - case _: - raise TypeError(f"{N} must be an integer or a polars.Expr when cnt_or_prop is a float!") - - -class SubsequenceSamplingStrategy(StrEnum): - """An enumeration of the possible subsequence sampling strategies for the dataset.""" - - RANDOM = "random" - TO_END = "to_end" - FROM_START = "from_start" - - -class SeqPaddingSide(StrEnum): - """An enumeration of the possible sequence padding sides for the dataset.""" - - LEFT = "left" - RIGHT = "right" - - -def to_int_index(col: pl.Expr) -> pl.Expr: - """Returns an integer index of the unique elements seen in this column. - - The returned index is into a vocabulary sorted lexographically. - - Args: - col: The column containing the data to be converted into integer indices. - - Examples: - >>> import polars as pl - >>> X = pl.DataFrame({ - ... 'c': ['foo', 'bar', 'foo', 'bar', 'baz', None, 'bar', 'aba'], - ... 'd': [1, 2, 3, 4, 5, 6, 7, 8] - ... }) - >>> X.with_columns(to_int_index(pl.col('c'))) - shape: (8, 2) - ┌──────┬─────┐ - │ c ┆ d │ - │ --- ┆ --- │ - │ u32 ┆ i64 │ - ╞══════╪═════╡ - │ 4 ┆ 1 │ - │ 1 ┆ 2 │ - │ 4 ┆ 3 │ - │ 1 ┆ 4 │ - │ 2 ┆ 5 │ - │ null ┆ 6 │ - │ 1 ┆ 7 │ - │ 0 ┆ 8 │ - └──────┴─────┘ - """ - - indices = col.unique(maintain_order=True).drop_nulls().search_sorted(col) - return pl.when(col.is_null()).then(pl.lit(None)).otherwise(indices).alias(col.meta.output_name()) - - -class PytorchDataset(SeedableMixin, torch.utils.data.Dataset): - """A PyTorch Dataset class. - - Args: - config: Configuration options for the dataset, in an `omegaconf.DictConfig` object. - split: The split of data which should be used in this dataset (e.g., ``'train'``, ``'tuning'``, - ``'held_out'``). This will dictate where the system looks for files. - """ - - TYPE_CHECKERS = { - "multi_class_classification": [ - ( - {pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64, pl.Int8, pl.Int16, pl.Int32, pl.Int64}, - None, - ), - ({pl.Categorical(ordering="physical"), pl.Categorical(ordering="lexical")}, to_int_index), - ({pl.Utf8}, to_int_index), - ], - "binary_classification": [({pl.Boolean}, lambda Y: Y.cast(pl.Float32))], - "regression": [({pl.Float32, pl.Float64}, None)], - } - """Type checker and conversion parameters for labeled datasets.""" - - @classmethod - def normalize_task(cls, col: pl.Expr, dtype: pl.DataType) -> tuple[str, pl.Expr]: - """Normalizes the task labels in `col` of dtype `dtype` to a common format. - - Args: - col: The column containing the task labels, in polars expression format. - dtype: The polars data type of the task labels. - - Returns: - The task type (a string key into the `TYPE_CHECKERS` dictionary) and the normalized column - expression. - - Raises: - TypeError: If the task labels are not of a supported type. - """ - for task_type, checkers in cls.TYPE_CHECKERS.items(): - for valid_dtypes, normalize_fn in checkers: - if dtype in valid_dtypes: - return task_type, (col if normalize_fn is None else normalize_fn(col)) - - raise TypeError(f"Can't process label of {dtype} type!") - - def __init__(self, cfg: DictConfig, split: str): - super().__init__() - - if cfg.subsequence_sampling_strategy not in SubsequenceSamplingStrategy: - raise ValueError( - f"Invalid subsequence sampling strategy {cfg.subsequence_sampling_strategy}! " - f"Valid options are {', '.join(SubsequenceSamplingStrategy.__members__)}" - ) - - if cfg.seq_padding_side not in SeqPaddingSide: - raise ValueError( - f"Invalid sequence padding side {cfg.seq_padding_side}! " - f"Valid options are {', '.join(SeqPaddingSide.__members__)}" - ) - - self.config = cfg - self.split = split - - logger.info("Scanning code metadata") - self.code_metadata = pl.scan_parquet(self.config.code_metadata_fp) - - logger.info("Reading splits & patient shards") - self.read_shards() - - logger.info("Reading patient descriptors") - self.read_patient_descriptors() - - if self.config.min_seq_len is not None and self.config.min_seq_len > 1: - logger.info(f"Restricting to subjects with at least {self.config.min_seq_len} events") - self.filter_to_min_seq_len() - - if self.config.train_subset_size not in (None, "FULL") and self.split == "train": - logger.info(f"Filtering training subset size to {self.config.train_subset_size}") - self.filter_to_subset() - - self.set_inter_event_time_stats() - - def read_shards(self): - """Reads the split-specific patient shards from the ESGPT or MEDS dataset.""" - shards_fp = self.config.split_shards_fp - all_shards = json.loads(shards_fp.read_text()) - self.shards = {sp: subjs for sp, subjs in all_shards.items() if sp.startswith(f"{self.split}/")} - self.subj_map = {subj: sp for sp, subjs in self.shards.items() for subj in subjs} - - def read_patient_descriptors(self): - """Reads the patient schemas and static data.""" - self.static_dfs = {} - self.subj_indices = {} - self.subj_seq_bounds = {} - - shards = tqdm(self.shards.keys(), total=len(self.shards), desc="Reading static shards", leave=False) - for shard in shards: - static_fp = self.config.schema_files_root / f"{shard}.parquet" - df = pl.read_parquet( - static_fp, - columns=[ - "patient_id", - "start_time", - pl.col("code").alias("static_indices"), - pl.col("numerical_value").alias("static_values"), - "timestamp", - "patient_offset", - ], - use_pyarrow=True, - ) - - self.static_dfs[shard] = df - patient_ids = df["patient_id"] - n_events = df.select(pl.col("timestamp").list.lengths().alias("n_events")).get_column("n_events") - for i, (subj, n_events) in enumerate(zip(patient_ids, n_events)): - if subj in self.subj_indices or subj in self.subj_seq_bounds: - raise ValueError(f"Duplicate subject {subj} in {shard}!") - - self.subj_indices[subj] = i - self.subj_seq_bounds[subj] = (0, n_events) - - if self.has_task: - task_df_fp = self.config.tasks_root / f"{self.config.task_name}.parquet" - task_info_fp = self.config.tasks_root / f"{self.config.task_name}_info.json" - - logger.info(f"Reading task constraints for {self.config.task_name} from {task_df_fp}") - task_df = pl.read_parquet(task_df_fp, use_pyarrow=True) - - task_info = self.get_task_info(task_df) - - if task_info_fp.is_file(): - loaded_task_info = json.loads(task_info_fp.read_text()) - if loaded_task_info != task_info: - raise ValueError( - f"Task info differs from on disk!\nDisk:\n{loaded_task_info}\n" - f"Local:\n{task_info}\nSplit: {self.split}" - ) - logger.info(f"Re-built existing {task_info_fp} and it matches.") - else: - task_info_fp.parent.mkdir(exist_ok=True, parents=True) - task_info_fp.write_text(json.dumps(task_info)) - - idx_col = "_row_index" - while idx_col in task_df.columns: - idx_col = f"_{idx_col}" - - raise NotImplementedError("Need to figure out task constraints still") - - task_df_joint = ( - task_df.select("patient_id", "start_time", "end_time") - .with_row_index(idx_col) - .group_by("patient_id") - .agg("start_time", "end_time", idx_col) - .join( - pl.concat(self.static_dfs.values()).select( - "patient_id", pl.col("start_time").alias("start_time_global"), "time_delta" - ), - on="patient_id", - how="left", - ) - .with_columns(pl.col("timestamp")) - ) - - start_idx_expr = (pl.col("start_time").search_sorted(pl.col("timestamp"))).alias("start_idx") - end_idx_expr = (pl.col("end_time").search_sorted(pl.col("timestamp"))).alias("end_idx") - - task_df_joint = ( - task_df_joint.explode(idx_col, "start_time", "end_time") - .explode("timestamp") - .group_by("patient_id", idx_col, "start_time", "end_time", maintain_order=True) - .agg(start_idx_expr.first(), end_idx_expr.first()) - .sort(by=idx_col, descending=False) - ) - - patient_ids = task_df_joint["patient_id"] - start_indices = task_df_joint["start_idx"] - end_indices = task_df_joint["end_idx"] - - self.labels = {t: task_df.get_column(t).to_list() for t in self.tasks} - self.index = list(zip(patient_ids, start_indices, end_indices)) - else: - self.index = [(subj, *bounds) for subj, bounds in self.subj_seq_bounds.items()] - self.labels = {} - self.tasks = None - self.task_types = None - self.task_vocabs = None - - def get_task_info(self, task_df: pl.DataFrame): - """Gets the task information from the task dataframe.""" - self.tasks = sorted([c for c in task_df.columns if c not in ["patient_id", "start_time", "end_time"]]) - - self.task_types = {} - self.task_vocabs = {} - - normalized_cols = [] - for t in self.tasks: - task_type, normalized_vals = self.normalize_task(col=pl.col(t), dtype=task_df.schema[t]) - self.task_types[t] = task_type - normalized_cols.append(normalized_vals.alias(t)) - - task_df = task_df.with_columns(normalized_cols) - - for t in self.tasks: - match self.task_types[t]: - case "binary_classification": - self.task_vocabs[t] = [False, True] - case "multi_class_classification": - self.task_vocabs[t] = list(range(task_df.select(pl.col(t).max()).item() + 1)) - case _: - raise NotImplementedError(f"Task type {self.task_types[t]} not implemented!") - - return {"tasks": sorted(self.tasks), "vocabs": self.task_vocabs, "types": self.task_types} - - def filter_to_min_seq_len(self): - """Filters the dataset to only include subjects with at least `self.config.min_seq_len` events.""" - if self.has_task: - logger.warning( - f"Filtering task {self.config.task_name} to min_seq_len {self.config.min_seq_len}. " - "This may result in incomparable model results against runs with different constraints!" - ) - - orig_len = len(self) - orig_n_subjects = len(set(self.patient_ids)) - valid_indices = [ - i for i, (subj, start, end) in enumerate(self.index) if end - start >= self.config.min_seq_len - ] - self.index = [self.index[i] for i in valid_indices] - self.labels = {t: [t_labels[i] for i in valid_indices] for t, t_labels in self.labels.items()} - new_len = len(self) - new_n_subjects = len(set(self.patient_ids)) - logger.info( - f"Filtered data due to sequence length constraint (>= {self.config.min_seq_len}) from " - f"{orig_len} to {new_len} rows and {orig_n_subjects} to {new_n_subjects} subjects." - ) - - def filter_to_subset(self): - """Filters the dataset to only include a subset of subjects.""" - - orig_len = len(self) - orig_n_subjects = len(set(self.patient_ids)) - rng = np.random.default_rng(self.config.train_subset_seed) - subset_subjects = rng.choice( - list(set(self.patient_ids)), - size=count_or_proportion(orig_n_subjects, self.config.train_subset_size), - replace=False, - ) - valid_indices = [i for i, (subj, start, end) in enumerate(self.index) if subj in subset_subjects] - self.index = [self.index[i] for i in valid_indices] - self.labels = {t: [t_labels[i] for i in valid_indices] for t, t_labels in self.labels.items()} - new_len = len(self) - new_n_subjects = len(set(self.patient_ids)) - logger.info( - f"Filtered data to subset of {self.config.train_subset_size} subjects from " - f"{orig_len} to {new_len} rows and {orig_n_subjects} to {new_n_subjects} subjects." - ) - - def set_inter_event_time_stats(self): - """Sets the inter-event time statistics for the dataset.""" - data_for_stats = pl.concat([x.lazy() for x in self.static_dfs.values()]) - stats = ( - data_for_stats.select( - pl.col("time_delta").explode().drop_nulls().drop_nans().alias("inter_event_time") - ) - .select( - pl.col("inter_event_time").min().alias("min"), - pl.col("inter_event_time").log().mean().alias("mean_log"), - pl.col("inter_event_time").log().std().alias("std_log"), - ) - .collect() - ) - - if stats["min"].item() <= 0: - bad_inter_event_times = data_for_stats.filter(pl.col("time_delta").list.min() <= 0).collect() - bad_patient_ids = set(bad_inter_event_times["patient_id"].to_list()) - warning_strs = [ - f"Observed inter-event times <= 0 for {len(bad_inter_event_times)} subjects!", - f"Bad Subject IDs: {', '.join(str(x) for x in bad_patient_ids)}", - f"Global min: {stats['min'].item()}", - ] - if self.config.MEDS_cohort_dir is not None: - fp = self.config.MEDS_cohort_dir / f"malformed_data_{self.split}.parquet" - bad_inter_event_times.write_parquet(fp) - warning_strs.append(f"Wrote malformed data records to {fp}") - warning_strs.append("Removing malformed subjects") - - logger.warning("\n".join(warning_strs)) - - self.index = [x for x in self.index if x[0] not in bad_patient_ids] - - self.mean_log_inter_event_time_min = stats["mean_log"].item() - self.std_log_inter_event_time_min = stats["std_log"].item() - - @property - def patient_ids(self) -> list[int]: - return [x[0] for x in self.index] - - def __len__(self): - return len(self.index) - - @property - def has_task(self) -> bool: - return self.config.task_name is not None - - @property - def max_seq_len(self) -> int: - return self.config.max_seq_len - - def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: - """Returns a Returns a dictionary corresponding to a single subject's data. - - The output of this will not be tensorized as that work will need to be re-done in the collate function - regardless. The output will have structure: - `` - { - 'time_delta': [seq_len], - 'dynamic_indices': [seq_len, n_data_per_event] (ragged), - 'dynamic_values': [seq_len, n_data_per_event] (ragged), - 'dynamic_measurement_indices': [seq_len, n_data_per_event] (ragged), - 'static_indices': [seq_len, n_data_per_event] (ragged), - 'static_measurement_indices': [seq_len, n_data_per_event] (ragged), - } - `` - - 1. ``time_delta`` captures the time between each event and the subsequent event. - 2. ``dynamic_indices`` captures the categorical metadata elements listed in `self.data_cols` in a - unified vocabulary space spanning all metadata vocabularies. - 3. ``dynamic_values`` captures the numerical metadata elements listed in `self.data_cols`. If no - numerical elements are listed in `self.data_cols` for a given categorical column, the according - index in this output will be `np.NaN`. - 4. ``dynamic_measurement_indices`` captures which measurement vocabulary was used to source a given - data element. - 5. ``static_indices`` captures the categorical metadata elements listed in `self.static_cols` in a - unified vocabulary. - 6. ``static_measurement_indices`` captures which measurement vocabulary was used to source a given - data element. - """ - return self._seeded_getitem(idx) - - @SeedableMixin.WithSeed - def _seeded_getitem(self, idx: int) -> dict[str, list[float]]: - """Returns a Returns a dictionary corresponding to a single subject's data. - - This function is a seedable version of `__getitem__`. - """ - - patient_id, st, end = self.index[idx] - - shard = self.subj_map[patient_id] - patient_idx = self.subj_indices[patient_id] - static_row = self.static_dfs[shard][patient_idx].to_dict() - - out = { - "static_indices": static_row["static_indices"].item().to_list(), - "static_measurement_indices": static_row["static_measurement_indices"].item().to_list(), - } - - if self.config.do_include_patient_id: - out["patient_id"] = patient_id - - seq_len = end - st - if seq_len > self.max_seq_len: - match self.config.subsequence_sampling_strategy: - case SubsequenceSamplingStrategy.RANDOM: - start_offset = np.random.choice(seq_len - self.max_seq_len) - case SubsequenceSamplingStrategy.TO_END: - start_offset = seq_len - self.max_seq_len - case SubsequenceSamplingStrategy.FROM_START: - start_offset = 0 - case _: - raise ValueError( - f"Invalid subsequence sampling strategy {self.config.subsequence_sampling_strategy}!" - ) - - st += start_offset - end = min(end, st + self.max_seq_len) - - if self.config.do_include_subsequence_indices: - out["start_idx"] = st - out["end_idx"] = end - - out["dynamic"] = JointNestedRaggedTensorDict.load_slice( - self.config.tensorized_root / f"{shard}.pt", patient_idx - )[st:end] - - if self.config.do_include_start_time_min: - out["start_time"] = static_row["start_time"] = static_row[ - "start_time" - ].item().timestamp() / 60.0 + sum(static_row["time_delta"].item().to_list()[:st]) - - for t, t_labels in self.labels.items(): - out[t] = t_labels[idx] - - return out - - def __dynamic_only_collate(self, batch: list[dict[str, list[float]]]) -> dict: - """An internal collate function for only dynamic data.""" - keys = batch[0].keys() - dense_keys = {k for k in keys if k not in ("dynamic", "static_indices", "static_measurement_indices")} - - if dense_keys: - dense_collated = torch.utils.data.default_collate([{k: x[k] for k in dense_keys} for x in batch]) - else: - dense_collated = {} - - dynamic = JointNestedRaggedTensorDict.vstack([x["dynamic"] for x in batch]).to_dense( - padding_side=self.config.seq_padding_side - ) - dynamic["event_mask"] = dynamic.pop("dim1/mask") - dynamic["dynamic_values_mask"] = dynamic.pop("dim2/mask") & ~np.isnan(dynamic["dynamic_values"]) - - dynamic_collated = {} - for k, v in dynamic.items(): - if k.endswith("mask"): - dynamic_collated[k] = torch.from_numpy(v) - elif v.dtype in NP_UINT_TYPES + NP_INT_TYPES: - dynamic_collated[k] = torch.from_numpy(v.astype(int)).long() - elif v.dtype in NP_FLOAT_TYPES: - dynamic_collated[k] = torch.from_numpy(v.astype(float)).float() - else: - raise TypeError(f"Don't know how to tensorify {k} of type {v.dtype}!") - - collated = {**dense_collated, **dynamic_collated} - - out_batch = {} - out_batch["event_mask"] = collated["event_mask"] - out_batch["dynamic_values_mask"] = collated["dynamic_values_mask"] - out_batch["time_delta"] = torch.nan_to_num(collated["time_delta"].float(), nan=0) - out_batch["dynamic_indices"] = collated["dynamic_indices"].long() - out_batch["dynamic_measurement_indices"] = collated["dynamic_measurement_indices"].long() - out_batch["dynamic_values"] = torch.nan_to_num(collated["dynamic_values"].float(), nan=0) - - if self.config.do_include_start_time_min: - out_batch["start_time"] = collated["start_time"].float() - if self.config.do_include_subsequence_indices: - out_batch["start_idx"] = collated["start_idx"].long() - out_batch["end_idx"] = collated["end_idx"].long() - if self.config.do_include_patient_id: - out_batch["patient_id"] = collated["patient_id"].long() - - if not self.has_task: - return out_batch - - out_labels = {} - for task in self.tasks: - match self.task_types[task]: - case "multi_class_classification": - out_labels[task] = collated[task].long() - case "binary_classification": - out_labels[task] = collated[task].float() - case "regression": - out_labels[task] = collated[task].float() - case _: - raise TypeError(f"Don't know how to tensorify task of type {self.task_types[task]}!") - out_batch["supervised_labels"] = out_labels - - return out_batch - - def collate(self, batch: list[dict]) -> dict: - """Combines the ragged dictionaries produced by `__getitem__` into a tensorized batch. - - This function handles conversion of arrays to tensors and padding of elements within the batch across - static data elements, sequence events, and dynamic data elements. - - Args: - batch: A list of `__getitem__` format output dictionaries. - - Returns: - A fully collated, tensorized, and padded batch. - """ - - out_batch = self.__dynamic_only_collate(batch) - - max_n_static = max(len(x["static_indices"]) for x in batch) - static_padded_fields = defaultdict(list) - for e in batch: - n_static = len(e["static_indices"]) - static_delta = max_n_static - n_static - for k in ("static_indices", "static_measurement_indices"): - if static_delta > 0: - static_padded_fields[k].append( - torch.nn.functional.pad( - torch.tensor(e[k], dtype=torch.long), (0, static_delta), value=0 - ) - ) - else: - static_padded_fields[k].append(torch.tensor(e[k], dtype=torch.long)) - - for k, v in static_padded_fields.items(): - out_batch[k] = torch.cat([T.unsqueeze(0) for T in v], dim=0) - - return out_batch From 9c6d7c408265a526a8e5681fdcd45b301e5b5c9f Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Mon, 17 Jun 2024 09:35:18 -0400 Subject: [PATCH 39/53] Resolved test stochasticity --- src/MEDS_polars_functions/tensorize.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/MEDS_polars_functions/tensorize.py b/src/MEDS_polars_functions/tensorize.py index c0f3304..ad96c55 100644 --- a/src/MEDS_polars_functions/tensorize.py +++ b/src/MEDS_polars_functions/tensorize.py @@ -41,15 +41,18 @@ def convert_to_NRT(tokenized_df: pl.LazyFrame) -> JointNestedRaggedTensorDict: │ 2 ┆ [NaN] ┆ [[201.0, 202.0]] ┆ [[6.0, 7.0]] │ └────────────┴─────────────────┴───────────────────────────┴─────────────────────┘ >>> nrt = convert_to_NRT(df.lazy()) - >>> for k, v in nrt.to_dense().items(): + >>> for k, v in sorted(list(nrt.to_dense().items())): ... print(k) ... print(v) + code + [[[101. 102.] + [103. 0.]] + + [[201. 202.] + [ 0. 0.]]] dim1/mask [[ True True] [ True False]] - time_delta_days - [[nan 12.] - [nan 0.]] dim2/mask [[[ True True] [ True False]] @@ -62,12 +65,9 @@ def convert_to_NRT(tokenized_df: pl.LazyFrame) -> JointNestedRaggedTensorDict: [[6. 7.] [0. 0.]]] - code - [[[101. 102.] - [103. 0.]] - - [[201. 202.] - [ 0. 0.]]] + time_delta_days + [[nan 12.] + [nan 0.]] """ # There should only be one time delta column, but this ensures we catch it regardless of the unit of time From 605c651e961bbc8bb7cfcb0a4eb35fab18a9552b Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Mon, 17 Jun 2024 09:44:18 -0400 Subject: [PATCH 40/53] removing unneeded dependency. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 142779c..25b9527 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ] -dependencies = ["polars", "pyarrow", "nested_ragged_tensors", "loguru", "hydra-core", "numpy", "ml-mixins"] +dependencies = ["polars", "pyarrow", "nested_ragged_tensors", "loguru", "hydra-core", "numpy"] [project.optional-dependencies] examples = ["rootutils"] From eb9a176ff943fff8eb52129176187e368411216c Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Mon, 17 Jun 2024 10:57:58 -0400 Subject: [PATCH 41/53] removing unused config --- configs/pytorch_dataset.yaml | 21 --------------------- 1 file changed, 21 deletions(-) delete mode 100644 configs/pytorch_dataset.yaml diff --git a/configs/pytorch_dataset.yaml b/configs/pytorch_dataset.yaml deleted file mode 100644 index da76807..0000000 --- a/configs/pytorch_dataset.yaml +++ /dev/null @@ -1,21 +0,0 @@ -MEDS_cohort_dir: ??? -task_name: null - -subsequence_sampling_strategy: random -seq_padding_side: right - -train_subset_size: null -train_subset_seed: 1 - -min_seq_len: null -max_seq_len: ??? - -do_include_patient_id: false -do_include_subsequence_indices: false -do_include_start_time_min: false - -code_metadata_fp: ${MEDS_cohort_dir}/code_metadata.parquet -split_shards_fp: ${MEDS_cohort_dir}/splits.json -schema_files_root: ${MEDS_cohort_dir}/tokenize/schemas -tasks_root: ${MEDS_cohort_dir}/tasks -tensorized_root: ${MEDS_cohort_dir}/tensorized From 13e99dda139cf709e1974c44c228cc02875988ac Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Mon, 17 Jun 2024 11:57:22 -0400 Subject: [PATCH 42/53] Added (yet untested) tokenization and tensorization scripts --- scripts/preprocessing/tensorize.py | 59 +++++++++++++++++++++++ scripts/preprocessing/tokenize.py | 75 ++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+) create mode 100755 scripts/preprocessing/tensorize.py create mode 100755 scripts/preprocessing/tokenize.py diff --git a/scripts/preprocessing/tensorize.py b/scripts/preprocessing/tensorize.py new file mode 100755 index 0000000..66d2005 --- /dev/null +++ b/scripts/preprocessing/tensorize.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python + +import json +import random +from pathlib import Path + +import hydra +import polars as pl +from loguru import logger +from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict +from omegaconf import DictConfig, OmegaConf + +from MEDS_polars_functions.mapper import wrap as rwlock_wrap +from MEDS_polars_functions.tensorize import tensorize +from MEDS_polars_functions.utils import hydra_loguru_init + + +@hydra.main(version_base=None, config_path="../../configs", config_name="preprocess") +def main(cfg: DictConfig): + """TODO.""" + + hydra_loguru_init() + + logger.info( + f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" + f"Stage: {cfg.stage}\n\n" + f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" + ) + + input_dir = Path(cfg.stage_cfg.data_input_dir) + output_dir = Path(cfg.stage_cfg.output_dir) + + shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) + + patient_splits = list(shards.keys()) + random.shuffle(patient_splits) + + for sp in patient_splits: + in_fp = input_dir / "event_seqs" / f"{sp}.parquet" + out_fp = output_dir / f"{sp}.parquet" + + logger.info(f"Tensorizing {str(in_fp.resolve())} into {str(out_fp.resolve())}") + + rwlock_wrap( + in_fp, + out_fp, + pl.scan_parquet, + JointNestedRaggedTensorDict.save, + tensorize, + do_return=False, + cache_intermediate=False, + do_overwrite=cfg.do_overwrite, + ) + + logger.info(f"Done with {cfg.stage}") + + +if __name__ == "__main__": + main() diff --git a/scripts/preprocessing/tokenize.py b/scripts/preprocessing/tokenize.py new file mode 100755 index 0000000..0a7cf9a --- /dev/null +++ b/scripts/preprocessing/tokenize.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python + +import json +import random +from pathlib import Path + +import hydra +import polars as pl +from loguru import logger +from omegaconf import DictConfig, OmegaConf + +from MEDS_polars_functions.mapper import wrap as rwlock_wrap +from MEDS_polars_functions.tokenize import ( + extract_seq_of_patient_events, + extract_statics_and_schema, +) +from MEDS_polars_functions.utils import hydra_loguru_init, write_lazyframe + + +@hydra.main(version_base=None, config_path="../../configs", config_name="preprocess") +def main(cfg: DictConfig): + """TODO.""" + + hydra_loguru_init() + + logger.info( + f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" + f"Stage: {cfg.stage}\n\n" + f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" + ) + + input_dir = Path(cfg.stage_cfg.data_input_dir) + output_dir = Path(cfg.stage_cfg.output_dir) + + shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) + + patient_splits = list(shards.keys()) + random.shuffle(patient_splits) + + for sp in patient_splits: + in_fp = input_dir / f"{sp}.parquet" + schema_out_fp = output_dir / "schemas" / f"{sp}.parquet" + event_seq_out_fp = output_dir / "event_seqs" / f"{sp}.parquet" + + logger.info(f"Tokenizing {str(in_fp.resolve())} into schemas at {str(schema_out_fp.resolve())}") + + rwlock_wrap( + in_fp, + schema_out_fp, + pl.scan_parquet, + write_lazyframe, + extract_statics_and_schema, + do_return=False, + cache_intermediate=False, + do_overwrite=cfg.do_overwrite, + ) + + logger.info(f"Tokenizing {str(in_fp.resolve())} into event_seqs at {str(event_seq_out_fp.resolve())}") + + rwlock_wrap( + in_fp, + event_seq_out_fp, + pl.scan_parquet, + write_lazyframe, + extract_seq_of_patient_events, + do_return=False, + cache_intermediate=False, + do_overwrite=cfg.do_overwrite, + ) + + logger.info(f"Done with {cfg.stage}") + + +if __name__ == "__main__": + main() From 1f616b38b1da6b41efde53a038cfe53948aa2006 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Mon, 17 Jun 2024 12:04:25 -0400 Subject: [PATCH 43/53] Had to rename to avoid an import issue with hydra. --- configs/preprocess.yaml | 2 +- scripts/preprocessing/{tokenize.py => tokenization.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename scripts/preprocessing/{tokenize.py => tokenization.py} (100%) diff --git a/configs/preprocess.yaml b/configs/preprocess.yaml index bf925a1..b42a221 100644 --- a/configs/preprocess.yaml +++ b/configs/preprocess.yaml @@ -17,7 +17,7 @@ stages: - fit_normalization - fit_vocabulary_indices - normalize - - tokenize + - tokenization - tensorize # Pipeline Structure diff --git a/scripts/preprocessing/tokenize.py b/scripts/preprocessing/tokenization.py similarity index 100% rename from scripts/preprocessing/tokenize.py rename to scripts/preprocessing/tokenization.py From 6f703060c363a782c15f76aeaa457908f823192c Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Mon, 17 Jun 2024 12:10:05 -0400 Subject: [PATCH 44/53] Change file extension for NRT files --- scripts/preprocessing/tensorize.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/preprocessing/tensorize.py b/scripts/preprocessing/tensorize.py index 66d2005..5ebd2f7 100755 --- a/scripts/preprocessing/tensorize.py +++ b/scripts/preprocessing/tensorize.py @@ -11,7 +11,7 @@ from omegaconf import DictConfig, OmegaConf from MEDS_polars_functions.mapper import wrap as rwlock_wrap -from MEDS_polars_functions.tensorize import tensorize +from MEDS_polars_functions.tensorize import convert_to_NRT from MEDS_polars_functions.utils import hydra_loguru_init @@ -37,7 +37,7 @@ def main(cfg: DictConfig): for sp in patient_splits: in_fp = input_dir / "event_seqs" / f"{sp}.parquet" - out_fp = output_dir / f"{sp}.parquet" + out_fp = output_dir / f"{sp}.nrt" logger.info(f"Tensorizing {str(in_fp.resolve())} into {str(out_fp.resolve())}") @@ -46,7 +46,7 @@ def main(cfg: DictConfig): out_fp, pl.scan_parquet, JointNestedRaggedTensorDict.save, - tensorize, + convert_to_NRT, do_return=False, cache_intermediate=False, do_overwrite=cfg.do_overwrite, From dd78d1fe97df4e51bb861f88f8d24239facd7bde Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 19 Jun 2024 15:06:08 -0400 Subject: [PATCH 45/53] Fixed a test that fails on numpy 2.0 --- src/MEDS_polars_functions/code_metadata.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/MEDS_polars_functions/code_metadata.py b/src/MEDS_polars_functions/code_metadata.py index 6e7f99f..19a079c 100644 --- a/src/MEDS_polars_functions/code_metadata.py +++ b/src/MEDS_polars_functions/code_metadata.py @@ -221,10 +221,10 @@ def mapper_fntr( >>> import numpy as np >>> df = pl.DataFrame({ ... "code": pl.Series(["A", "B", "A", "B", "C", "A", "C", "B", "D"], dtype=pl.Categorical), - ... "modifier1": [1, 2, 1, 2, 1, 2, 1, 2, None], - ... "modifier_ignored": [3, 3, 4, 4, 5, 5, 6, 6, 7], - ... "patient_id": [1, 2, 1, 3, 1, 2, 2, 2, 1], - ... "numerical_value": [1.1, 2., 1.1, 4., 5., 6., 7.5, np.NaN, None], + ... "modifier1": [1, 2, 1, 2, 1, 2, 1, 2, None], + ... "modifier_ignored": [3, 3, 4, 4, 5, 5, 6, 6, 7], + ... "patient_id": [1, 2, 1, 3, 1, 2, 2, 2, 1], + ... "numerical_value": [1.1, 2., 1.1, 4., 5., 6., 7.5, float('nan'), None], ... }) >>> df shape: (9, 5) From 42eb4a613814d1ebe54083d338b606d88bc5bf5f Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 19 Jun 2024 15:20:46 -0400 Subject: [PATCH 46/53] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 25b9527..c50f1ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = ["polars", "pyarrow", "nested_ragged_tensors", "loguru", "hydra-c [project.optional-dependencies] examples = ["rootutils"] dev = ["pre-commit"] -tests = ["pytest", "pytest-cov[toml]", "rootutils"] +tests = ["pytest", "pytest-cov", "rootutils"] local_parallelism = ["hydra-joblib-launcher"] slurm_parallelism = ["hydra-submitit-launcher"] From f3a9edb805219f3ced28f553a793f6a933686c26 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 19 Jun 2024 15:23:01 -0400 Subject: [PATCH 47/53] Try to correct github lint issue. --- MIMIC-IV_Example/README.md | 64 +++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/MIMIC-IV_Example/README.md b/MIMIC-IV_Example/README.md index 0f9ffed..ee83eb8 100644 --- a/MIMIC-IV_Example/README.md +++ b/MIMIC-IV_Example/README.md @@ -85,9 +85,9 @@ This is a step in 4 parts: ```bash ./scripts/extraction/shard_events.py \ - input_dir=$MIMICIV_PREMEDS_DIR \ - cohort_dir=$MIMICIV_MEDS_DIR \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml + input_dir=$MIMICIV_PREMEDS_DIR \ + cohort_dir=$MIMICIV_MEDS_DIR \ + event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml ``` In practice, on a machine with 150 GB of RAM and 10 cores, this step takes approximately 20 minutes in total. @@ -96,9 +96,9 @@ In practice, on a machine with 150 GB of RAM and 10 cores, this step takes appro ```bash ./scripts/extraction/split_and_shard_patients.py \ - input_dir=$MIMICIV_PREMEDS_DIR \ - cohort_dir=$MIMICIV_MEDS_DIR \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml + input_dir=$MIMICIV_PREMEDS_DIR \ + cohort_dir=$MIMICIV_MEDS_DIR \ + event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml ``` In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less than 5 minutes in total. @@ -107,9 +107,9 @@ In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less ```bash ./scripts/extraction/convert_to_sharded_events.py \ - input_dir=$MIMICIV_PREMEDS_DIR \ - cohort_dir=$MIMICIV_MEDS_DIR \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml + input_dir=$MIMICIV_PREMEDS_DIR \ + cohort_dir=$MIMICIV_MEDS_DIR \ + event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml ``` In practice, serially, this also takes around 20 minutes or more. However, it can be trivially parallelized to @@ -122,9 +122,9 @@ and performance is not necessary; however, for larger datasets, it can be. ```bash ./scripts/extraction/merge_to_MEDS_cohort.py \ - input_dir=$MIMICIV_PREMEDS_DIR \ - cohort_dir=$MIMICIV_MEDS_DIR \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml + input_dir=$MIMICIV_PREMEDS_DIR \ + cohort_dir=$MIMICIV_MEDS_DIR \ + event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml ``` ### Running Locally, in Parallel. @@ -143,17 +143,17 @@ to finish before moving on to the next stage. Let `$N_PARALLEL_WORKERS` be the n ```bash ./scripts/extraction/shard_events.py \ - --multirun \ - worker="range(0,$N_PARALLEL_WORKERS)" \ - hydra/launcher=submitit_slurm \ - hydra.launcher.timeout_min=60 \ - hydra.launcher.cpus_per_task=10 \ - hydra.launcher.mem_gb=50 \ - hydra.launcher.name="${hydra.job.name}_${worker}" \ - hydra.launcher.partition="short" \ - input_dir=$MIMICIV_PREMEDS_DIR \ - cohort_dir=$MIMICIV_MEDS_DIR \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml + --multirun \ + worker="range(0,$N_PARALLEL_WORKERS)" \ + hydra/launcher=submitit_slurm \ + hydra.launcher.timeout_min=60 \ + hydra.launcher.cpus_per_task=10 \ + hydra.launcher.mem_gb=50 \ + hydra.launcher.name="${hydra.job.name}_${worker}" \ + hydra.launcher.partition="short" \ + input_dir=$MIMICIV_PREMEDS_DIR \ + cohort_dir=$MIMICIV_MEDS_DIR \ + event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml ``` In practice, on a machine with 150 GB of RAM and 10 cores, this step takes approximately 20 minutes in total. @@ -162,9 +162,9 @@ In practice, on a machine with 150 GB of RAM and 10 cores, this step takes appro ```bash ./scripts/extraction/split_and_shard_patients.py \ - input_dir=$MIMICIV_PREMEDS_DIR \ - cohort_dir=$MIMICIV_MEDS_DIR \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml + input_dir=$MIMICIV_PREMEDS_DIR \ + cohort_dir=$MIMICIV_MEDS_DIR \ + event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml ``` In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less than 5 minutes in total. @@ -173,9 +173,9 @@ In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less ```bash ./scripts/extraction/convert_to_sharded_events.py \ - input_dir=$MIMICIV_PREMEDS_DIR \ - cohort_dir=$MIMICIV_MEDS_DIR \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml + input_dir=$MIMICIV_PREMEDS_DIR \ + cohort_dir=$MIMICIV_MEDS_DIR \ + event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml ``` In practice, serially, this also takes around 20 minutes or more. However, it can be trivially parallelized to @@ -188,9 +188,9 @@ and performance is not necessary; however, for larger datasets, it can be. ```bash ./scripts/extraction/merge_to_MEDS_cohort.py \ - input_dir=$MIMICIV_PREMEDS_DIR \ - cohort_dir=$MIMICIV_MEDS_DIR \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml + input_dir=$MIMICIV_PREMEDS_DIR \ + cohort_dir=$MIMICIV_MEDS_DIR \ + event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml ``` 5. (Optional) Generate preliminary code statistics and merge to external metadata. From 38c1d78937da6e8f94876052937949e89513d9ba Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 19 Jun 2024 15:32:56 -0400 Subject: [PATCH 48/53] Updated the MIMIC README and removed the troublesome portions. --- MIMIC-IV_Example/README.md | 137 +++++-------------------- MIMIC-IV_Example/joint_script_slurm.sh | 98 +++++++++--------- 2 files changed, 73 insertions(+), 162 deletions(-) diff --git a/MIMIC-IV_Example/README.md b/MIMIC-IV_Example/README.md index ee83eb8..4d626da 100644 --- a/MIMIC-IV_Example/README.md +++ b/MIMIC-IV_Example/README.md @@ -45,13 +45,20 @@ that page. You will need the raw `.csv.gz` files for this example. We will use ` the root directory of where the resulting _core data files_ are stored -- e.g., there should be a `hosp` and `icu` subdirectory of `$MIMICIV_RAW_DIR`. -## Step 2: Get the data ready for base MEDS extraction +## Step 2: Run the basic MEDS ETL + +This step contains several sub-steps; luckily, all these substeps can be run via a single script, with the +`joint_script.sh` script. This script entails several steps: + +### Step 2.1: Get the data ready for base MEDS extraction This is a step in a few parts: 1. Join a few tables by `hadm_id` to get the right timestamps in the right rows for processing. In particular, we need to join: - - TODO + - the `hosp/diagnoses_icd` table with the `hosp/admissions` table to get the `dischtime` for each + `hadm_id`. + - the `hosp/drgcodes` table with the `hosp/admissions` table to get the `dischtime` for each `hadm_id`. 2. Convert the patient's static data to a more parseable form. This entails: - Get the patient's DOB in a format that is usable for MEDS, rather than the integral `anchor_year` and `anchor_offset` fields. @@ -61,7 +68,8 @@ After these steps, modified files or symlinks to the original files will be writ will be used as the input to the actual MEDS extraction ETL. We'll use `$MIMICIV_PREMEDS_DIR` to denote this directory. -To run this step, you can use the following script (assumed to be run **not** from this directory but from the +This step is run in the `joint_script.sh` script or the `joint_script_slurm.sh` script, but in either case the +base command that is run is as follows (assumed to be run **not** from this directory but from the root directory of this repository): ```bash @@ -70,9 +78,7 @@ root directory of this repository): In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less than 5 minutes in total. -## Step 3: Run the MEDS extraction ETL - -### Running locally, serially +### Step 2.2: Run the MEDS extraction ETL We will assume you want to output the final MEDS dataset into a directory we'll denote as `$MIMICIV_MEDS_DIR`. Note this is a different directory than the pre-MEDS directory (though, of course, they can both be @@ -83,117 +89,22 @@ This is a step in 4 parts: 1. Sub-shard the raw files. Run this command as many times simultaneously as you would like to have workers performing this sub-sharding step. See below for how to automate this parallelism using hydra launchers. -```bash -./scripts/extraction/shard_events.py \ - input_dir=$MIMICIV_PREMEDS_DIR \ - cohort_dir=$MIMICIV_MEDS_DIR \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml -``` - -In practice, on a machine with 150 GB of RAM and 10 cores, this step takes approximately 20 minutes in total. - -2. Extract and form the patient splits and sub-shards. - -```bash -./scripts/extraction/split_and_shard_patients.py \ - input_dir=$MIMICIV_PREMEDS_DIR \ - cohort_dir=$MIMICIV_MEDS_DIR \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml -``` - -In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less than 5 minutes in total. - -3. Extract patient sub-shards and convert to MEDS events. - -```bash -./scripts/extraction/convert_to_sharded_events.py \ - input_dir=$MIMICIV_PREMEDS_DIR \ - cohort_dir=$MIMICIV_MEDS_DIR \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml -``` - -In practice, serially, this also takes around 20 minutes or more. However, it can be trivially parallelized to -cut the time down by a factor of the number of workers processing the data by simply running the command -multiple times (though this will, of course, consume more resources). If your filesystem is distributed, these -commands can also be launched as separate slurm jobs, for example. For MIMIC-IV, this level of parallelization -and performance is not necessary; however, for larger datasets, it can be. - -4. Merge the MEDS events into a single file per patient sub-shard. - -```bash -./scripts/extraction/merge_to_MEDS_cohort.py \ - input_dir=$MIMICIV_PREMEDS_DIR \ - cohort_dir=$MIMICIV_MEDS_DIR \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml -``` - -### Running Locally, in Parallel. - -This step is the exact same commands as above, but leverages Hydra's multirun capabilities with the `joblib` -launcher. Install this package with the optional `local_parallelism` option (e.g., `pip install -e .[local_parallelism]` and run `./MIMIC-IV_Example/joint_script.sh`. See that script for expected args. - -### Running Each Step over Slurm + This step uses the `./scripts/extraction/shard_events.py` script. See `joint_script*.sh` for the expected + format of the command. -To use slurm, run each command with the number of workers desired using Hydra's multirun capabilities with the -`submitit_slurm` launcher. Install this package with the optional `slurm_parallelism` option. See below for -modified commands. Note these can't be chained in a single script as the jobs will not wait for all slurm jobs -to finish before moving on to the next stage. Let `$N_PARALLEL_WORKERS` be the number of desired workers +2. Extract and form the patient splits and sub-shards. The `./scripts/extraction/split_and_shard_patients.py` + script is used for this step. See `joint_script*.sh` for the expected format of the command. -1. Sub-shard the raw files. - -```bash -./scripts/extraction/shard_events.py \ - --multirun \ - worker="range(0,$N_PARALLEL_WORKERS)" \ - hydra/launcher=submitit_slurm \ - hydra.launcher.timeout_min=60 \ - hydra.launcher.cpus_per_task=10 \ - hydra.launcher.mem_gb=50 \ - hydra.launcher.name="${hydra.job.name}_${worker}" \ - hydra.launcher.partition="short" \ - input_dir=$MIMICIV_PREMEDS_DIR \ - cohort_dir=$MIMICIV_MEDS_DIR \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml -``` - -In practice, on a machine with 150 GB of RAM and 10 cores, this step takes approximately 20 minutes in total. - -2. Extract and form the patient splits and sub-shards. - -```bash -./scripts/extraction/split_and_shard_patients.py \ - input_dir=$MIMICIV_PREMEDS_DIR \ - cohort_dir=$MIMICIV_MEDS_DIR \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml -``` +3. Extract patient sub-shards and convert to MEDS events. The + `./scripts/extraction/convert_to_sharded_events.py` script is used for this step. See `joint_script*.sh` for + the expected format of the command. -In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less than 5 minutes in total. - -3. Extract patient sub-shards and convert to MEDS events. - -```bash -./scripts/extraction/convert_to_sharded_events.py \ - input_dir=$MIMICIV_PREMEDS_DIR \ - cohort_dir=$MIMICIV_MEDS_DIR \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml -``` - -In practice, serially, this also takes around 20 minutes or more. However, it can be trivially parallelized to -cut the time down by a factor of the number of workers processing the data by simply running the command -multiple times (though this will, of course, consume more resources). If your filesystem is distributed, these -commands can also be launched as separate slurm jobs, for example. For MIMIC-IV, this level of parallelization -and performance is not necessary; however, for larger datasets, it can be. - -4. Merge the MEDS events into a single file per patient sub-shard. - -```bash -./scripts/extraction/merge_to_MEDS_cohort.py \ - input_dir=$MIMICIV_PREMEDS_DIR \ - cohort_dir=$MIMICIV_MEDS_DIR \ - event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml -``` +4. Merge the MEDS events into a single file per patient sub-shard. The + `./scripts/extraction/merge_to_MEDS_cohort.py` script is used for this step. See `joint_script*.sh` for the + expected format of the command. -5. (Optional) Generate preliminary code statistics and merge to external metadata. +5. (Optional) Generate preliminary code statistics and merge to external metadata. This is not performed + currently in the `joint_script*.sh` scripts. ## Pre-processing for a model diff --git a/MIMIC-IV_Example/joint_script_slurm.sh b/MIMIC-IV_Example/joint_script_slurm.sh index feb7fd3..f088dfd 100755 --- a/MIMIC-IV_Example/joint_script_slurm.sh +++ b/MIMIC-IV_Example/joint_script_slurm.sh @@ -44,17 +44,17 @@ shift 4 # this doesn't fall back on running anything locally in a setting where only slurm worker nodes have # sufficient computational resources to run the actual jobs. -# echo "Running pre-MEDS conversion on one worker." -# ./MIMIC-IV_Example/pre_MEDS.py \ -# --multirun \ -# worker="range(0,1)" \ -# hydra/launcher=submitit_slurm \ -# hydra.launcher.timeout_min=60 \ -# hydra.launcher.cpus_per_task=10 \ -# hydra.launcher.mem_gb=50 \ -# hydra.launcher.partition="short" \ -# raw_cohort_dir="$MIMICIV_RAW_DIR" \ -# output_dir="$MIMICIV_PREMEDS_DIR" +echo "Running pre-MEDS conversion on one worker." +./MIMIC-IV_Example/pre_MEDS.py \ + --multirun \ + worker="range(0,1)" \ + hydra/launcher=submitit_slurm \ + hydra.launcher.timeout_min=60 \ + hydra.launcher.cpus_per_task=10 \ + hydra.launcher.mem_gb=50 \ + hydra.launcher.partition="short" \ + raw_cohort_dir="$MIMICIV_RAW_DIR" \ + output_dir="$MIMICIV_PREMEDS_DIR" echo "Trying submitit launching with $N_PARALLEL_WORKERS jobs." @@ -72,41 +72,41 @@ echo "Trying submitit launching with $N_PARALLEL_WORKERS jobs." event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml \ stage=shard_events -# echo "Splitting patients on one worker" -# ./scripts/extraction/split_and_shard_patients.py \ -# --multirun \ -# worker="range(0,1)" \ -# hydra/launcher=submitit_slurm \ -# hydra.launcher.timeout_min=60 \ -# hydra.launcher.cpus_per_task=10 \ -# hydra.launcher.mem_gb=50 \ -# hydra.launcher.partition="short" \ -# input_dir="$MIMICIV_PREMEDS_DIR" \ -# cohort_dir="$MIMICIV_MEDS_DIR" \ -# event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" -# -# echo "Converting to sharded events with $N_PARALLEL_WORKERS workers in parallel" -# ./scripts/extraction/convert_to_sharded_events.py \ -# --multirun \ -# worker="range(0,$N_PARALLEL_WORKERS)" \ -# hydra/launcher=submitit_slurm \ -# hydra.launcher.timeout_min=60 \ -# hydra.launcher.cpus_per_task=10 \ -# hydra.launcher.mem_gb=50 \ -# hydra.launcher.partition="short" \ -# input_dir="$MIMICIV_PREMEDS_DIR" \ -# cohort_dir="$MIMICIV_MEDS_DIR" \ -# event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" -# -# echo "Merging to a MEDS cohort with $N_PARALLEL_WORKERS workers in parallel" -# ./scripts/extraction/merge_to_MEDS_cohort.py \ -# --multirun \ -# worker="range(0,$N_PARALLEL_WORKERS)" \ -# hydra/launcher=submitit_slurm \ -# hydra.launcher.timeout_min=60 \ -# hydra.launcher.cpus_per_task=10 \ -# hydra.launcher.mem_gb=50 \ -# hydra.launcher.partition="short" \ -# input_dir="$MIMICIV_PREMEDS_DIR" \ -# cohort_dir="$MIMICIV_MEDS_DIR" \ -# event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" +echo "Splitting patients on one worker" +./scripts/extraction/split_and_shard_patients.py \ + --multirun \ + worker="range(0,1)" \ + hydra/launcher=submitit_slurm \ + hydra.launcher.timeout_min=60 \ + hydra.launcher.cpus_per_task=10 \ + hydra.launcher.mem_gb=50 \ + hydra.launcher.partition="short" \ + input_dir="$MIMICIV_PREMEDS_DIR" \ + cohort_dir="$MIMICIV_MEDS_DIR" \ + event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" + +echo "Converting to sharded events with $N_PARALLEL_WORKERS workers in parallel" +./scripts/extraction/convert_to_sharded_events.py \ + --multirun \ + worker="range(0,$N_PARALLEL_WORKERS)" \ + hydra/launcher=submitit_slurm \ + hydra.launcher.timeout_min=60 \ + hydra.launcher.cpus_per_task=10 \ + hydra.launcher.mem_gb=50 \ + hydra.launcher.partition="short" \ + input_dir="$MIMICIV_PREMEDS_DIR" \ + cohort_dir="$MIMICIV_MEDS_DIR" \ + event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" + +echo "Merging to a MEDS cohort with $N_PARALLEL_WORKERS workers in parallel" +./scripts/extraction/merge_to_MEDS_cohort.py \ + --multirun \ + worker="range(0,$N_PARALLEL_WORKERS)" \ + hydra/launcher=submitit_slurm \ + hydra.launcher.timeout_min=60 \ + hydra.launcher.cpus_per_task=10 \ + hydra.launcher.mem_gb=50 \ + hydra.launcher.partition="short" \ + input_dir="$MIMICIV_PREMEDS_DIR" \ + cohort_dir="$MIMICIV_MEDS_DIR" \ + event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" From f368bbefeb56113f15ca5d7bda5258c1d2711f53 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 19 Jun 2024 15:34:19 -0400 Subject: [PATCH 49/53] Updated the MIMIC README and removed the troublesome portions. --- MIMIC-IV_Example/README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/MIMIC-IV_Example/README.md b/MIMIC-IV_Example/README.md index 4d626da..66a6005 100644 --- a/MIMIC-IV_Example/README.md +++ b/MIMIC-IV_Example/README.md @@ -48,7 +48,11 @@ the root directory of where the resulting _core data files_ are stored -- e.g., ## Step 2: Run the basic MEDS ETL This step contains several sub-steps; luckily, all these substeps can be run via a single script, with the -`joint_script.sh` script. This script entails several steps: +`joint_script.sh` script which uses the Hydra `joblib` launcher to run things with local parallelism (make +sure you enable this feature by including the `[local_parallelism]` option during installation) or via +`joint_script_slurm.sh` which uses the Hydra `submitit` launcher to run things through slurm (make sure you +enable this feature by including the `[slurm_parallelism]` option during installation). This script entails +several steps: ### Step 2.1: Get the data ready for base MEDS extraction From e70d26f0235255ee810104f1014ff512385bd934 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 19 Jun 2024 16:59:22 -0400 Subject: [PATCH 50/53] Checked the other shards for #23 --- tests/test_extraction.py | 77 +++++++++++++++++++++++++++++++++------- 1 file changed, 65 insertions(+), 12 deletions(-) diff --git a/tests/test_extraction.py b/tests/test_extraction.py index d221a30..b5bb85e 100644 --- a/tests/test_extraction.py +++ b/tests/test_extraction.py @@ -147,31 +147,87 @@ def get_expected_output(df: str) -> pl.DataFrame: 1195293,"06/20/2010, 20:50:04",DISCHARGE, """ +MEDS_OUTPUT_TRAIN_1_SUBJECTS = """ +patient_id,timestamp,code,numerical_value +68729,,EYE_COLOR//HAZEL, +68729,,HEIGHT,160.3953106166676 +68729,"03/09/1978, 00:00:00",DOB, +814703,,EYE_COLOR//HAZEL, +814703,,HEIGHT,156.48559093209357 +814703,"03/28/1976, 00:00:00",DOB, +""" + +MEDS_OUTPUT_TRAIN_1_ADMIT_VITALS = """ +patient_id,timestamp,code,numerical_value +68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, +68729,"05/26/2010, 02:30:56",HR,86.0 +68729,"05/26/2010, 02:30:56",TEMP,97.8 +68729,"05/26/2010, 04:51:52",DISCHARGE, +814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, +814703,"02/05/2010, 05:55:39",HR,170.2 +814703,"02/05/2010, 05:55:39",TEMP,100.1 +814703,"02/05/2010, 07:02:30",DISCHARGE, +""" + +MEDS_OUTPUT_TUNING_0_SUBJECTS = """ +patient_id,timestamp,code,numerical_value +754281,,EYE_COLOR//BROWN, +754281,,HEIGHT,166.22261567137025 +754281,"12/19/1988, 00:00:00",DOB, +""" + +MEDS_OUTPUT_TUNING_0_ADMIT_VITALS = """ +patient_id,timestamp,code,numerical_value +754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, +754281,"01/03/2010, 06:27:59",HR,142.0 +754281,"01/03/2010, 06:27:59",TEMP,99.8 +754281,"01/03/2010, 08:22:13",DISCHARGE, +""" + +MEDS_OUTPUT_HELD_OUT_0_SUBJECTS = """ +patient_id,timestamp,code,numerical_value +1500733,,EYE_COLOR//BROWN, +1500733,,HEIGHT,158.60131573580904 +1500733,"07/20/1986, 00:00:00",DOB, +""" + +MEDS_OUTPUT_HELD_OUT_0_ADMIT_VITALS = """ +patient_id,timestamp,code,numerical_value +1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, +1500733,"06/03/2010, 14:54:38",HR,91.4 +1500733,"06/03/2010, 14:54:38",TEMP,100.0 +1500733,"06/03/2010, 15:39:49",HR,84.4 +1500733,"06/03/2010, 15:39:49",TEMP,100.3 +1500733,"06/03/2010, 16:20:49",HR,90.1 +1500733,"06/03/2010, 16:20:49",TEMP,100.1 +1500733,"06/03/2010, 16:44:26",DISCHARGE, +""" + SUB_SHARDED_OUTPUTS = { "train/0": { "subjects": MEDS_OUTPUT_TRAIN_0_SUBJECTS, "admit_vitals": MEDS_OUTPUT_TRAIN_0_ADMIT_VITALS, }, "train/1": { - "subjects": None, - "admit_vitals": None, + "subjects": MEDS_OUTPUT_TRAIN_1_SUBJECTS, + "admit_vitals": MEDS_OUTPUT_TRAIN_1_ADMIT_VITALS, }, "tuning/0": { - "subjects": None, - "admit_vitals": None, + "subjects": MEDS_OUTPUT_TUNING_0_SUBJECTS, + "admit_vitals": MEDS_OUTPUT_TUNING_0_ADMIT_VITALS, }, "held_out/0": { - "subjects": None, - "admit_vitals": None, + "subjects": MEDS_OUTPUT_HELD_OUT_0_SUBJECTS, + "admit_vitals": MEDS_OUTPUT_HELD_OUT_0_ADMIT_VITALS, }, } MEDS_OUTPUTS = { "train/0": [MEDS_OUTPUT_TRAIN_0_SUBJECTS, MEDS_OUTPUT_TRAIN_0_ADMIT_VITALS], - "train/1": None, - "tuning/0": None, - "held_out/0": None, + "train/1": [MEDS_OUTPUT_TRAIN_1_SUBJECTS, MEDS_OUTPUT_TRAIN_1_ADMIT_VITALS], + "tuning/0": [MEDS_OUTPUT_TUNING_0_SUBJECTS, MEDS_OUTPUT_TUNING_0_ADMIT_VITALS], + "held_out/0": [MEDS_OUTPUT_HELD_OUT_0_SUBJECTS, MEDS_OUTPUT_HELD_OUT_0_ADMIT_VITALS], } @@ -358,9 +414,6 @@ def test_extraction(): for split, expected_outputs in SUB_SHARDED_OUTPUTS.items(): for prefix, expected_df_L in expected_outputs.items(): - if expected_df_L is None: - continue - if not isinstance(expected_df_L, list): expected_df_L = [expected_df_L] From 7150d5ea85760fd23d96988a649e2048a9295bc3 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 19 Jun 2024 22:03:42 -0400 Subject: [PATCH 51/53] Added code metadata checking to the test. --- tests/test_extraction.py | 41 +++++++++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/tests/test_extraction.py b/tests/test_extraction.py index b5bb85e..75a7404 100644 --- a/tests/test_extraction.py +++ b/tests/test_extraction.py @@ -11,7 +11,6 @@ from pathlib import Path import polars as pl -from loguru import logger from polars.testing import assert_frame_equal pl.enable_string_cache() @@ -203,6 +202,22 @@ def get_expected_output(df: str) -> pl.DataFrame: 1500733,"06/03/2010, 16:44:26",DISCHARGE, """ +MEDS_OUTPUT_CODE_METADATA_FILE = """ +code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd +,44,4,28,3198.8389005974336,382968.28937288234 +ADMISSION//CARDIAC,2,2,0,, +ADMISSION//ORTHOPEDIC,1,1,0,, +ADMISSION//PULMONARY,1,1,0,, +DISCHARGE,4,4,0,, +DOB,4,4,0,, +EYE_COLOR//BLUE,1,1,0,, +EYE_COLOR//BROWN,1,1,0,, +EYE_COLOR//HAZEL,2,2,0,, +HEIGHT,4,4,4,656.8389005974336,108056.12937288235 +HR,12,4,12,1360.5000000000002,158538.77 +TEMP,12,4,12,1181.4999999999998,116373.38999999998 +""" + SUB_SHARDED_OUTPUTS = { "train/0": { "subjects": MEDS_OUTPUT_TRAIN_0_SUBJECTS, @@ -454,9 +469,6 @@ def test_extraction(): output_folder = MEDS_cohort_dir / "final_cohort" try: for split, expected_df_L in MEDS_OUTPUTS.items(): - if expected_df_L is None: - continue - if not isinstance(expected_df_L, list): expected_df_L = [expected_df_L] @@ -487,8 +499,6 @@ def test_extraction(): print(f"stdout:\n{full_stdout}") raise e - logger.warning("Only checked the train/0 split for now. TODO: add the rest of the splits.") - # Step 4: Merge to the final output stderr, stdout = run_command( extraction_root / "collect_code_metadata.py", @@ -504,4 +514,21 @@ def test_extraction(): output_file = MEDS_cohort_dir / "code_metadata.parquet" assert output_file.is_file(), f"Expected {output_file} to exist: stderr:\n{stderr}\nstdout:\n{stdout}" - logger.warning("Didn't check contents of code metadata!") + got_df = pl.read_parquet(output_file, glob=False) + + want_df = pl.read_csv(source=StringIO(MEDS_OUTPUT_CODE_METADATA_FILE)).with_columns( + pl.col("code").cast(pl.Categorical), + pl.col("code/n_occurrences").cast(pl.UInt32), + pl.col("code/n_patients").cast(pl.UInt32), + pl.col("values/n_occurrences").cast(pl.UInt32), + pl.col("values/sum").cast(pl.Float64).fill_null(0), + pl.col("values/sum_sqd").cast(pl.Float64).fill_null(0), + ) + + assert_df_equal( + want=want_df, + got=got_df, + msg="Code metadata differs!", + check_column_order=False, + check_row_order=False, + ) From 1ec3934e5215e41a8e638fe572ed7c71406e6a85 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Fri, 21 Jun 2024 07:12:38 -0400 Subject: [PATCH 52/53] Re-arranged import statements --- src/MEDS_polars_functions/filter_measurements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MEDS_polars_functions/filter_measurements.py b/src/MEDS_polars_functions/filter_measurements.py index b3deeb2..a6abc18 100644 --- a/src/MEDS_polars_functions/filter_measurements.py +++ b/src/MEDS_polars_functions/filter_measurements.py @@ -3,9 +3,9 @@ from collections.abc import Callable import polars as pl +from omegaconf import DictConfig pl.enable_string_cache() -from omegaconf import DictConfig def filter_codes_fntr( From 410e6ce4d4024e081fcd0987e28b3f191bdbf27f Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Fri, 21 Jun 2024 07:15:15 -0400 Subject: [PATCH 53/53] Updated workflow --- .github/workflows/tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 5041078..908adc5 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -33,7 +33,7 @@ jobs: #---------------------------------------------- - name: Run tests run: | - pytest -v --doctest-modules --cov + pytest -v --doctest-modules --cov=src -s - name: Upload coverage to Codecov uses: codecov/codecov-action@v4.0.1