diff --git a/README.md b/README.md index fbed6d3..6680f37 100644 --- a/README.md +++ b/README.md @@ -420,3 +420,14 @@ To use either of these, you need to install additional optional dependencies: files. Similar to task configuration files, but for models. 2. Figure out how to ensure that each pre-processing step reads from the right prior files. Likely need some kind of a "prior stage name" config variable. + +## Notes: + +You can overwrite the `stages` parameter on the command line to run a dynamic pipeline with just a subset of +options (the `--cfg job --resolve` is just to make hydra show the induced, resolved config instead of trying +to run anything): + +```bash +MEDS_polars_functions on  reusable_interface [$⇡] is 󰏗 v0.0.1 via  v3.12.4 via  MEDS_fns +❯ ./src/MEDS_polars_functions/scripts/preprocessing/normalize.py input_dir=foo cohort_dir=bar 'stages=["normalize", "tensorize"]' --cfg job --resolve +``` diff --git a/configs/extraction.yaml b/configs/extraction.yaml deleted file mode 100644 index a0939e8..0000000 --- a/configs/extraction.yaml +++ /dev/null @@ -1,89 +0,0 @@ -defaults: - - pipeline - - _self_ - -description: |- - This pipeline extracts raw MEDS events in longitudinal, sparse form from an input dataset meeting select - criteria and converts them to the flattened, MEDS format. It can be run in its entirety, with controllable - levels of parallelism, or in stages. Arguments: - - `event_conversion_config_fp`: The path to the event conversion configuration file. This file defines - the events to extract from the various rows of the various input files encountered in the global input - directory. - - `input_dir`: The path to the directory containing the raw input files. - - `cohort_dir`: The path to the directory where the output cohort will be written. It will be written in - various subfolders of this dir depending on the stage, as intermediate stages cache their output during - computation for efficiency of re-running and distributing. - -# The event conversion configuration file is used throughout the pipeline to define the events to extract. -event_conversion_config_fp: ??? -# The code modifier columns are in this pipeline only used in the collect_code_metadata stage. -code_modifiers: null - -stages: - - shard_events - - split_and_shard_patients - - convert_to_sharded_events - - merge_to_MEDS_cohort - - collect_code_metadata - -stage_configs: - shard_events: - description: |- - This stage shards the raw input events into smaller files for easier processing. Arguments: - - `row_chunksize`: The number of rows to read in at a time. - - `infer_schema_length`: The number of rows to read in to infer the schema (only used if the source - files are csvs) - row_chunksize: 200000000 - infer_schema_length: 10000 - - split_and_shard_patients: - description: |- - This stage splits the patients into training, tuning, and held-out sets, and further splits those sets - into shards. Arguments: - - `n_patients_per_shard`: The number of patients to include in a shard. - - `external_splits_json_fp`: The path to a json file containing any pre-defined splits for specially - held-out test sets beyond the IID held out set that will be produced (e.g., for prospective - datasets, etc.). - - `split_fracs`: The fraction of patients to include in the IID training, tuning, and held-out sets. - Split fractions can be changed for the default names by adding a hydra-syntax command line argument - for the nested name; e.g., `split_fracs.train=0.7 split_fracs.tuning=0.1 split_fracs.held_out=0.2`. - A split can be removed with the `~` override Hydra syntax. Similarly, a new split name can be added - with the standard Hydra `+` override option. E.g., `~split_fracs.held_out +split_fracs.test=0.1`. It - is the user's responsibility to ensure that split fractions sum to 1. - is_metadata: True - output_dir: ${cohort_dir} - n_patients_per_shard: 50000 - external_splits_json_fp: null - split_fracs: - train: 0.8 - tuning: 0.1 - held_out: 0.1 - - merge_to_MEDS_cohort: - description: |- - This stage splits the patients into training, tuning, and held-out sets, and further splits those sets - into shards. Arguments: - - `n_patients_per_shard`: The number of patients to include in a shard. - - `external_splits_json_fp`: The path to a json file containing any pre-defined splits for specially - held-out test sets beyond the IID held out set that will be produced (e.g., for prospective - datasets, etc.). - - `split_fracs`: The fraction of patients to include in the IID training, tuning, and held-out sets. - output_dir: ${cohort_dir}/final_cohort - unique_by: "*" - - collect_code_metadata: - description: |- - This stage collects some descriptive metadata about the codes in the cohort. Arguments: - - `aggregations`: The aggregations to compute over the codes. Defaults to counts of code occurrences, - counts of patients with the code, and counts of value occurrences per code, as well as the sum and - sum of squares of values (for use in computing means and variances). - aggregations: - - "code/n_occurrences" - - "code/n_patients" - - "values/n_occurrences" - - "values/sum" - - "values/sum_sqd" - do_summarize_over_all_codes: true # This indicates we should include overall, code-independent counts - is_metadata: True - mapper_output_dir: "${cohort_dir}/code_metadata" - output_dir: "${cohort_dir}" diff --git a/configs/preprocess.yaml b/configs/preprocess.yaml deleted file mode 100644 index b42a221..0000000 --- a/configs/preprocess.yaml +++ /dev/null @@ -1,69 +0,0 @@ -defaults: - - pipeline - - _self_ - -# Global pipeline parameters: -# 1. Code modifiers will be used as adjoining parts of the `code` columns during group-bys and eventual -# tokenization. -code_modifier_columns: ??? - -stages: - - filter_patients - - add_time_derived_measurements - - preliminary_counts - - filter_codes - - fit_outlier_detection - - filter_outliers - - fit_normalization - - fit_vocabulary_indices - - normalize - - tokenization - - tensorize - -# Pipeline Structure -stage_configs: - filter_patients: - min_events_per_patient: null - min_measurements_per_patient: null - data_input_dir: ${input_dir}/final_cohort - - add_time_derived_measurements: - age: - DOB_code: ??? - age_code: "AGE" - age_unit: "years" - time_of_day: - time_of_day_code: "TIME_OF_DAY" - endpoints: [6, 12, 18, 24] - - preliminary_counts: - aggregations: - - "code/n_occurrences" - - "code/n_patients" - do_summarize_over_all_codes: true # This indicates we should include overall, code-independent counts - - filter_codes: - min_patients_per_code: null - min_occurrences_per_code: null - - fit_outlier_detection: - aggregations: - - "values/n_occurrences" - - "values/sum" - - "values/sum_sqd" - - filter_outliers: - stddev_cutoff: 4.5 - - fit_normalization: - aggregations: - - "code/n_occurrences" - - "code/n_patients" - - "values/n_occurrences" - - "values/sum" - - "values/sum_sqd" - - fit_vocabulary_indices: - is_metadata: true - ordering_method: "lexicographic" - output_dir: "${cohort_dir}" diff --git a/pyproject.toml b/pyproject.toml index c50f1ba..941f74b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=61.0"] +requires = ["setuptools>=61.0", "setuptools-scm>=8.0", "wheel"] build-backend = "setuptools.build_meta" [project] @@ -16,7 +16,7 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ] -dependencies = ["polars", "pyarrow", "nested_ragged_tensors", "loguru", "hydra-core", "numpy"] +dependencies = ["polars>=1.1.0", "pyarrow", "nested_ragged_tensors", "loguru", "hydra-core", "numpy"] [project.optional-dependencies] examples = ["rootutils"] @@ -25,6 +25,13 @@ tests = ["pytest", "pytest-cov", "rootutils"] local_parallelism = ["hydra-joblib-launcher"] slurm_parallelism = ["hydra-submitit-launcher"] +[project.scripts] +MEDS_extract-split_and_shard_patients = "MEDS_polars_functions.extraction.split_and_shard_patients:main" +MEDS_extract-shard_events = "MEDS_polars_functions.extraction.shard_events:main" +MEDS_extract-convert_to_sharded_events = "MEDS_polars_functions.extraction.convert_to_sharded_events:main" +MEDS_extract-merge_to_MEDS_cohort = "MEDS_polars_functions.extraction.merge_to_MEDS_cohort:main" +MEDS_transform-aggregate_code_metadata = "MEDS_polars_functions.aggregate_code_metadata:main" + [project.urls] Homepage = "https://github.com/mmcdermott/MEDS_polars_functions" Issues = "https://github.com/mmcdermott/MEDS_polars_functions/issues" diff --git a/scripts/extraction/collect_code_metadata.py b/scripts/extraction/collect_code_metadata.py deleted file mode 100755 index 9fb020c..0000000 --- a/scripts/extraction/collect_code_metadata.py +++ /dev/null @@ -1,89 +0,0 @@ -#!/usr/bin/env python - -import json -import random -import time -from datetime import datetime -from pathlib import Path - -import hydra -import polars as pl -from loguru import logger -from omegaconf import DictConfig, OmegaConf - -from MEDS_polars_functions.code_metadata import mapper_fntr, reducer_fntr -from MEDS_polars_functions.mapper import wrap as rwlock_wrap -from MEDS_polars_functions.utils import hydra_loguru_init, write_lazyframe - - -@hydra.main(version_base=None, config_path="../../configs", config_name="extraction") -def main(cfg: DictConfig): - """Computes code metadata.""" - - hydra_loguru_init() - - logger.info( - f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" - f"Stage: {cfg.stage}\n\n" - f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" - ) - - input_dir = Path(cfg.stage_cfg.data_input_dir) - output_dir = Path(cfg.stage_cfg.output_dir) - mapper_output_dir = Path(cfg.stage_cfg.mapper_output_dir) - - shards = json.loads((Path(cfg.stage_cfg.metadata_input_dir) / "splits.json").read_text()) - - examine_splits = [f"{sp}/" for sp in cfg.stage_cfg.get("examine_splits", ["train"])] - logger.info(f"Computing metadata over shards with any prefix in {examine_splits}") - shards = {k: v for k, v in shards.items() if any(k.startswith(prefix) for prefix in examine_splits)} - - patient_splits = list(shards.keys()) - random.shuffle(patient_splits) - - mapper_fn = mapper_fntr(cfg.stage_cfg, cfg.get("code_modifier_columns", None)) - - start = datetime.now() - logger.info("Starting code metadata mapping computation") - - all_out_fps = [] - for sp in patient_splits: - in_fp = input_dir / f"{sp}.parquet" - out_fp = mapper_output_dir / f"{sp}.parquet" - all_out_fps.append(out_fp) - - logger.info( - f"Computing code metadata for {str(in_fp.resolve())} and storing to {str(out_fp.resolve())}" - ) - - rwlock_wrap( - in_fp, - out_fp, - pl.scan_parquet, - write_lazyframe, - mapper_fn, - do_return=False, - cache_intermediate=False, - do_overwrite=cfg.do_overwrite, - ) - - logger.info(f"Finished mapping in {datetime.now() - start}") - - if cfg.worker != 1: - return - - while not all(fp.is_file() for fp in all_out_fps): - logger.info("Waiting to begin reduction for all files to be written...") - time.sleep(cfg.polling_time) - - start = datetime.now() - logger.info("All map shards complete! Starting code metadata reduction computation.") - reducer_fn = reducer_fntr(cfg.stage_cfg, cfg.get("code_modifier_columns", None)) - - reduced = reducer_fn(*[pl.scan_parquet(fp, glob=False) for fp in all_out_fps]) - write_lazyframe(reduced, output_dir / "code_metadata.parquet") - logger.info(f"Finished reduction in {datetime.now() - start}") - - -if __name__ == "__main__": - main() diff --git a/scripts/extraction/convert_to_sharded_events.py b/scripts/extraction/convert_to_sharded_events.py deleted file mode 100755 index bc1eff3..0000000 --- a/scripts/extraction/convert_to_sharded_events.py +++ /dev/null @@ -1,95 +0,0 @@ -#!/usr/bin/env python - -import copy -import json -import random -from functools import partial -from pathlib import Path - -import hydra -import polars as pl -from loguru import logger -from omegaconf import DictConfig, OmegaConf - -from MEDS_polars_functions.event_conversion import convert_to_events -from MEDS_polars_functions.mapper import wrap as rwlock_wrap -from MEDS_polars_functions.utils import hydra_loguru_init, write_lazyframe - - -@hydra.main(version_base=None, config_path="../../configs", config_name="extraction") -def main(cfg: DictConfig): - """Converts the sub-sharded or raw data into events which are sharded by patient X input shard.""" - - hydra_loguru_init() - - logger.info( - f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" - f"Stage: {cfg.stage}\n\n" - f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" - ) - - shards = json.loads((Path(cfg.stage_cfg.metadata_input_dir) / "splits.json").read_text()) - - event_conversion_cfg_fp = Path(cfg.event_conversion_config_fp) - if not event_conversion_cfg_fp.exists(): - raise FileNotFoundError(f"Event conversion config file not found: {event_conversion_cfg_fp}") - - logger.info("Starting event conversion.") - - logger.info(f"Reading event conversion config from {event_conversion_cfg_fp}") - event_conversion_cfg = OmegaConf.load(event_conversion_cfg_fp) - logger.info(f"Event conversion config:\n{OmegaConf.to_yaml(event_conversion_cfg)}") - - default_patient_id_col = event_conversion_cfg.pop("patient_id_col", "patient_id") - - patient_subsharded_dir = Path(cfg.stage_cfg.output_dir) - patient_subsharded_dir.mkdir(parents=True, exist_ok=True) - OmegaConf.save(event_conversion_cfg, patient_subsharded_dir / "event_conversion_config.yaml") - - patient_splits = list(shards.items()) - random.shuffle(patient_splits) - - event_configs = list(event_conversion_cfg.items()) - random.shuffle(event_configs) - - # Here, we'll be reading files directly, so we'll turn off globbing - read_fn = partial(pl.scan_parquet, glob=False) - - for sp, patients in patient_splits: - for input_prefix, event_cfgs in event_configs: - event_cfgs = copy.deepcopy(event_cfgs) - input_patient_id_column = event_cfgs.pop("patient_id_col", default_patient_id_col) - - event_shards = list((Path(cfg.stage_cfg.data_input_dir) / input_prefix).glob("*.parquet")) - random.shuffle(event_shards) - - for shard_fp in event_shards: - out_fp = patient_subsharded_dir / sp / input_prefix / shard_fp.name - logger.info(f"Converting {shard_fp} to events and saving to {out_fp}") - - def compute_fn(df: pl.LazyFrame) -> pl.LazyFrame: - typed_patients = pl.Series(patients, dtype=df.schema[input_patient_id_column]) - - if input_patient_id_column != "patient_id": - df = df.rename({input_patient_id_column: "patient_id"}) - - try: - logger.info(f"Extracting events for {input_prefix}/{shard_fp.name}") - return convert_to_events( - df.filter(pl.col("patient_id").is_in(typed_patients)), - event_cfgs=copy.deepcopy(event_cfgs), - ) - except Exception as e: - raise ValueError( - f"Error converting {str(shard_fp.resolve())} for {sp}/{input_prefix}: {e}" - ) from e - - rwlock_wrap( - shard_fp, out_fp, read_fn, write_lazyframe, compute_fn, do_overwrite=cfg.do_overwrite - ) - - logger.info("Subsharded into converted events.") - - -if __name__ == "__main__": - main() diff --git a/scripts/extraction/merge_to_MEDS_cohort.py b/scripts/extraction/merge_to_MEDS_cohort.py deleted file mode 100755 index ade8d50..0000000 --- a/scripts/extraction/merge_to_MEDS_cohort.py +++ /dev/null @@ -1,96 +0,0 @@ -#!/usr/bin/env python - -import json -import random -from functools import partial -from pathlib import Path - -import hydra -import polars as pl -from loguru import logger -from omegaconf import DictConfig, OmegaConf - -from MEDS_polars_functions.mapper import wrap as rwlock_wrap -from MEDS_polars_functions.utils import hydra_loguru_init - -pl.enable_string_cache() - - -def read_fn(sp_dir: Path, unique_by: list[str] | str | None) -> pl.LazyFrame: - files_to_read = list(sp_dir.glob("**/*.parquet")) - - if not files_to_read: - raise FileNotFoundError(f"No files found in {sp_dir}/**/*.parquet.") - - file_strs = "\n".join(f" - {str(fp.resolve())}" for fp in files_to_read) - logger.info(f"Reading {len(files_to_read)} files:\n{file_strs}") - - dfs = [pl.scan_parquet(fp, glob=False) for fp in files_to_read] - df = pl.concat(dfs, how="diagonal_relaxed") - - match unique_by: - case None: - pass - case "*": - df = df.unique(maintain_order=False) - case list() if len(unique_by) == 0 and all(isinstance(u, str) for u in unique_by): - subset = [] - for u in unique_by: - if u in df.columns: - subset.append(u) - else: - logger.warning(f"Column {u} not found in dataframe. Omitting from unique-by subset.") - df = df.unique(maintain_order=False, subset=subset) - case _: - raise ValueError(f"Invalid unique_by value: {unique_by}") - - return df.sort(by=["patient_id", "timestamp"], multithreaded=False) - - -def write_fn(df: pl.LazyFrame, out_fp: Path) -> None: - df.collect().write_parquet(out_fp, use_pyarrow=True) - - -def identity_fn(df: pl.LazyFrame) -> pl.LazyFrame: - return df - - -@hydra.main(version_base=None, config_path="../../configs", config_name="extraction") -def main(cfg: DictConfig): - """Merges the patient sub-sharded events into a single parquet file per patient shard.""" - - hydra_loguru_init() - - logger.info( - f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" - f"Stage: {cfg.stage}\n\n" - f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" - ) - - shards = json.loads((Path(cfg.stage_cfg.metadata_input_dir) / "splits.json").read_text()) - - logger.info("Starting patient shard merging.") - - patient_subsharded_dir = Path(cfg.stage_cfg.data_input_dir) - if not patient_subsharded_dir.is_dir(): - raise FileNotFoundError(f"Patient sub-sharded directory not found: {patient_subsharded_dir}") - - patient_splits = list(shards.keys()) - random.shuffle(patient_splits) - - reader = partial(read_fn, unique_by=cfg.stage_cfg.get("unique_by", None)) - - for sp in patient_splits: - in_dir = patient_subsharded_dir / sp - out_fp = Path(cfg.stage_cfg.output_dir) / f"{sp}.parquet" - - shard_fps = sorted(list(in_dir.glob("**/*.parquet"))) - shard_fp_strs = [f" * {str(fp.resolve())}" for fp in shard_fps] - logger.info(f"Merging {len(shard_fp_strs)} shards into {out_fp}:\n" + "\n".join(shard_fp_strs)) - rwlock_wrap(in_dir, out_fp, reader, write_fn, identity_fn, do_return=False) - - logger.info("Output cohort written.") - - -if __name__ == "__main__": - main() diff --git a/scripts/extraction/split_and_shard_patients.py b/scripts/extraction/split_and_shard_patients.py deleted file mode 100755 index 6cd03b7..0000000 --- a/scripts/extraction/split_and_shard_patients.py +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env python - -import json -from pathlib import Path - -import hydra -import polars as pl -from loguru import logger -from omegaconf import DictConfig, OmegaConf - -from MEDS_polars_functions.sharding import shard_patients -from MEDS_polars_functions.utils import hydra_loguru_init - - -@hydra.main(version_base=None, config_path="../../configs", config_name="extraction") -def main(cfg: DictConfig): - """Extracts the set of unique patients from the raw data and splits/shards them and saves the result.""" - - hydra_loguru_init() - - logger.info( - f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" - f"Stage: {cfg.stage}\n\n" - f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" - ) - - logger.info("Starting patient splitting and sharding") - - MEDS_cohort_dir = Path(cfg.stage_cfg.output_dir) - subsharded_dir = Path(cfg.stage_cfg.data_input_dir) - - event_conversion_cfg_fp = Path(cfg.event_conversion_config_fp) - if not event_conversion_cfg_fp.exists(): - raise FileNotFoundError(f"Event conversion config file not found: {event_conversion_cfg_fp}") - - logger.info( - f"Reading event conversion config from {event_conversion_cfg_fp} (needed for patient ID columns)" - ) - event_conversion_cfg = OmegaConf.load(event_conversion_cfg_fp) - logger.info(f"Event conversion config:\n{OmegaConf.to_yaml(event_conversion_cfg)}") - - dfs = [] - - default_patient_id_col = event_conversion_cfg.pop("patient_id_col", "patient_id") - for input_prefix, event_cfgs in event_conversion_cfg.items(): - input_patient_id_column = event_cfgs.get("patient_id_col", default_patient_id_col) - - input_fps = list((subsharded_dir / input_prefix).glob("**/*.parquet")) - - input_fps_strs = "\n".join(f" - {str(fp.resolve())}" for fp in input_fps) - logger.info(f"Reading patient IDs from {input_prefix} files:\n{input_fps_strs}") - - for input_fp in input_fps: - dfs.append( - pl.scan_parquet(input_fp, glob=False) - .select(pl.col(input_patient_id_column).alias("patient_id")) - .unique() - ) - - logger.info(f"Joining all patient IDs from {len(dfs)} dataframes") - patient_ids = ( - pl.concat(dfs) - .select(pl.col("patient_id").drop_nulls().drop_nans().unique()) - .collect(streaming=True)["patient_id"] - .to_numpy(use_pyarrow=True) - ) - - logger.info(f"Found {len(patient_ids)} unique patient IDs of type {patient_ids.dtype}") - - if cfg.stage_cfg.external_splits_json_fp: - external_splits_json_fp = Path(cfg.stage_cfg.external_splits_json_fp) - if not external_splits_json_fp.exists(): - raise FileNotFoundError(f"External splits JSON file not found at {external_splits_json_fp}") - - logger.info(f"Reading external splits from {cfg.external_splits_json_fp}") - external_splits = json.loads(external_splits_json_fp.read_text()) - - size_strs = ", ".join(f"{k}: {len(v)}" for k, v in external_splits.items()) - logger.info(f"Loaded external splits of size: {size_strs}") - else: - external_splits = None - - logger.info("Sharding and splitting patients") - - sharded_patients = shard_patients( - patients=patient_ids, - external_splits=external_splits, - split_fracs_dict=cfg.stage_cfg.split_fracs, - n_patients_per_shard=cfg.stage_cfg.n_patients_per_shard, - seed=cfg.seed, - ) - - logger.info(f"Writing sharded patients to {MEDS_cohort_dir}") - MEDS_cohort_dir.mkdir(parents=True, exist_ok=True) - out_fp = MEDS_cohort_dir / "splits.json" - out_fp.write_text(json.dumps(sharded_patients)) - logger.info(f"Done writing sharded patients to {out_fp}") - - -if __name__ == "__main__": - main() diff --git a/scripts/preprocessing/add_time_derived_measurements.py b/scripts/preprocessing/add_time_derived_measurements.py deleted file mode 100755 index bd1dded..0000000 --- a/scripts/preprocessing/add_time_derived_measurements.py +++ /dev/null @@ -1,84 +0,0 @@ -#!/usr/bin/env python - -import json -import random -from pathlib import Path - -import hydra -import polars as pl -from loguru import logger -from omegaconf import DictConfig, OmegaConf - -from MEDS_polars_functions.mapper import wrap as rwlock_wrap -from MEDS_polars_functions.time_derived_measurements import ( - add_new_events_fntr, - age_fntr, - time_of_day_fntr, -) -from MEDS_polars_functions.utils import hydra_loguru_init, write_lazyframe - -INFERRED_STAGE_KEYS = {"is_metadata", "data_input_dir", "metadata_input_dir", "output_dir"} - - -@hydra.main(version_base=None, config_path="../../configs", config_name="preprocess") -def main(cfg: DictConfig): - """Adds time-derived measurements to a MEDS cohort as separate observations at each unique timestamp.""" - - hydra_loguru_init() - - logger.info( - f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" - f"Stage: {cfg.stage}\n\n" - f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" - ) - - output_dir = Path(cfg.stage_cfg.output_dir) - - shards = json.loads((Path(cfg.stage_cfg.metadata_input_dir) / "splits.json").read_text()) - input_dir = Path(cfg.stage_cfg.data_input_dir) - - logger.info(f"Reading data from {str(input_dir.resolve())}") - - patient_splits = list(shards.keys()) - random.shuffle(patient_splits) - - compute_fns = [] - # We use the raw stages object as the induced `stage_cfg` has extra properties like the input and output - # directories. - for feature_name, feature_cfg in cfg.stage_cfg.items(): - match feature_name: - case "age": - compute_fns.append(add_new_events_fntr(age_fntr(feature_cfg))) - case "time_of_day": - compute_fns.append(add_new_events_fntr(time_of_day_fntr(feature_cfg))) - case str() if feature_name in INFERRED_STAGE_KEYS: - continue - case _: - raise ValueError(f"Unknown time-derived measurement: {feature_name}") - - logger.info(f"Adding {feature_name} via config: {OmegaConf.to_yaml(feature_cfg)}") - - for sp in patient_splits: - in_fp = input_dir / f"{sp}.parquet" - out_fp = output_dir / f"{sp}.parquet" - - logger.info( - f"Adding time derived measurements to {str(in_fp.resolve())} into {str(out_fp.resolve())}" - ) - - rwlock_wrap( - in_fp, - out_fp, - pl.scan_parquet, - write_lazyframe, - *compute_fns, - do_return=False, - cache_intermediate=False, - do_overwrite=cfg.do_overwrite, - ) - - logger.info("Added time-derived measurements.") - - -if __name__ == "__main__": - main() diff --git a/scripts/preprocessing/collect_code_metadata.py b/scripts/preprocessing/collect_code_metadata.py deleted file mode 100755 index d69bbc8..0000000 --- a/scripts/preprocessing/collect_code_metadata.py +++ /dev/null @@ -1,91 +0,0 @@ -#!/usr/bin/env python - -import json -import random -import time -from datetime import datetime -from pathlib import Path - -import hydra -import polars as pl -import polars.selectors as cs -from loguru import logger -from omegaconf import DictConfig, OmegaConf - -from MEDS_polars_functions.code_metadata import mapper_fntr, reducer_fntr -from MEDS_polars_functions.mapper import wrap as rwlock_wrap -from MEDS_polars_functions.utils import hydra_loguru_init, write_lazyframe - - -@hydra.main(version_base=None, config_path="../../configs", config_name="preprocess") -def main(cfg: DictConfig): - """Computes code metadata.""" - - hydra_loguru_init() - - logger.info( - f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" - f"Stage: {cfg.stage}\n\n" - f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" - ) - - input_dir = Path(cfg.stage_cfg.data_input_dir) - output_dir = Path(cfg.stage_cfg.output_dir) - - shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) - - examine_splits = [f"{sp}/" for sp in cfg.stage_cfg.get("examine_splits", ["train"])] - logger.info(f"Computing metadata over shards with any prefix in {examine_splits}") - shards = {k: v for k, v in shards.items() if any(k.startswith(prefix) for prefix in examine_splits)} - - patient_splits = list(shards.keys()) - random.shuffle(patient_splits) - - mapper_fn = mapper_fntr(cfg.stage_cfg, cfg.get("code_modifier_columns", None)) - - start = datetime.now() - logger.info("Starting code metadata mapping computation") - - all_out_fps = [] - for sp in patient_splits: - in_fp = input_dir / f"{sp}.parquet" - out_fp = output_dir / f"{sp}.parquet" - all_out_fps.append(out_fp) - - logger.info( - f"Computing code metadata for {str(in_fp.resolve())} and storing to {str(out_fp.resolve())}" - ) - - rwlock_wrap( - in_fp, - out_fp, - pl.scan_parquet, - write_lazyframe, - mapper_fn, - do_return=False, - cache_intermediate=False, - do_overwrite=cfg.do_overwrite, - ) - - logger.info(f"Finished mapping in {datetime.now() - start}") - - if cfg.worker != 0: - return - - while not all(fp.is_file() for fp in all_out_fps): - logger.info("Waiting to begin reduction for all files to be written...") - time.sleep(cfg.polling_time) - - start = datetime.now() - logger.info("All map shards complete! Starting code metadata reduction computation.") - reducer_fn = reducer_fntr(cfg.stage_cfg, cfg.get("code_modifier_columns", None)) - - reduced = reducer_fn(*[pl.scan_parquet(fp, glob=False) for fp in all_out_fps]).with_columns( - cs.numeric().shrink_dtype().keep_name() - ) - write_lazyframe(reduced, output_dir / "code_metadata.parquet") - logger.info(f"Finished reduction in {datetime.now() - start}") - - -if __name__ == "__main__": - main() diff --git a/scripts/preprocessing/filter_codes.py b/scripts/preprocessing/filter_codes.py deleted file mode 100755 index d086f5e..0000000 --- a/scripts/preprocessing/filter_codes.py +++ /dev/null @@ -1,62 +0,0 @@ -#!/usr/bin/env python - -import json -import random -from pathlib import Path - -import hydra -import polars as pl -from loguru import logger -from omegaconf import DictConfig, OmegaConf - -from MEDS_polars_functions.filter_measurements import filter_codes_fntr -from MEDS_polars_functions.mapper import wrap as rwlock_wrap -from MEDS_polars_functions.utils import hydra_loguru_init, write_lazyframe - - -@hydra.main(version_base=None, config_path="../../configs", config_name="preprocess") -def main(cfg: DictConfig): - """TODO.""" - - hydra_loguru_init() - - logger.info( - f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" - f"Stage: {cfg.stage}\n\n" - f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" - ) - - input_dir = Path(cfg.stage_cfg.data_input_dir) - metadata_input_dir = Path(cfg.stage_cfg.metadata_input_dir) - output_dir = Path(cfg.stage_cfg.output_dir) - - shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) - - patient_splits = list(shards.keys()) - random.shuffle(patient_splits) - - code_metadata = pl.read_parquet(metadata_input_dir / "code_metadata.parquet", use_pyarrow=True) - compute_fn = filter_codes_fntr(cfg.stage_cfg, code_metadata) - - for sp in patient_splits: - in_fp = input_dir / f"{sp}.parquet" - out_fp = output_dir / f"{sp}.parquet" - - logger.info(f"Filtering {str(in_fp.resolve())} into {str(out_fp.resolve())}") - - rwlock_wrap( - in_fp, - out_fp, - pl.scan_parquet, - write_lazyframe, - compute_fn, - do_return=False, - cache_intermediate=False, - do_overwrite=cfg.do_overwrite, - ) - - logger.info(f"Done with {cfg.stage}") - - -if __name__ == "__main__": - main() diff --git a/scripts/preprocessing/filter_outliers.py b/scripts/preprocessing/filter_outliers.py deleted file mode 100755 index 1643e07..0000000 --- a/scripts/preprocessing/filter_outliers.py +++ /dev/null @@ -1,62 +0,0 @@ -#!/usr/bin/env python - -import json -import random -from pathlib import Path - -import hydra -import polars as pl -from loguru import logger -from omegaconf import DictConfig, OmegaConf - -from MEDS_polars_functions.filter_measurements import filter_outliers_fntr -from MEDS_polars_functions.mapper import wrap as rwlock_wrap -from MEDS_polars_functions.utils import hydra_loguru_init, write_lazyframe - - -@hydra.main(version_base=None, config_path="../../configs", config_name="preprocess") -def main(cfg: DictConfig): - """TODO.""" - - hydra_loguru_init() - - logger.info( - f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" - f"Stage: {cfg.stage}\n\n" - f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" - ) - - input_dir = Path(cfg.stage_cfg.data_input_dir) - metadata_input_dir = Path(cfg.stage_cfg.metadata_input_dir) - output_dir = Path(cfg.stage_cfg.output_dir) - - shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) - - patient_splits = list(shards.keys()) - random.shuffle(patient_splits) - - code_metadata = pl.read_parquet(metadata_input_dir / "code_metadata.parquet", use_pyarrow=True) - compute_fn = filter_outliers_fntr(cfg.stage_cfg, code_metadata) - - for sp in patient_splits: - in_fp = input_dir / f"{sp}.parquet" - out_fp = output_dir / f"{sp}.parquet" - - logger.info(f"Filtering {str(in_fp.resolve())} into {str(out_fp.resolve())}") - - rwlock_wrap( - in_fp, - out_fp, - pl.scan_parquet, - write_lazyframe, - compute_fn, - do_return=False, - cache_intermediate=False, - do_overwrite=cfg.do_overwrite, - ) - - logger.info(f"Done with {cfg.stage}") - - -if __name__ == "__main__": - main() diff --git a/scripts/preprocessing/filter_patients.py b/scripts/preprocessing/filter_patients.py deleted file mode 100755 index 602ae20..0000000 --- a/scripts/preprocessing/filter_patients.py +++ /dev/null @@ -1,85 +0,0 @@ -#!/usr/bin/env python - -import json -import random -from functools import partial -from pathlib import Path - -import hydra -import polars as pl -from loguru import logger -from omegaconf import DictConfig, OmegaConf - -from MEDS_polars_functions.filter_patients_by_length import ( - filter_patients_by_num_events, - filter_patients_by_num_measurements, -) -from MEDS_polars_functions.mapper import wrap as rwlock_wrap -from MEDS_polars_functions.utils import hydra_loguru_init, write_lazyframe - - -@hydra.main(version_base=None, config_path="../../configs", config_name="preprocess") -def main(cfg: DictConfig): - """TODO.""" - - hydra_loguru_init() - - logger.info( - f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" - f"Stage: {cfg.stage}\n\n" - f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" - ) - - input_dir = Path(cfg.stage_cfg.data_input_dir) - output_dir = Path(cfg.stage_cfg.output_dir) - - shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) - - patient_splits = list(shards.keys()) - random.shuffle(patient_splits) - - compute_fns = [] - if cfg.stage_cfg.min_measurements_per_patient: - logger.info( - f"Filtering patients with fewer than {cfg.stage_cfg.min_measurements_per_patient} measurements " - "(observations of any kind)." - ) - compute_fns.append( - partial( - filter_patients_by_num_measurements, - min_measurements_per_patient=cfg.stage_cfg.min_measurements_per_patient, - ) - ) - if cfg.stage_cfg.min_events_per_patient: - logger.info( - f"Filtering patients with fewer than {cfg.stage_cfg.min_events_per_patient} events " - "(unique timepoints)." - ) - compute_fns.append( - partial( - filter_patients_by_num_events, min_events_per_patient=cfg.stage_cfg.min_events_per_patient - ) - ) - - for sp in patient_splits: - in_fp = input_dir / f"{sp}.parquet" - out_fp = output_dir / f"{sp}.parquet" - - logger.info(f"Filtering {str(in_fp.resolve())} into {str(out_fp.resolve())}") - - rwlock_wrap( - in_fp, - out_fp, - pl.scan_parquet, - write_lazyframe, - *compute_fns, - do_return=False, - cache_intermediate=False, - do_overwrite=cfg.do_overwrite, - ) - - logger.info("Filtered patients.") - - -if __name__ == "__main__": - main() diff --git a/scripts/preprocessing/fit_vocabulary_indices.py b/scripts/preprocessing/fit_vocabulary_indices.py deleted file mode 100755 index 73c56e7..0000000 --- a/scripts/preprocessing/fit_vocabulary_indices.py +++ /dev/null @@ -1,60 +0,0 @@ -#!/usr/bin/env python - -from pathlib import Path - -import hydra -import polars as pl -from loguru import logger -from omegaconf import DictConfig, OmegaConf - -from MEDS_polars_functions.get_vocabulary import ( - VOCABULARY_ORDERING, - VOCABULARY_ORDERING_METHODS, -) -from MEDS_polars_functions.utils import hydra_loguru_init - - -@hydra.main(version_base=None, config_path="../../configs", config_name="preprocess") -def main(cfg: DictConfig): - """TODO.""" - - hydra_loguru_init() - - logger.info( - f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" - f"Stage: {cfg.stage}\n\n" - f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" - ) - - metadata_input_dir = Path(cfg.stage_cfg.metadata_input_dir) - output_dir = Path(cfg.stage_cfg.output_dir) - - code_metadata = pl.read_parquet(metadata_input_dir / "code_metadata.parquet", use_pyarrow=True) - - ordering_method = cfg.stage_cfg.get("ordering_method", VOCABULARY_ORDERING.LEXICOGRAPHIC) - - if ordering_method not in VOCABULARY_ORDERING_METHODS: - raise ValueError( - f"Invalid ordering method: {ordering_method}. " - f"Expected one of {', '.join(VOCABULARY_ORDERING_METHODS.keys())}" - ) - - logger.info(f"Assigning code vocabulary indices via a {ordering_method} order.") - ordering_fn = VOCABULARY_ORDERING_METHODS[ordering_method] - - code_modifiers = cfg.get("code_modifier_columns", None) - if code_modifiers is None: - code_modifiers = [] - - code_metadata = ordering_fn(code_metadata, code_modifiers) - - output_fp = output_dir / "code_metadata.parquet" - logger.info(f"Indices assigned. Writing to {output_fp}") - - code_metadata.write_parquet(output_fp, use_pyarrow=True) - - logger.info(f"Done with {cfg.stage}") - - -if __name__ == "__main__": - main() diff --git a/scripts/preprocessing/normalize.py b/scripts/preprocessing/normalize.py deleted file mode 100755 index a80522c..0000000 --- a/scripts/preprocessing/normalize.py +++ /dev/null @@ -1,64 +0,0 @@ -#!/usr/bin/env python - -import json -import random -from functools import partial -from pathlib import Path - -import hydra -import polars as pl -from loguru import logger -from omegaconf import DictConfig, OmegaConf - -from MEDS_polars_functions.mapper import wrap as rwlock_wrap -from MEDS_polars_functions.normalization import normalize -from MEDS_polars_functions.utils import hydra_loguru_init, write_lazyframe - - -@hydra.main(version_base=None, config_path="../../configs", config_name="preprocess") -def main(cfg: DictConfig): - """TODO.""" - - hydra_loguru_init() - - logger.info( - f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" - f"Stage: {cfg.stage}\n\n" - f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" - ) - - input_dir = Path(cfg.stage_cfg.data_input_dir) - metadata_input_dir = Path(cfg.stage_cfg.metadata_input_dir) - output_dir = Path(cfg.stage_cfg.output_dir) - - shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) - - patient_splits = list(shards.keys()) - random.shuffle(patient_splits) - - code_metadata = pl.read_parquet(metadata_input_dir / "code_metadata.parquet", use_pyarrow=True).lazy() - code_modifiers = cfg.get("code_modifier_columns", None) - compute_fn = partial(normalize, code_metadata=code_metadata, code_modifiers=code_modifiers) - - for sp in patient_splits: - in_fp = input_dir / f"{sp}.parquet" - out_fp = output_dir / f"{sp}.parquet" - - logger.info(f"Filtering {str(in_fp.resolve())} into {str(out_fp.resolve())}") - - rwlock_wrap( - in_fp, - out_fp, - pl.scan_parquet, - write_lazyframe, - compute_fn, - do_return=False, - cache_intermediate=False, - do_overwrite=cfg.do_overwrite, - ) - - logger.info(f"Done with {cfg.stage}") - - -if __name__ == "__main__": - main() diff --git a/scripts/preprocessing/tensorize.py b/scripts/preprocessing/tensorize.py deleted file mode 100755 index 5ebd2f7..0000000 --- a/scripts/preprocessing/tensorize.py +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env python - -import json -import random -from pathlib import Path - -import hydra -import polars as pl -from loguru import logger -from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict -from omegaconf import DictConfig, OmegaConf - -from MEDS_polars_functions.mapper import wrap as rwlock_wrap -from MEDS_polars_functions.tensorize import convert_to_NRT -from MEDS_polars_functions.utils import hydra_loguru_init - - -@hydra.main(version_base=None, config_path="../../configs", config_name="preprocess") -def main(cfg: DictConfig): - """TODO.""" - - hydra_loguru_init() - - logger.info( - f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" - f"Stage: {cfg.stage}\n\n" - f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" - ) - - input_dir = Path(cfg.stage_cfg.data_input_dir) - output_dir = Path(cfg.stage_cfg.output_dir) - - shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) - - patient_splits = list(shards.keys()) - random.shuffle(patient_splits) - - for sp in patient_splits: - in_fp = input_dir / "event_seqs" / f"{sp}.parquet" - out_fp = output_dir / f"{sp}.nrt" - - logger.info(f"Tensorizing {str(in_fp.resolve())} into {str(out_fp.resolve())}") - - rwlock_wrap( - in_fp, - out_fp, - pl.scan_parquet, - JointNestedRaggedTensorDict.save, - convert_to_NRT, - do_return=False, - cache_intermediate=False, - do_overwrite=cfg.do_overwrite, - ) - - logger.info(f"Done with {cfg.stage}") - - -if __name__ == "__main__": - main() diff --git a/scripts/preprocessing/tokenization.py b/scripts/preprocessing/tokenization.py deleted file mode 100755 index 0a7cf9a..0000000 --- a/scripts/preprocessing/tokenization.py +++ /dev/null @@ -1,75 +0,0 @@ -#!/usr/bin/env python - -import json -import random -from pathlib import Path - -import hydra -import polars as pl -from loguru import logger -from omegaconf import DictConfig, OmegaConf - -from MEDS_polars_functions.mapper import wrap as rwlock_wrap -from MEDS_polars_functions.tokenize import ( - extract_seq_of_patient_events, - extract_statics_and_schema, -) -from MEDS_polars_functions.utils import hydra_loguru_init, write_lazyframe - - -@hydra.main(version_base=None, config_path="../../configs", config_name="preprocess") -def main(cfg: DictConfig): - """TODO.""" - - hydra_loguru_init() - - logger.info( - f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" - f"Stage: {cfg.stage}\n\n" - f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" - ) - - input_dir = Path(cfg.stage_cfg.data_input_dir) - output_dir = Path(cfg.stage_cfg.output_dir) - - shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) - - patient_splits = list(shards.keys()) - random.shuffle(patient_splits) - - for sp in patient_splits: - in_fp = input_dir / f"{sp}.parquet" - schema_out_fp = output_dir / "schemas" / f"{sp}.parquet" - event_seq_out_fp = output_dir / "event_seqs" / f"{sp}.parquet" - - logger.info(f"Tokenizing {str(in_fp.resolve())} into schemas at {str(schema_out_fp.resolve())}") - - rwlock_wrap( - in_fp, - schema_out_fp, - pl.scan_parquet, - write_lazyframe, - extract_statics_and_schema, - do_return=False, - cache_intermediate=False, - do_overwrite=cfg.do_overwrite, - ) - - logger.info(f"Tokenizing {str(in_fp.resolve())} into event_seqs at {str(event_seq_out_fp.resolve())}") - - rwlock_wrap( - in_fp, - event_seq_out_fp, - pl.scan_parquet, - write_lazyframe, - extract_seq_of_patient_events, - do_return=False, - cache_intermediate=False, - do_overwrite=cfg.do_overwrite, - ) - - logger.info(f"Done with {cfg.stage}") - - -if __name__ == "__main__": - main() diff --git a/src/MEDS_polars_functions/code_metadata.py b/src/MEDS_polars_functions/aggregate_code_metadata.py old mode 100644 new mode 100755 similarity index 94% rename from src/MEDS_polars_functions/code_metadata.py rename to src/MEDS_polars_functions/aggregate_code_metadata.py index 19a079c..3fc287d --- a/src/MEDS_polars_functions/code_metadata.py +++ b/src/MEDS_polars_functions/aggregate_code_metadata.py @@ -1,13 +1,23 @@ +#!/usr/bin/env python """Utilities for grouping and/or reducing MEDS cohort files by code to collect metadata properties.""" +import time from collections.abc import Callable, Sequence +from datetime import datetime from enum import StrEnum +from importlib.resources import files +from pathlib import Path from typing import NamedTuple +import hydra import polars as pl import polars.selectors as cs +from loguru import logger from omegaconf import DictConfig, ListConfig, OmegaConf +from MEDS_polars_functions.mapreduce.mapper import map_over +from MEDS_polars_functions.utils import write_lazyframe + pl.enable_string_cache() @@ -414,7 +424,7 @@ def reducer_fntr( ... "values/n_ints": [4, 0, 1, 3, 1], ... "values/sum": [13.2, 2.2, 6.0, 14.0, 12.5], ... "values/sum_sqd": [21.3, 2.42, 36.0, 84.0, 81.25], - ... "values/min": [-1, 0, -1, 2, 2.], + ... "values/min": [-1, 0, -1, 2, 2], ... "values/max": [8.0, 1.1, 6.0, 8.0, 7.5], ... }) >>> df_2 = pl.DataFrame({ @@ -427,7 +437,7 @@ def reducer_fntr( ... "values/n_ints": [0, 1, 3, 1], ... "values/sum": [0., 7.0, 14.0, 12.5], ... "values/sum_sqd": [0., 103.2, 84.0, 81.25], - ... "values/min": [None, -1, 0.2, -2.], + ... "values/min": [None, -1., 0.2, -2.], ... "values/max": [None, 6.2, 1.0, 1.5], ... }) >>> df_3 = pl.DataFrame({ @@ -553,3 +563,44 @@ def reducer(*dfs: Sequence[pl.LazyFrame]) -> pl.LazyFrame: return df.select(*code_key_columns, **agg_operations).sort(code_key_columns) return reducer + + +def run_map_reduce(cfg: DictConfig): + """Stored separately so it can be easily imported into the pre-built extraction pipelines.""" + mapper_fn = mapper_fntr(cfg.stage_cfg, cfg.get("code_modifier_columns", None)) + all_out_fps = map_over(cfg, compute_fn=mapper_fn) + + if cfg.worker != 0: + logger.info("Code metadata mapping completed. Exiting") + return + + logger.info("Starting reduction process") + + while not all(fp.is_file() for fp in all_out_fps): + logger.info("Waiting to begin reduction for all files to be written...") + time.sleep(cfg.polling_time) + + start = datetime.now() + logger.info("All map shards complete! Starting code metadata reduction computation.") + reducer_fn = reducer_fntr(cfg.stage_cfg, cfg.get("code_modifier_columns", None)) + + reduced = reducer_fn(*[pl.scan_parquet(fp, glob=False) for fp in all_out_fps]).with_columns( + cs.numeric().shrink_dtype().name.keep() + ) + logger.debug("For an extraction task specifically, we write out specifically to the cohort dir") + write_lazyframe(reduced, Path(cfg.cohort_dir) / "code_metadata.parquet") + logger.info(f"Finished reduction in {datetime.now() - start}") + + +config_yaml = files("MEDS_polars_functions").joinpath("configs/preprocess.yaml") + + +@hydra.main(version_base=None, config_path=str(config_yaml.parent), config_name=config_yaml.stem) +def main(cfg: DictConfig): + """Computes code metadata.""" + + run_map_reduce(cfg) + + +if __name__ == "__main__": + main() diff --git a/scripts/__init__.py b/src/MEDS_polars_functions/configs/__init__.py similarity index 100% rename from scripts/__init__.py rename to src/MEDS_polars_functions/configs/__init__.py diff --git a/src/MEDS_polars_functions/configs/extraction.yaml b/src/MEDS_polars_functions/configs/extraction.yaml new file mode 100644 index 0000000..f5beba4 --- /dev/null +++ b/src/MEDS_polars_functions/configs/extraction.yaml @@ -0,0 +1,57 @@ +defaults: + - pipeline + - stage_configs: + - shard_events + - split_and_shard_patients + - merge_to_MEDS_cohort + # There is no configuration beyond the global "event_conversion_config_fp" for the + # convert_to_sharded_events stage, so it doesn't have a stage config block here or below. + - _self_ + +description: |- + This pipeline extracts raw MEDS events in longitudinal, sparse form from an input dataset meeting select + criteria and converts them to the flattened, MEDS format. It can be run in its entirety, with controllable + levels of parallelism, or in stages. Arguments: + - `event_conversion_config_fp`: The path to the event conversion configuration file. This file defines + the events to extract from the various rows of the various input files encountered in the global input + directory. + - `input_dir`: The path to the directory containing the raw input files. + - `cohort_dir`: The path to the directory where the output cohort will be written. It will be written in + various subfolders of this dir depending on the stage, as intermediate stages cache their output during + computation for efficiency of re-running and distributing. + +# The event conversion configuration file is used throughout the pipeline to define the events to extract. +event_conversion_config_fp: ??? +# The code modifier columns are in this pipeline only used in the aggregate_code_metadata stage. +code_modifiers: null +# The shards mapping is stored in the root of the final output directory. +shards_map_fp: "${cohort_dir}/splits.json" + +stages: + - shard_events + - split_and_shard_patients + - convert_to_sharded_events + - merge_to_MEDS_cohort + - aggregate_code_metadata + +stage_configs: + aggregate_code_metadata: + description: |- + This stage collects some descriptive metadata about the codes in the cohort. + + Args: + stage_cfg.aggregations: The aggregations to compute over the codes. + Defaults to counts of code occurrences, counts of patients with the code, and counts of value + occurrences per code, as well as the sum and sum of squares of values (for use in computing means + and variances). + aggregations: + - "code/n_occurrences" + - "code/n_patients" + - "values/n_occurrences" + - "values/sum" + - "values/sum_sqd" + do_summarize_over_all_codes: true # This indicates we should include overall, code-independent counts + is_metadata: True + mapper_output_dir: "${cohort_dir}/code_metadata" + output_dir: "${cohort_dir}" + process_shard_prefix: "train/" # we only want to summarize over the training set diff --git a/configs/pipeline.yaml b/src/MEDS_polars_functions/configs/pipeline.yaml similarity index 99% rename from configs/pipeline.yaml rename to src/MEDS_polars_functions/configs/pipeline.yaml index 985866a..eb427fe 100644 --- a/configs/pipeline.yaml +++ b/src/MEDS_polars_functions/configs/pipeline.yaml @@ -15,7 +15,7 @@ stages: ??? # The list of stages to this overall pipeline (in order) stage_configs: ??? # The configurations for each stage, keyed by stage name # Mapreduce information -worker: 1 +worker: 0 polling_time: 300 # wait time in seconds before beginning reduction steps # Filling in the current stage diff --git a/src/MEDS_polars_functions/configs/preprocess.yaml b/src/MEDS_polars_functions/configs/preprocess.yaml new file mode 100644 index 0000000..f86d1d4 --- /dev/null +++ b/src/MEDS_polars_functions/configs/preprocess.yaml @@ -0,0 +1,34 @@ +defaults: + - pipeline + - stage_configs: + - filter_patients + - add_time_derived_measurements + - count_code_occurrences + - filter_codes + - fit_outlier_detection + - filter_outliers + - fit_normalization + - fit_vocabulary_indices + - _self_ + +# Global pipeline parameters: +# 1. Code modifiers will be used as adjoining parts of the `code` columns during group-bys and eventual +# tokenization. +code_modifier_columns: ??? + +# The shards map filepath is stored in the global input directory for model-specific pre-processing. +shards_map_fp: "${input_dir}/splits.json" + +# Pipeline Structure +stages: + - filter_patients + - add_time_derived_measurements + - preliminary_counts + - filter_codes + - fit_outlier_detection + - filter_outliers + - fit_normalization + - fit_vocabulary_indices + - normalize + - tokenization + - tensorize diff --git a/scripts/extraction/__init__.py b/src/MEDS_polars_functions/configs/stage_configs/__init__.py similarity index 100% rename from scripts/extraction/__init__.py rename to src/MEDS_polars_functions/configs/stage_configs/__init__.py diff --git a/src/MEDS_polars_functions/configs/stage_configs/add_time_derived_measurements.yaml b/src/MEDS_polars_functions/configs/stage_configs/add_time_derived_measurements.yaml new file mode 100644 index 0000000..acfbe10 --- /dev/null +++ b/src/MEDS_polars_functions/configs/stage_configs/add_time_derived_measurements.yaml @@ -0,0 +1,8 @@ +add_time_derived_measurements: + age: + DOB_code: ??? + age_code: "AGE" + age_unit: "years" + time_of_day: + time_of_day_code: "TIME_OF_DAY" + endpoints: [6, 12, 18, 24] diff --git a/src/MEDS_polars_functions/configs/stage_configs/count_code_occurrences.yaml b/src/MEDS_polars_functions/configs/stage_configs/count_code_occurrences.yaml new file mode 100644 index 0000000..076a1a0 --- /dev/null +++ b/src/MEDS_polars_functions/configs/stage_configs/count_code_occurrences.yaml @@ -0,0 +1,5 @@ +count_code_occurrences: + aggregations: + - "code/n_occurrences" + - "code/n_patients" + do_summarize_over_all_codes: true # This indicates we should include overall, code-independent counts diff --git a/src/MEDS_polars_functions/configs/stage_configs/filter_codes.yaml b/src/MEDS_polars_functions/configs/stage_configs/filter_codes.yaml new file mode 100644 index 0000000..04b8833 --- /dev/null +++ b/src/MEDS_polars_functions/configs/stage_configs/filter_codes.yaml @@ -0,0 +1,3 @@ +filter_codes: + min_patients_per_code: null + min_occurrences_per_code: null diff --git a/src/MEDS_polars_functions/configs/stage_configs/filter_outliers.yaml b/src/MEDS_polars_functions/configs/stage_configs/filter_outliers.yaml new file mode 100644 index 0000000..6c08285 --- /dev/null +++ b/src/MEDS_polars_functions/configs/stage_configs/filter_outliers.yaml @@ -0,0 +1,2 @@ +filter_outliers: + stddev_cutoff: 4.5 diff --git a/src/MEDS_polars_functions/configs/stage_configs/filter_patients.yaml b/src/MEDS_polars_functions/configs/stage_configs/filter_patients.yaml new file mode 100644 index 0000000..be7c5f9 --- /dev/null +++ b/src/MEDS_polars_functions/configs/stage_configs/filter_patients.yaml @@ -0,0 +1,4 @@ +filter_patients: + min_events_per_patient: null + min_measurements_per_patient: null + data_input_dir: ${input_dir}/final_cohort diff --git a/src/MEDS_polars_functions/configs/stage_configs/fit_normalization.yaml b/src/MEDS_polars_functions/configs/stage_configs/fit_normalization.yaml new file mode 100644 index 0000000..e522470 --- /dev/null +++ b/src/MEDS_polars_functions/configs/stage_configs/fit_normalization.yaml @@ -0,0 +1,7 @@ +fit_normalization: + aggregations: + - "code/n_occurrences" + - "code/n_patients" + - "values/n_occurrences" + - "values/sum" + - "values/sum_sqd" diff --git a/src/MEDS_polars_functions/configs/stage_configs/fit_outlier_detection.yaml b/src/MEDS_polars_functions/configs/stage_configs/fit_outlier_detection.yaml new file mode 100644 index 0000000..07dd4a3 --- /dev/null +++ b/src/MEDS_polars_functions/configs/stage_configs/fit_outlier_detection.yaml @@ -0,0 +1,5 @@ +fit_outlier_detection: + aggregations: + - "values/n_occurrences" + - "values/sum" + - "values/sum_sqd" diff --git a/src/MEDS_polars_functions/configs/stage_configs/fit_vocabulary_indices.yaml b/src/MEDS_polars_functions/configs/stage_configs/fit_vocabulary_indices.yaml new file mode 100644 index 0000000..8725010 --- /dev/null +++ b/src/MEDS_polars_functions/configs/stage_configs/fit_vocabulary_indices.yaml @@ -0,0 +1,4 @@ +fit_vocabulary_indices: + is_metadata: true + ordering_method: "lexicographic" + output_dir: "${cohort_dir}" diff --git a/src/MEDS_polars_functions/configs/stage_configs/merge_to_MEDS_cohort.yaml b/src/MEDS_polars_functions/configs/stage_configs/merge_to_MEDS_cohort.yaml new file mode 100644 index 0000000..aee7a52 --- /dev/null +++ b/src/MEDS_polars_functions/configs/stage_configs/merge_to_MEDS_cohort.yaml @@ -0,0 +1,4 @@ +merge_to_MEDS_cohort: + output_dir: ${cohort_dir}/final_cohort + unique_by: "*" + additional_sort_by: null diff --git a/src/MEDS_polars_functions/configs/stage_configs/shard_events.yaml b/src/MEDS_polars_functions/configs/stage_configs/shard_events.yaml new file mode 100644 index 0000000..96b2238 --- /dev/null +++ b/src/MEDS_polars_functions/configs/stage_configs/shard_events.yaml @@ -0,0 +1,3 @@ +shard_events: + row_chunksize: 200000000 + infer_schema_length: 10000 diff --git a/src/MEDS_polars_functions/configs/stage_configs/split_and_shard_patients.yaml b/src/MEDS_polars_functions/configs/stage_configs/split_and_shard_patients.yaml new file mode 100644 index 0000000..b56aeb1 --- /dev/null +++ b/src/MEDS_polars_functions/configs/stage_configs/split_and_shard_patients.yaml @@ -0,0 +1,9 @@ +split_and_shard_patients: + is_metadata: True + output_dir: ${cohort_dir} + n_patients_per_shard: 50000 + external_splits_json_fp: null + split_fracs: + train: 0.8 + tuning: 0.1 + held_out: 0.1 diff --git a/src/MEDS_polars_functions/extraction/README.md b/src/MEDS_polars_functions/extraction/README.md new file mode 100644 index 0000000..d4ed883 --- /dev/null +++ b/src/MEDS_polars_functions/extraction/README.md @@ -0,0 +1,3 @@ +# Extraction + +These functions and scripts are designed to aid in creating a MEDS view of passed raw data. diff --git a/scripts/preprocessing/__init__.py b/src/MEDS_polars_functions/extraction/__init__.py similarity index 100% rename from scripts/preprocessing/__init__.py rename to src/MEDS_polars_functions/extraction/__init__.py diff --git a/src/MEDS_polars_functions/event_conversion.py b/src/MEDS_polars_functions/extraction/convert_to_sharded_events.py old mode 100644 new mode 100755 similarity index 89% rename from src/MEDS_polars_functions/event_conversion.py rename to src/MEDS_polars_functions/extraction/convert_to_sharded_events.py index eae9505..9cc5259 --- a/src/MEDS_polars_functions/event_conversion.py +++ b/src/MEDS_polars_functions/extraction/convert_to_sharded_events.py @@ -1,16 +1,33 @@ +#!/usr/bin/env python """Utilities for converting input data structures into MEDS events.""" +import copy +import json +import random from collections.abc import Sequence -from functools import reduce +from functools import partial, reduce +from importlib.resources import files +from pathlib import Path +import hydra import polars as pl from loguru import logger +from omegaconf import DictConfig, OmegaConf from omegaconf.listconfig import ListConfig -from .utils import is_col_field, parse_col_field +from MEDS_polars_functions.mapreduce.mapper import rwlock_wrap +from MEDS_polars_functions.utils import ( + is_col_field, + parse_col_field, + stage_init, + write_lazyframe, +) + +config_yaml = files("MEDS_polars_functions").joinpath("configs/extraction.yaml") def in_format(fmt: str, ts_name: str) -> pl.Expr: + """Returns an expression formatting the column ``ts_name`` in timestamp format ``fmt``.""" return pl.col(ts_name).str.strptime(pl.Datetime, fmt, strict=False) @@ -553,3 +570,88 @@ def convert_to_events( df = pl.concat(event_dfs, how="diagonal_relaxed") return df + + +@hydra.main(version_base=None, config_path=str(config_yaml.parent), config_name=config_yaml.stem) +def main(cfg: DictConfig): + """Converts the event-sharded raw data into MEDS events and storing them in patient subsharded flat files. + + All arguments are specified through the command line into the `cfg` object through Hydra. + + The `cfg.stage_cfg` object is a special key that is imputed by OmegaConf to contain the stage-specific + configuration arguments based on the global, pipeline-level configuration file. It cannot be overwritten + directly on the command line, but can be overwritten implicitly by overwriting components of the + `stage_configs.convert_to_sharded_events` key. + + + This stage has no stage-specific configuration arguments. It does, naturally, require the global, + `event_conversion_config_fp` configuration argument to be set to the path of the event conversion yaml + file. + """ + + input_dir, patient_subsharded_dir, metadata_input_dir, shards_map_fn = stage_init(cfg) + + shards = json.loads(shards_map_fn.read_text()) + + event_conversion_cfg_fp = Path(cfg.event_conversion_config_fp) + if not event_conversion_cfg_fp.exists(): + raise FileNotFoundError(f"Event conversion config file not found: {event_conversion_cfg_fp}") + + logger.info("Starting event conversion.") + + logger.info(f"Reading event conversion config from {event_conversion_cfg_fp}") + event_conversion_cfg = OmegaConf.load(event_conversion_cfg_fp) + logger.info(f"Event conversion config:\n{OmegaConf.to_yaml(event_conversion_cfg)}") + + default_patient_id_col = event_conversion_cfg.pop("patient_id_col", "patient_id") + + patient_subsharded_dir.mkdir(parents=True, exist_ok=True) + OmegaConf.save(event_conversion_cfg, patient_subsharded_dir / "event_conversion_config.yaml") + + patient_splits = list(shards.items()) + random.shuffle(patient_splits) + + event_configs = list(event_conversion_cfg.items()) + random.shuffle(event_configs) + + # Here, we'll be reading files directly, so we'll turn off globbing + read_fn = partial(pl.scan_parquet, glob=False) + + for sp, patients in patient_splits: + for input_prefix, event_cfgs in event_configs: + event_cfgs = copy.deepcopy(event_cfgs) + input_patient_id_column = event_cfgs.pop("patient_id_col", default_patient_id_col) + + event_shards = list((input_dir / input_prefix).glob("*.parquet")) + random.shuffle(event_shards) + + for shard_fp in event_shards: + out_fp = patient_subsharded_dir / sp / input_prefix / shard_fp.name + logger.info(f"Converting {shard_fp} to events and saving to {out_fp}") + + def compute_fn(df: pl.LazyFrame) -> pl.LazyFrame: + typed_patients = pl.Series(patients, dtype=df.schema[input_patient_id_column]) + + if input_patient_id_column != "patient_id": + df = df.rename({input_patient_id_column: "patient_id"}) + + try: + logger.info(f"Extracting events for {input_prefix}/{shard_fp.name}") + return convert_to_events( + df.filter(pl.col("patient_id").is_in(typed_patients)), + event_cfgs=copy.deepcopy(event_cfgs), + ) + except Exception as e: + raise ValueError( + f"Error converting {str(shard_fp.resolve())} for {sp}/{input_prefix}: {e}" + ) from e + + rwlock_wrap( + shard_fp, out_fp, read_fn, write_lazyframe, compute_fn, do_overwrite=cfg.do_overwrite + ) + + logger.info("Subsharded into converted events.") + + +if __name__ == "__main__": + main() diff --git a/src/MEDS_polars_functions/extraction/merge_to_MEDS_cohort.py b/src/MEDS_polars_functions/extraction/merge_to_MEDS_cohort.py new file mode 100755 index 0000000..f28e8ea --- /dev/null +++ b/src/MEDS_polars_functions/extraction/merge_to_MEDS_cohort.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python +from functools import partial +from importlib.resources import files +from pathlib import Path + +import hydra +import polars as pl +from loguru import logger +from omegaconf import DictConfig + +from MEDS_polars_functions.mapreduce.mapper import map_over, shard_iterator + +pl.enable_string_cache() + + +def merge_subdirs_and_sort( + sp_dir: Path, unique_by: list[str] | str | None, additional_sort_by: list[str] | None = None +) -> pl.LazyFrame: + """This function reads all parquet files in subdirs of `sp_dir` and merges them into a single dataframe. + + Args: + sp_dir: The directory containing the subdirs with parquet files to be merged. + unique_by: The list of columns that should be ensured to be unique after the dataframes are merged. If + `None`, this is ignored. If `*`, all columns are used. If a list of strings, only the columns in + the list are used. If a column is not found in the dataframe, it is omitted from the unique-by, a + warning is logged, but an error is *not* raised. Which rows are retained if the uniqeu-by columns + are not all columns is not guaranteed, but is also *not* random, so this may have statistical + implications. + additional_sort_by: Additional columns to sort by, in addition to the default sorting by patient ID + and timestamp. If `None`, only patient ID and timestamp are used. If a list of strings, these + columns are used in addition to the default sorting. If a column is not found in the dataframe, it + is omitted from the sort-by, a warning is logged, but an error is *not* raised. This functionality + is useful both for deterministic testing and in cases where a data owner wants to impose + intra-event measurement ordering in the data, though this is not recommended in general. + + Returns: + A single dataframe containing all the data from the parquet files in the subdirs of `sp_dir`. These + files will be concatenated diagonally, taking the union of all rows in all dataframes and all unique + columns in all dataframes to form the merged output. The returned dataframe will be made unique by the + columns specified in `unique_by` and sorted by first patient ID, then timestamp, then all columns in + `additional_sort_by`, if any. + + Raises: + FileNotFoundError: If no parquet files are found in the subdirs of `sp_dir`. + ValueError: If `unique_by` is not `None`, `*`, or a list of strings + + Examples: + >>> from tempfile import TemporaryDirectory + >>> df1 = pl.DataFrame({"patient_id": [1, 2], "timestamp": [10, 20], "code": ["A", "B"]}) + >>> df2 = pl.DataFrame({ + ... "patient_id": [1, 1, 3], + ... "timestamp": [2, 1, 8], + ... "code": ["C", "D", "E"], + ... "numerical_value": [None, 2.0, None], + ... }) + >>> df3 = pl.DataFrame({ + ... "patient_id": [1, 1, 3], + ... "timestamp": [2, 2, 8], + ... "code": ["C", "D", "E"], + ... "numerical_value": [6.2, 2.0, None], + ... }) + >>> with TemporaryDirectory() as tmpdir: + ... sp_dir = Path(tmpdir) + ... merge_subdirs_and_sort(sp_dir, unique_by=None) + Traceback (most recent call last): + ... + FileNotFoundError: No files found in ... + >>> with TemporaryDirectory() as tmpdir: + ... sp_dir = Path(tmpdir) + ... (sp_dir / "subdir1").mkdir() + ... df1.write_parquet(sp_dir / "subdir1" / "file1.parquet") + ... df2.write_parquet(sp_dir / "subdir1" / "file2.parquet") + ... (sp_dir / "subdir2").mkdir() + ... df3.write_parquet(sp_dir / "subdir2" / "df.parquet") + ... merge_subdirs_and_sort( + ... sp_dir, + ... unique_by=None, + ... additional_sort_by=["code", "numerical_value", "missing_col_will_not_error"] + ... ).collect() + shape: (8, 4) + ┌────────────┬───────────┬──────┬─────────────────┐ + │ patient_id ┆ timestamp ┆ code ┆ numerical_value │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ str ┆ f64 │ + ╞════════════╪═══════════╪══════╪═════════════════╡ + │ 1 ┆ 1 ┆ D ┆ 2.0 │ + │ 1 ┆ 2 ┆ C ┆ null │ + │ 1 ┆ 2 ┆ C ┆ 6.2 │ + │ 1 ┆ 2 ┆ D ┆ 2.0 │ + │ 1 ┆ 10 ┆ A ┆ null │ + │ 2 ┆ 20 ┆ B ┆ null │ + │ 3 ┆ 8 ┆ E ┆ null │ + │ 3 ┆ 8 ┆ E ┆ null │ + └────────────┴───────────┴──────┴─────────────────┘ + >>> with TemporaryDirectory() as tmpdir: + ... sp_dir = Path(tmpdir) + ... (sp_dir / "subdir1").mkdir() + ... df1.write_parquet(sp_dir / "subdir1" / "file1.parquet") + ... df2.write_parquet(sp_dir / "subdir1" / "file2.parquet") + ... (sp_dir / "subdir2").mkdir() + ... df3.write_parquet(sp_dir / "subdir2" / "df.parquet") + ... merge_subdirs_and_sort( + ... sp_dir, + ... unique_by="*", + ... additional_sort_by=["code", "numerical_value"] + ... ).collect() + shape: (7, 4) + ┌────────────┬───────────┬──────┬─────────────────┐ + │ patient_id ┆ timestamp ┆ code ┆ numerical_value │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ str ┆ f64 │ + ╞════════════╪═══════════╪══════╪═════════════════╡ + │ 1 ┆ 1 ┆ D ┆ 2.0 │ + │ 1 ┆ 2 ┆ C ┆ null │ + │ 1 ┆ 2 ┆ C ┆ 6.2 │ + │ 1 ┆ 2 ┆ D ┆ 2.0 │ + │ 1 ┆ 10 ┆ A ┆ null │ + │ 2 ┆ 20 ┆ B ┆ null │ + │ 3 ┆ 8 ┆ E ┆ null │ + └────────────┴───────────┴──────┴─────────────────┘ + >>> with TemporaryDirectory() as tmpdir: + ... sp_dir = Path(tmpdir) + ... (sp_dir / "subdir1").mkdir() + ... df1.write_parquet(sp_dir / "subdir1" / "file1.parquet") + ... df2.write_parquet(sp_dir / "subdir1" / "file2.parquet") + ... (sp_dir / "subdir2").mkdir() + ... df3.write_parquet(sp_dir / "subdir2" / "df.parquet") + ... # We just display the patient ID, timestamp, and code columns as the numerical value column + ... # is not guaranteed to be deterministic in the output given some rows will be dropped due to + ... # the unique-by constraint. + ... merge_subdirs_and_sort( + ... sp_dir, + ... unique_by=["patient_id", "timestamp", "code"], + ... additional_sort_by=["code", "numerical_value"] + ... ).select("patient_id", "timestamp", "code").collect() + shape: (6, 3) + ┌────────────┬───────────┬──────┐ + │ patient_id ┆ timestamp ┆ code │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ str │ + ╞════════════╪═══════════╪══════╡ + │ 1 ┆ 1 ┆ D │ + │ 1 ┆ 2 ┆ C │ + │ 1 ┆ 2 ┆ D │ + │ 1 ┆ 10 ┆ A │ + │ 2 ┆ 20 ┆ B │ + │ 3 ┆ 8 ┆ E │ + └────────────┴───────────┴──────┘ + """ + + files_to_read = list(sp_dir.glob("**/*.parquet")) + + if not files_to_read: + raise FileNotFoundError(f"No files found in {sp_dir}/**/*.parquet.") + + file_strs = "\n".join(f" - {str(fp.resolve())}" for fp in files_to_read) + logger.info(f"Reading {len(files_to_read)} files:\n{file_strs}") + + dfs = [pl.scan_parquet(fp, glob=False) for fp in files_to_read] + df = pl.concat(dfs, how="diagonal_relaxed") + + df_columns = set(df.collect_schema().names()) + + match unique_by: + case None: + pass + case "*": + df = df.unique(maintain_order=False) + case list() if len(unique_by) > 0 and all(isinstance(u, str) for u in unique_by): + subset = [] + for u in unique_by: + if u in df_columns: + subset.append(u) + else: + logger.warning(f"Column {u} not found in dataframe. Omitting from unique-by subset.") + df = df.unique(maintain_order=False, subset=subset) + case _: + raise ValueError(f"Invalid unique_by value: {unique_by}") + + sort_by = ["patient_id", "timestamp"] + if additional_sort_by is not None: + for s in additional_sort_by: + if s in df_columns: + sort_by.append(s) + else: + logger.warning(f"Column {s} not found in dataframe. Omitting from sort-by list.") + + return df.sort(by=sort_by, multithreaded=False) + + +config_yaml = files("MEDS_polars_functions").joinpath("configs/extraction.yaml") + + +@hydra.main(version_base=None, config_path=str(config_yaml.parent), config_name=config_yaml.stem) +def main(cfg: DictConfig): + """Merges the patient sub-sharded events into a single parquet file per patient shard. + + This function takes all dataframes (in parquet files) in any subdirs of the `cfg.stage_cfg.input_dir` and + merges them into a single dataframe. All dataframes in the subdirs are assumed to be in the unnested, MEDS + format, and cover the same group of patients (specific to the shard being processed). The merged dataframe + will also be sorted by patient ID and timestamp. + + All arguments are specified through the command line into the `cfg` object through Hydra. + + The `cfg.stage_cfg` object is a special key that is imputed by OmegaConf to contain the stage-specific + configuration arguments based on the global, pipeline-level configuration file. It cannot be overwritten + directly on the command line, but can be overwritten implicitly by overwriting components of the + `stage_configs.merge_to_MEDS_cohort` key. + + Args: + stage_configs.merge_to_MEDS_cohort.unique_by: The list of columns that should be ensured to be unique + after the dataframes are merged. Defaults to `"*"`, which means all columns are used. + stage_configs.merge_to_MEDS_cohort.additional_sort_by: Additional columns to sort by, in addition to + the default sorting by patient ID and timestamp. Defaults to `None`, which means only patient ID + and timestamp are used. + + Returns: + Writes the merged dataframes to the shard-specific output filepath in the `cfg.stage_cfg.output_dir`. + """ + + read_fn = partial( + merge_subdirs_and_sort, + unique_by=cfg.stage_cfg.get("unique_by", None), + additional_sort_by=cfg.stage_cfg.get("additional_sort_by", None), + ) + + map_over( + cfg, + read_fn=read_fn, + shard_iterator_fntr=partial(shard_iterator, in_suffix=""), + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/extraction/shard_events.py b/src/MEDS_polars_functions/extraction/shard_events.py similarity index 64% rename from scripts/extraction/shard_events.py rename to src/MEDS_polars_functions/extraction/shard_events.py index 9ce0ac9..dcdc421 100755 --- a/scripts/extraction/shard_events.py +++ b/src/MEDS_polars_functions/extraction/shard_events.py @@ -1,11 +1,12 @@ #!/usr/bin/env python - import copy import gzip import random +import warnings from collections.abc import Sequence from datetime import datetime from functools import partial +from importlib.resources import files from pathlib import Path import hydra @@ -13,7 +14,7 @@ from loguru import logger from omegaconf import DictConfig, OmegaConf -from MEDS_polars_functions.mapper import wrap as rwlock_wrap +from MEDS_polars_functions.mapreduce.mapper import rwlock_wrap from MEDS_polars_functions.utils import ( get_shard_prefix, hydra_loguru_init, @@ -27,15 +28,100 @@ def kwargs_strs(kwargs: dict) -> str: + """Returns a string representation of the kwargs dictionary for logging. + + Args: + kwargs: A dictionary of keyword arguments. + + Returns: A string with each key-value pair in the dictionary formatted as a bullet point, + newline-separated. The order of the key-value pairs is the order of the dictionary. + + Examples: + >>> print(kwargs_strs({"a": 1, "b": "two", "c": 3.0})) + * a=1 + * b=two + * c=3.0 + >>> print(kwargs_strs({})) + + """ return "\n".join([f" * {k}={v}" for k, v in kwargs.items()]) def scan_with_row_idx(fp: Path, columns: Sequence[str], **scan_kwargs) -> pl.LazyFrame: - """Scans a file with a row index column added. + """Scans a file into a polars lazyframe and adds a row index with name `ROW_IDX_NAME`. Note that we don't put ``row_index_name=ROW_IDX_NAME`` in the kwargs because it is not well supported in polars currently, pending https://github.com/pola-rs/polars/issues/15730. Instead, we add it at the end, which seems to work. + + Args: + fp: The file path to read. Must be either a ".csv", ".csv.gz", or ".parquet" file. + columns: A list of column names to read from the file. + scan_kwargs: Additional keyword arguments to pass to the scan function. The `infer_schema_length` + kwarg is removed for reading parquet files as it is not used for such files. + + Raises: + ValueError: If the file type is not supported. + + Returns: + A LazyFrame with the row index column added. + + Examples: + >>> from tempfile import TemporaryDirectory + >>> df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, schema={"a": pl.UInt8, "b": pl.Int64}) + >>> with TemporaryDirectory() as tmpdir: + ... fp = Path(tmpdir) / "test.csv" + ... df.write_csv(fp) + ... scan_with_row_idx(fp, columns=["a"], infer_schema_length=40).collect() + shape: (3, 2) + ┌─────────────┬─────┐ + │ __row_idx__ ┆ a │ + │ --- ┆ --- │ + │ u32 ┆ i64 │ + ╞═════════════╪═════╡ + │ 0 ┆ 1 │ + │ 1 ┆ 2 │ + │ 2 ┆ 3 │ + └─────────────┴─────┘ + >>> with TemporaryDirectory() as tmpdir: + ... fp = Path(tmpdir) / "test.parquet" + ... df.write_parquet(fp) + ... scan_with_row_idx(fp, columns=["a", "b"], infer_schema_length=40).collect() + shape: (3, 3) + ┌─────────────┬─────┬─────┐ + │ __row_idx__ ┆ a ┆ b │ + │ --- ┆ --- ┆ --- │ + │ u32 ┆ u8 ┆ i64 │ + ╞═════════════╪═════╪═════╡ + │ 0 ┆ 1 ┆ 4 │ + │ 1 ┆ 2 ┆ 5 │ + │ 2 ┆ 3 ┆ 6 │ + └─────────────┴─────┴─────┘ + >>> import gzip + >>> with TemporaryDirectory() as tmpdir: + ... fp = Path(tmpdir) / "test.csv.gz" + ... with gzip.open(fp, mode="wb") as f: + ... with warnings.catch_warnings(): + ... warnings.simplefilter("ignore", category=UserWarning) + ... df.write_csv(f) + ... scan_with_row_idx(fp, columns=["b"]).collect() + shape: (3, 2) + ┌─────────────┬─────┐ + │ __row_idx__ ┆ b │ + │ --- ┆ --- │ + │ u32 ┆ i64 │ + ╞═════════════╪═════╡ + │ 0 ┆ 4 │ + │ 1 ┆ 5 │ + │ 2 ┆ 6 │ + └─────────────┴─────┘ + >>> with TemporaryDirectory() as tmpdir: + ... fp = Path(tmpdir) / "test.json" + ... df.write_json(fp) + ... scan_with_row_idx(fp, columns=["a", "b"]) + Traceback (most recent call last): + ... + ValueError: Unsupported file type: .json """ kwargs = {**scan_kwargs} @@ -49,7 +135,9 @@ def scan_with_row_idx(fp: Path, columns: Sequence[str], **scan_kwargs) -> pl.Laz ) logger.warning("Reading compressed CSV files may be slow and limit parallelizability.") with gzip.open(fp, mode="rb") as f: - return pl.read_csv(f, **kwargs).with_row_index(ROW_IDX_NAME).lazy() + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + return pl.read_csv(f, **kwargs).with_row_index(ROW_IDX_NAME).lazy() case ".csv": logger.debug(f"Reading {str(fp.resolve())} as CSV with kwargs:\n{kwargs_strs(kwargs)}.") df = pl.scan_csv(fp, **kwargs) @@ -70,7 +158,7 @@ def scan_with_row_idx(fp: Path, columns: Sequence[str], **scan_kwargs) -> pl.Laz df = df.with_row_index(ROW_IDX_NAME) - logger.debug(f"Returning df with columns: {', '.join(df.columns)}") + logger.debug(f"Returning df with columns: {', '.join(df.collect_schema().names())}") return df @@ -177,16 +265,67 @@ def retrieve_columns(event_conversion_cfg: DictConfig) -> dict[str, list[str]]: def filter_to_row_chunk(df: pl.LazyFrame, start: int, end: int) -> pl.LazyFrame: + """Filters the input LazyFrame to a specific row chunk. + + This function is a simple helper designed to make other code clearer. The lazyframe must have a row index + column named `ROW_IDX_NAME`. + + Args: + df: The input LazyFrame. + start: The starting row index (inclusive). + end: The ending row index (exclusive). + + Returns: + The dataframe with only the rows in the range [`start`, `end`), and with the row index column dropped. + + Examples: + >>> df = pl.DataFrame({ROW_IDX_NAME: [1, 2, 3, 4, 5], "b": [6, 7, 8, 9, 10]}) + >>> filter_to_row_chunk(df.lazy(), 1, 3).collect() + shape: (2, 1) + ┌─────┐ + │ b │ + │ --- │ + │ i64 │ + ╞═════╡ + │ 6 │ + │ 7 │ + └─────┘ + >>> filter_to_row_chunk(df.lazy(), 100, 300).collect() + shape: (0, 1) + ┌─────┐ + │ b │ + │ --- │ + │ i64 │ + ╞═════╡ + └─────┘ + """ + return df.filter(pl.col(ROW_IDX_NAME).is_between(start, end, closed="left")).drop(ROW_IDX_NAME) -@hydra.main(version_base=None, config_path="../../configs", config_name="extraction") +config_yaml = files("MEDS_polars_functions").joinpath("configs/extraction.yaml") + + +@hydra.main(version_base=None, config_path=str(config_yaml.parent), config_name=config_yaml.stem) def main(cfg: DictConfig): """Runs the input data re-sharding process. Can be parallelized across output shards. - Output shards are simply row-chunks of the input data. There is no randomization or re-ordering of the - input data. Read contention on the input files may render additional parallelism beyond one worker per - input file ineffective. + This stage takes the raw input files and splits them into smaller files by taking consecutive chunks of + rows and writing them out to new files. This is useful for parallelizing the processing of the input data. + There is no randomization or re-ordering of the input data, and furthermore read contention on the input + files being split may render additional parallelism beyond one worker per input file ineffective. + + All arguments are specified through the command line into the `cfg` object through Hydra. + + The `cfg.stage_cfg` object is a special key that is imputed by OmegaConf to contain the stage-specific + configuration arguments based on the global, pipeline-level configuration file. It cannot be overwritten + directly on the command line, but can be overwritten implicitly by overwriting components of the + `stage_configs.shard_events` key. + + Args: + stage_configs.shard_events.row_chunksize (int): The number of rows to read in at a time. + stage_configs.shard_events.infer_schema_length (int): The number of rows to read in to infer the + schema (only used if the source files are csvs). """ hydra_loguru_init() diff --git a/src/MEDS_polars_functions/sharding.py b/src/MEDS_polars_functions/extraction/split_and_shard_patients.py old mode 100644 new mode 100755 similarity index 53% rename from src/MEDS_polars_functions/sharding.py rename to src/MEDS_polars_functions/extraction/split_and_shard_patients.py index 837a0af..5e4478d --- a/src/MEDS_polars_functions/sharding.py +++ b/src/MEDS_polars_functions/extraction/split_and_shard_patients.py @@ -1,9 +1,16 @@ -"""Utilities for sharding/re-sharding and splitting MEDS datasets.""" - +#!/usr/bin/env python +import json from collections.abc import Sequence +from importlib.resources import files +from pathlib import Path +import hydra import numpy as np +import polars as pl from loguru import logger +from omegaconf import DictConfig, OmegaConf + +from MEDS_polars_functions.utils import stage_init def shard_patients[ @@ -144,3 +151,113 @@ def shard_patients[ seen[k] = set(pts) return final_shards + + +config_yaml = files("MEDS_polars_functions").joinpath("configs/extraction.yaml") + + +@hydra.main(version_base=None, config_path=str(config_yaml.parent), config_name=config_yaml.stem) +def main(cfg: DictConfig): + """Extracts the set of unique patients from the raw data and splits/shards them and saves the result. + + This stage splits the patients into training, tuning, and held-out sets, and further splits those sets + into shards. + + All arguments are specified through the command line into the `cfg` object through Hydra. + + The `cfg.stage_cfg` object is a special key that is imputed by OmegaConf to contain the stage-specific + configuration arguments based on the global, pipeline-level configuration file. It cannot be overwritten + directly on the command line, but can be overwritten implicitly by overwriting components of the + `stage_configs.split_and_shard_patients` key. + + Args: + stage_configs.split_and_shard_patients.n_patients_per_shard: The maximum number of patients to include + in any shard. Realized shards will not necessarily have this many patients, though they will never + exceed this number. Instead, the number of shards necessary to include all patients in a split + such that no shard exceeds this number will be calculated, then the patients will be evenly, + randomly split amongst those shards so that all shards within a split have approximately the same + number of patietns. + stage_configs.split_and_shard_patients.external_splits_json_fp: The path to a json file containing any + pre-defined splits for specialty held-out test sets beyond the IID held out set that will be + produced (e.g., for prospective datasets, etc.). + stage_configs.split_and_shard_patients.split_fracs: The fraction of patients to include in the IID + training, tuning, and held-out sets. Split fractions can be changed for the default names by + adding a hydra-syntax command line argument for the nested name; e.g., `split_fracs.train=0.7 + split_fracs.tuning=0.1 split_fracs.held_out=0.2`. A split can be removed with the `~` override + Hydra syntax. Similarly, a new split name can be added with the standard Hydra `+` override + option. E.g., `~split_fracs.held_out +split_fracs.test=0.1`. It is the user's responsibility to + ensure that split fractions sum to 1. + """ + + subsharded_dir, MEDS_cohort_dir, _, _ = stage_init(cfg) + + event_conversion_cfg_fp = Path(cfg.event_conversion_config_fp) + if not event_conversion_cfg_fp.exists(): + raise FileNotFoundError(f"Event conversion config file not found: {event_conversion_cfg_fp}") + + logger.info( + f"Reading event conversion config from {event_conversion_cfg_fp} (needed for patient ID columns)" + ) + event_conversion_cfg = OmegaConf.load(event_conversion_cfg_fp) + logger.info(f"Event conversion config:\n{OmegaConf.to_yaml(event_conversion_cfg)}") + + dfs = [] + + default_patient_id_col = event_conversion_cfg.pop("patient_id_col", "patient_id") + for input_prefix, event_cfgs in event_conversion_cfg.items(): + input_patient_id_column = event_cfgs.get("patient_id_col", default_patient_id_col) + + input_fps = list((subsharded_dir / input_prefix).glob("**/*.parquet")) + + input_fps_strs = "\n".join(f" - {str(fp.resolve())}" for fp in input_fps) + logger.info(f"Reading patient IDs from {input_prefix} files:\n{input_fps_strs}") + + for input_fp in input_fps: + dfs.append( + pl.scan_parquet(input_fp, glob=False) + .select(pl.col(input_patient_id_column).alias("patient_id")) + .unique() + ) + + logger.info(f"Joining all patient IDs from {len(dfs)} dataframes") + patient_ids = ( + pl.concat(dfs) + .select(pl.col("patient_id").drop_nulls().drop_nans().unique()) + .collect(streaming=True)["patient_id"] + .to_numpy(use_pyarrow=True) + ) + + logger.info(f"Found {len(patient_ids)} unique patient IDs of type {patient_ids.dtype}") + + if cfg.stage_cfg.external_splits_json_fp: + external_splits_json_fp = Path(cfg.stage_cfg.external_splits_json_fp) + if not external_splits_json_fp.exists(): + raise FileNotFoundError(f"External splits JSON file not found at {external_splits_json_fp}") + + logger.info(f"Reading external splits from {cfg.external_splits_json_fp}") + external_splits = json.loads(external_splits_json_fp.read_text()) + + size_strs = ", ".join(f"{k}: {len(v)}" for k, v in external_splits.items()) + logger.info(f"Loaded external splits of size: {size_strs}") + else: + external_splits = None + + logger.info("Sharding and splitting patients") + + sharded_patients = shard_patients( + patients=patient_ids, + external_splits=external_splits, + split_fracs_dict=cfg.stage_cfg.split_fracs, + n_patients_per_shard=cfg.stage_cfg.n_patients_per_shard, + seed=cfg.seed, + ) + + logger.info(f"Writing sharded patients to {MEDS_cohort_dir}") + MEDS_cohort_dir.mkdir(parents=True, exist_ok=True) + out_fp = MEDS_cohort_dir / "splits.json" + out_fp.write_text(json.dumps(sharded_patients)) + logger.info(f"Done writing sharded patients to {out_fp}") + + +if __name__ == "__main__": + main() diff --git a/src/MEDS_polars_functions/filters/README.md b/src/MEDS_polars_functions/filters/README.md new file mode 100644 index 0000000..9d582f0 --- /dev/null +++ b/src/MEDS_polars_functions/filters/README.md @@ -0,0 +1,5 @@ +# Filters + +Filters remove wholesale events within the data, either at the patient or event level. For transformations +that simply _occlude_ aspects of the data (e.g., by setting a code variable to `UNK`), see the `transforms` +library section. diff --git a/src/MEDS_polars_functions/filters/__init__.py b/src/MEDS_polars_functions/filters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MEDS_polars_functions/filter_measurements.py b/src/MEDS_polars_functions/filters/filter_measurements.py similarity index 58% rename from src/MEDS_polars_functions/filter_measurements.py rename to src/MEDS_polars_functions/filters/filter_measurements.py index a6abc18..a11eaf6 100644 --- a/src/MEDS_polars_functions/filter_measurements.py +++ b/src/MEDS_polars_functions/filters/filter_measurements.py @@ -1,10 +1,15 @@ +#!/usr/bin/env python """A polars-to-polars transformation function for filtering patients by sequence length.""" - from collections.abc import Callable +from importlib.resources import files +from pathlib import Path +import hydra import polars as pl from omegaconf import DictConfig +from MEDS_polars_functions.mapreduce.mapper import map_over + pl.enable_string_cache() @@ -123,11 +128,12 @@ def filter_codes_fn(df: pl.LazyFrame) -> pl.LazyFrame: """ idx_col = "_row_idx" - while idx_col in df.columns: + df_columns = set(df.collect_schema().names()) + while idx_col in df_columns: idx_col = f"_{idx_col}" return ( - df.with_row_count(idx_col) + df.with_row_index(idx_col) .join(allowed_code_metadata, on=join_cols, how="inner") .sort(idx_col) .drop(idx_col) @@ -136,92 +142,20 @@ def filter_codes_fn(df: pl.LazyFrame) -> pl.LazyFrame: return filter_codes_fn -def filter_outliers_fntr( - stage_cfg: DictConfig, code_metadata: pl.LazyFrame, code_modifier_columns: list[str] | None = None -) -> Callable[[pl.LazyFrame], pl.LazyFrame]: - """Filters patient events to only encompass those with a set of permissible codes. +config_yaml = files("MEDS_polars_functions").joinpath("configs/preprocess.yaml") - Args: - df: The input DataFrame. - stage_cfg: The configuration for the code filtering stage. - Returns: - The processed DataFrame. +@hydra.main(version_base=None, config_path=str(config_yaml.parent), config_name=config_yaml.stem) +def main(cfg: DictConfig): + """TODO.""" - Examples: - >>> code_metadata_df = pl.DataFrame({ - ... "code": pl.Series(["A", "A", "B", "C"], dtype=pl.Categorical), - ... "modifier1": [1, 2, 1, 2], - ... "values/n_occurrences": [3, 1, 3, 2], - ... "values/sum": [0.0, 4.0, 12.0, 2.0], - ... "values/sum_sqd": [27.0, 16.0, 75.0, 4.0], - ... # for clarity: ----- mean = [0.0, 4.0, 4.0, 1.0] - ... # for clarity: --- stddev = [3.0, 0.0, 3.0, 1.0] - ... }) - >>> data = pl.DataFrame({ - ... "patient_id": [1, 1, 2, 2], - ... "code": pl.Series(["A", "B", "A", "C"], dtype=pl.Categorical), - ... "modifier1": [1, 1, 2, 2], - ... # for clarity: mean [0.0, 4.0, 4.0, 1.0] - ... # for clarity: stddev [3.0, 3.0, 0.0, 1.0] - ... "numerical_value": [15., 16., 3.9, 1.0], - ... }).lazy() - >>> stage_cfg = DictConfig({"stddev_cutoff": 4.5}) - >>> fn = filter_outliers_fntr(stage_cfg, code_metadata_df, ["modifier1"]) - >>> fn(data).collect() - shape: (4, 5) - ┌────────────┬──────┬───────────┬─────────────────┬───────────────────────────┐ - │ patient_id ┆ code ┆ modifier1 ┆ numerical_value ┆ numerical_value/is_inlier │ - │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ - │ i64 ┆ cat ┆ i64 ┆ f64 ┆ bool │ - ╞════════════╪══════╪═══════════╪═════════════════╪═══════════════════════════╡ - │ 1 ┆ A ┆ 1 ┆ null ┆ false │ - │ 1 ┆ B ┆ 1 ┆ 16.0 ┆ true │ - │ 2 ┆ A ┆ 2 ┆ null ┆ false │ - │ 2 ┆ C ┆ 2 ┆ 1.0 ┆ true │ - └────────────┴──────┴───────────┴─────────────────┴───────────────────────────┘ - """ + code_metadata = pl.read_parquet( + Path(cfg.stage_cfg.metadata_input_dir) / "code_metadata.parquet", use_pyarrow=True + ) + compute_fn = filter_codes_fntr(cfg.stage_cfg, code_metadata) - stddev_cutoff = stage_cfg.get("stddev_cutoff", None) - if stddev_cutoff is None: - return lambda df: df - - join_cols = ["code"] - if code_modifier_columns: - join_cols.extend(code_modifier_columns) - - cols_to_select = ["code"] - if code_modifier_columns: - cols_to_select.extend(code_modifier_columns) - - mean_col = pl.col("values/sum") / pl.col("values/n_occurrences") - stddev_col = (pl.col("values/sum_sqd") / pl.col("values/n_occurrences") - mean_col**2) ** 0.5 - if "values/mean" not in code_metadata.columns: - cols_to_select.append(mean_col.alias("values/mean")) - if "values/std" not in code_metadata.columns: - cols_to_select.append(stddev_col.alias("values/std")) + map_over(cfg, compute_fn=compute_fn) - code_metadata = code_metadata.lazy().select(cols_to_select) - - def filter_outliers_fn(df: pl.LazyFrame) -> pl.LazyFrame: - f"""Filters out outlier numerical values from patient events. - - In particular, this function filters the DataFrame to only include numerical values that are within - {stddev_cutoff} standard deviations of the mean for the corresponding (code, modifier) pair. - """ - - val = pl.col("numerical_value") - mean = pl.col("values/mean") - stddev = pl.col("values/std") - filter_expr = (val - mean).abs() <= stddev_cutoff * stddev - - return ( - df.join(code_metadata, on=join_cols, how="left") - .with_columns( - filter_expr.alias("numerical_value/is_inlier"), - pl.when(filter_expr).then(pl.col("numerical_value")).alias("numerical_value"), - ) - .drop("values/mean", "values/std") - ) - return filter_outliers_fn +if __name__ == "__main__": + main() diff --git a/src/MEDS_polars_functions/filter_patients_by_length.py b/src/MEDS_polars_functions/filters/filter_patients_by_length.py similarity index 80% rename from src/MEDS_polars_functions/filter_patients_by_length.py rename to src/MEDS_polars_functions/filters/filter_patients_by_length.py index 00fd9c5..fc08f19 100644 --- a/src/MEDS_polars_functions/filter_patients_by_length.py +++ b/src/MEDS_polars_functions/filters/filter_patients_by_length.py @@ -1,6 +1,16 @@ +#!/usr/bin/env python """A polars-to-polars transformation function for filtering patients by sequence length.""" +from collections.abc import Callable +from functools import partial +from importlib.resources import files +import hydra import polars as pl +from loguru import logger +from omegaconf import DictConfig + +from MEDS_polars_functions.mapreduce.mapper import map_over +from MEDS_polars_functions.utils import hydra_loguru_init def filter_patients_by_num_measurements(df: pl.LazyFrame, min_measurements_per_patient: int) -> pl.LazyFrame: @@ -155,3 +165,48 @@ def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) ) return df.filter(pl.col("timestamp").n_unique().over("patient_id") >= min_events_per_patient) + + +def filter_patients_fntr(stage_cfg: DictConfig) -> Callable[[pl.LazyFrame], pl.LazyFrame]: + compute_fns = [] + if stage_cfg.min_measurements_per_patient: + logger.info( + f"Filtering patients with fewer than {stage_cfg.min_measurements_per_patient} measurements " + "(observations of any kind)." + ) + compute_fns.append( + partial( + filter_patients_by_num_measurements, + min_measurements_per_patient=stage_cfg.min_measurements_per_patient, + ) + ) + if stage_cfg.min_events_per_patient: + logger.info( + f"Filtering patients with fewer than {stage_cfg.min_events_per_patient} events " + "(unique timepoints)." + ) + compute_fns.append( + partial(filter_patients_by_num_events, min_events_per_patient=stage_cfg.min_events_per_patient) + ) + + def fn(data: pl.LazyFrame) -> pl.LazyFrame: + for compute_fn in compute_fns: + data = compute_fn(data) + return data + + +config_yaml = files("MEDS_polars_functions").joinpath("configs/preprocess.yaml") + + +@hydra.main(version_base=None, config_path=str(config_yaml.parent), config_name=config_yaml.stem) +def main(cfg: DictConfig): + """TODO.""" + + hydra_loguru_init() + compute_fn = filter_patients_fntr(cfg.stage_cfg) + + map_over(cfg, compute_fn=compute_fn) + + +if __name__ == "__main__": + main() diff --git a/src/MEDS_polars_functions/get_vocabulary.py b/src/MEDS_polars_functions/fit_vocabulary_indices.py similarity index 84% rename from src/MEDS_polars_functions/get_vocabulary.py rename to src/MEDS_polars_functions/fit_vocabulary_indices.py index 318b9e1..e6f105e 100644 --- a/src/MEDS_polars_functions/get_vocabulary.py +++ b/src/MEDS_polars_functions/fit_vocabulary_indices.py @@ -1,9 +1,16 @@ +#!/usr/bin/env python """Simple helper functions to define a consistent code vocabulary for normalizing a MEDS dataset.""" - from collections.abc import Callable from enum import StrEnum +from importlib.resources import files +from pathlib import Path +import hydra import polars as pl +from loguru import logger +from omegaconf import DictConfig, OmegaConf + +from MEDS_polars_functions.utils import hydra_loguru_init class VOCABULARY_ORDERING(StrEnum): @@ -87,7 +94,7 @@ def validate_code_metadata(code_metadata: pl.DataFrame, code_modifiers: list[str n_total_rows = len(code_metadata) if n_unique_codes != n_total_rows: - code_counts = code_metadata.groupby(cols).agg(pl.len().alias("count")).sort("count", descending=True) + code_counts = code_metadata.group_by(cols).agg(pl.len().alias("count")).sort("count", descending=True) extra_codes = code_counts.filter(pl.col("count") > 1) raise ValueError(f"The code and code modifiers are not unique:\n{extra_codes.head(100)}") @@ -183,3 +190,51 @@ def lexicographic_indices(code_metadata: pl.DataFrame, code_modifiers: list[str] VOCABULARY_ORDERING_METHODS: dict[VOCABULARY_ORDERING, INDEX_ASSIGNMENT_FN] = { VOCABULARY_ORDERING.LEXICOGRAPHIC: lexicographic_indices, } + +config_yaml = files("MEDS_polars_functions").joinpath("configs/preprocess.yaml") + + +@hydra.main(version_base=None, config_path=str(config_yaml.parent), config_name=config_yaml.stem) +def main(cfg: DictConfig): + """TODO.""" + + hydra_loguru_init() + + logger.info( + f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" + f"Stage: {cfg.stage}\n\n" + f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" + ) + + metadata_input_dir = Path(cfg.stage_cfg.metadata_input_dir) + output_dir = Path(cfg.stage_cfg.output_dir) + + code_metadata = pl.read_parquet(metadata_input_dir / "code_metadata.parquet", use_pyarrow=True) + + ordering_method = cfg.stage_cfg.get("ordering_method", VOCABULARY_ORDERING.LEXICOGRAPHIC) + + if ordering_method not in VOCABULARY_ORDERING_METHODS: + raise ValueError( + f"Invalid ordering method: {ordering_method}. " + f"Expected one of {', '.join(VOCABULARY_ORDERING_METHODS.keys())}" + ) + + logger.info(f"Assigning code vocabulary indices via a {ordering_method} order.") + ordering_fn = VOCABULARY_ORDERING_METHODS[ordering_method] + + code_modifiers = cfg.get("code_modifier_columns", None) + if code_modifiers is None: + code_modifiers = [] + + code_metadata = ordering_fn(code_metadata, code_modifiers) + + output_fp = output_dir / "code_metadata.parquet" + logger.info(f"Indices assigned. Writing to {output_fp}") + + code_metadata.write_parquet(output_fp, use_pyarrow=True) + + logger.info(f"Done with {cfg.stage}") + + +if __name__ == "__main__": + main() diff --git a/src/MEDS_polars_functions/mapreduce/__init__.py b/src/MEDS_polars_functions/mapreduce/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MEDS_polars_functions/mapreduce/mapper.py b/src/MEDS_polars_functions/mapreduce/mapper.py new file mode 100644 index 0000000..c70b6dd --- /dev/null +++ b/src/MEDS_polars_functions/mapreduce/mapper.py @@ -0,0 +1,60 @@ +"""Basic utilities for parallelizable map operations on sharded MEDS datasets with caching and locking.""" + +from collections.abc import Callable, Generator +from datetime import datetime +from functools import partial +from pathlib import Path +from typing import Any, TypeVar + +import polars as pl +from loguru import logger +from omegaconf import DictConfig + +from ..utils import stage_init, write_lazyframe +from .utils import rwlock_wrap, shard_iterator + +DF_T = TypeVar("DF_T") + +MAP_FN_T = Callable[[DF_T], DF_T] | tuple[Callable[[DF_T], DF_T]] +SHARD_GEN_T = Generator[tuple[Path, Path], None, None] +SHARD_ITR_FNTR_T = Callable[[DictConfig], SHARD_GEN_T] + + +def identity_fn(df: Any) -> Any: + return df + + +def map_over( + cfg: DictConfig, + compute_fn: MAP_FN_T | None = None, + read_fn: Callable[[Path], DF_T] = partial(pl.scan_parquet, glob=False), + write_fn: Callable[[DF_T, Path], None] = write_lazyframe, + shard_iterator_fntr: SHARD_ITR_FNTR_T = shard_iterator, +) -> list[Path]: + stage_init(cfg) + + start = datetime.now() + + if compute_fn is None: + compute_fn = identity_fn + + if not isinstance(compute_fn, tuple): + compute_fn = (compute_fn,) + + all_out_fps = [] + for in_fp, out_fp in shard_iterator_fntr(cfg): + logger.info(f"Processing {str(in_fp.resolve())} into {str(out_fp.resolve())}") + rwlock_wrap( + in_fp, + out_fp, + read_fn, + write_fn, + *compute_fn, + do_return=False, + cache_intermediate=False, + do_overwrite=cfg.do_overwrite, + ) + all_out_fps.append(out_fp) + + logger.info(f"Finished mapping in {datetime.now() - start}") + return all_out_fps diff --git a/src/MEDS_polars_functions/mapper.py b/src/MEDS_polars_functions/mapreduce/utils.py similarity index 68% rename from src/MEDS_polars_functions/mapper.py rename to src/MEDS_polars_functions/mapreduce/utils.py index 34275b8..d419ced 100644 --- a/src/MEDS_polars_functions/mapper.py +++ b/src/MEDS_polars_functions/mapreduce/utils.py @@ -1,12 +1,14 @@ -"""Basic utilities for parallelizable map operations on sharded MEDS datasets with caching and locking.""" +"""Basic utilities for parallelizable mapreduces on sharded MEDS datasets with caching and locking.""" import json +import random import shutil from collections.abc import Callable from datetime import datetime from pathlib import Path from loguru import logger +from omegaconf import DictConfig LOCK_TIME_FMT = "%Y-%m-%dT%H:%M:%S.%f" @@ -82,7 +84,7 @@ def register_lock(cache_directory: Path) -> tuple[datetime, Path]: return lock_time, lock_fp -def wrap[ +def rwlock_wrap[ DF_T ]( in_fp: Path, @@ -149,7 +151,7 @@ def wrap[ ... lambda df: df.with_columns(pl.col("c") * 2), ... lambda df: df.filter(pl.col("c") > 4) ... ] - >>> result_computed = wrap(in_fp, out_fp, read_fn, write_fn, *transform_fns, do_return=False) + >>> result_computed = rwlock_wrap(in_fp, out_fp, read_fn, write_fn, *transform_fns, do_return=False) >>> assert result_computed >>> print(out_fp.read_text()) a,b,c @@ -163,7 +165,7 @@ def wrap[ ... lambda df: df.with_columns(pl.col("c") * 2), ... lambda df: df.filter(pl.col("d") > 4) ... ] - >>> wrap(in_fp, out_fp, read_fn, write_fn, *transform_fns) + >>> rwlock_wrap(in_fp, out_fp, read_fn, write_fn, *transform_fns) Traceback (most recent call last): ... polars.exceptions.ColumnNotFoundError: unable to find column "d"; valid columns: ["a", "b", "c"] @@ -186,7 +188,7 @@ def wrap[ >>> def lock_dir_checker_fn(df: pl.DataFrame) -> pl.DataFrame: ... print(f"Lock dir exists? {lock_dir.exists()}") ... return df - >>> result_computed, out_df = wrap( + >>> result_computed, out_df = rwlock_wrap( ... in_fp, out_fp, read_fn, write_fn, lock_dir_checker_fn, do_return=True ... ) Lock dir exists? True @@ -276,3 +278,103 @@ def wrap[ logger.warning(f"Clearing lock due to Exception {e} at {lock_fp} after {datetime.now() - st_time}") lock_fp.unlink() raise e + + +def shard_iterator( + cfg: DictConfig, + in_suffix: str = ".parquet", + out_suffix: str = ".parquet", + in_prefix: str = "", + out_prefix: str = "", +): + """Provides a generator that yields shard input and output files for mapreduce operations. + + Args: + cfg: The configuration dictionary for the overall pipeline. Should (possibly) contain the following + keys (some are optional, as marked below): + - ``stage_cfg.data_input_dir`` (mandatory): The directory containing the input data. + - ``stage_cfg.output_dir`` (mandatory): The directory to write the output data. + - ``shards_map_fp`` (mandatory): The file path to the shards map JSON file. + - ``stage_cfg.process_shard_prefix`` (optional): The prefix of the shards to process (e.g., + ``"train/"``). If not provided, all shards will be processed. + - ``worker`` (optional): The worker ID for the MR worker; this is also used to seed the + randomization process. If not provided, the randomization process is unseeded. + in_suffix: The suffix of the input files. Defaults to ".parquet". This can be set to "" to process + entire directories. + out_suffix: The suffix of the output files. Defaults to ".parquet". + in_prefix: The prefix of the input files. Defaults to "". This can be used to load files from a + subdirectory of the input directory by including a "/" at the end of the prefix. + out_prefix: The prefix of the output files. Defaults to "". + + Yields: + Randomly shuffled pairs of input and output file paths for each shard. The randomization process is + seeded by the worker ID in ``cfg``, if provided, otherwise it is left unseeded. + + Examples: + >>> from tempfile import NamedTemporaryFile + >>> shards = {"train/0": [1, 2, 3], "train/1": [4, 5, 6], "held_out": [4, 5, 6], "foo": [5]} + >>> with NamedTemporaryFile() as tmp: + ... _ = Path(tmp.name).write_text(json.dumps(shards)) + ... cfg = DictConfig({ + ... "stage_cfg": {"data_input_dir": "data/", "output_dir": "output/"}, + ... "shards_map_fp": tmp.name, + ... "worker": 1, + ... }) + ... gen = shard_iterator(cfg) + ... list(gen) # doctest: +NORMALIZE_WHITESPACE + [(PosixPath('data/foo.parquet'), PosixPath('output/foo.parquet')), + (PosixPath('data/train/0.parquet'), PosixPath('output/train/0.parquet')), + (PosixPath('data/held_out.parquet'), PosixPath('output/held_out.parquet')), + (PosixPath('data/train/1.parquet'), PosixPath('output/train/1.parquet'))] + >>> with NamedTemporaryFile() as tmp: + ... _ = Path(tmp.name).write_text(json.dumps(shards)) + ... cfg = DictConfig({ + ... "stage_cfg": {"data_input_dir": "data/", "output_dir": "output/"}, + ... "shards_map_fp": tmp.name, + ... "worker": 1, + ... }) + ... gen = shard_iterator(cfg, in_suffix="", out_suffix=".csv", in_prefix="a/", out_prefix="b/") + ... list(gen) # doctest: +NORMALIZE_WHITESPACE + [(PosixPath('data/a/foo'), PosixPath('output/b/foo.csv')), + (PosixPath('data/a/train/0'), PosixPath('output/b/train/0.csv')), + (PosixPath('data/a/held_out'), PosixPath('output/b/held_out.csv')), + (PosixPath('data/a/train/1'), PosixPath('output/b/train/1.csv'))] + >>> with NamedTemporaryFile() as tmp: + ... _ = Path(tmp.name).write_text(json.dumps(shards)) + ... cfg = DictConfig({ + ... "stage_cfg": { + ... "data_input_dir": "data/", "output_dir": "output/", "process_shard_prefix": "train/" + ... }, + ... "shards_map_fp": tmp.name, + ... "worker": 1, + ... }) + ... gen = shard_iterator(cfg) + ... list(gen) # doctest: +NORMALIZE_WHITESPACE + [(PosixPath('data/train/1.parquet'), PosixPath('output/train/1.parquet')), + (PosixPath('data/train/0.parquet'), PosixPath('output/train/0.parquet'))] + """ + + input_dir = Path(cfg.stage_cfg.data_input_dir) + output_dir = Path(cfg.stage_cfg.output_dir) + shards_map_fn = Path(cfg.shards_map_fp) + + shards = json.loads(shards_map_fn.read_text()) + + if "process_shard_prefix" in cfg.stage_cfg: + logger.info(f'Processing shards with prefix "{cfg.stage_cfg.process_shard_prefix}"') + shards = {k: v for k, v in shards.items() if k.startswith(cfg.stage_cfg.process_shard_prefix)} + + shards = list(shards.keys()) + if "worker" in cfg: + random.seed(cfg.worker) + random.shuffle(shards) + + logger.info(f"Mapping computation over a maximum of {len(shards)} shards") + + for sp in shards: + in_fp = input_dir / f"{in_prefix}{sp}{in_suffix}" + out_fp = output_dir / f"{out_prefix}{sp}{out_suffix}" + + # TODO: Could add checking logic for existence of in_fp and/or out_fp here. + + yield in_fp, out_fp diff --git a/src/MEDS_polars_functions/transforms/__init__.py b/src/MEDS_polars_functions/transforms/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MEDS_polars_functions/time_derived_measurements.py b/src/MEDS_polars_functions/transforms/add_time_derived_measurements.py similarity index 91% rename from src/MEDS_polars_functions/time_derived_measurements.py rename to src/MEDS_polars_functions/transforms/add_time_derived_measurements.py index fba536c..3161337 100644 --- a/src/MEDS_polars_functions/time_derived_measurements.py +++ b/src/MEDS_polars_functions/transforms/add_time_derived_measurements.py @@ -1,9 +1,15 @@ +#!/usr/bin/env python """Transformations for adding time-derived measurements (e.g., a patient's age) to a MEDS dataset.""" - from collections.abc import Callable +from importlib.resources import files +import hydra import polars as pl -from omegaconf import DictConfig +from loguru import logger +from omegaconf import DictConfig, OmegaConf + +from MEDS_polars_functions.mapreduce.mapper import map_over +from MEDS_polars_functions.utils import hydra_loguru_init pl.enable_string_cache() @@ -360,3 +366,46 @@ def tod_code(start: int, end: int) -> str: ) return fn + + +def add_time_derived_measurements_fntr(stage_cfg: DictConfig) -> Callable[[pl.LazyFrame], pl.LazyFrame]: + INFERRED_STAGE_KEYS = {"is_metadata", "data_input_dir", "metadata_input_dir", "output_dir"} + + compute_fns = [] + # We use the raw stages object as the induced `stage_cfg` has extra properties like the input and output + # directories. + for feature_name, feature_cfg in stage_cfg.items(): + match feature_name: + case "age": + compute_fns.append(add_new_events_fntr(age_fntr(feature_cfg))) + case "time_of_day": + compute_fns.append(add_new_events_fntr(time_of_day_fntr(feature_cfg))) + case str() if feature_name in INFERRED_STAGE_KEYS: + continue + case _: + raise ValueError(f"Unknown time-derived measurement: {feature_name}") + + logger.info(f"Adding {feature_name} via config: {OmegaConf.to_yaml(feature_cfg)}") + + def fn(df: pl.LazyFrame) -> pl.LazyFrame: + for compute_fn in compute_fns: + df = compute_fn(df) + return df + + return fn + + +config_yaml = files("MEDS_polars_functions").joinpath("configs/preprocess.yaml") + + +@hydra.main(version_base=None, config_path=str(config_yaml.parent), config_name=config_yaml.stem) +def main(cfg: DictConfig): + """Adds time-derived measurements to a MEDS cohort as separate observations at each unique timestamp.""" + + hydra_loguru_init() + compute_fn = add_time_derived_measurements_fntr(cfg.stage_cfg) + map_over(cfg, compute_fn=compute_fn) + + +if __name__ == "__main__": + main() diff --git a/src/MEDS_polars_functions/normalization.py b/src/MEDS_polars_functions/transforms/normalization.py similarity index 90% rename from src/MEDS_polars_functions/normalization.py rename to src/MEDS_polars_functions/transforms/normalization.py index 28a6ac8..fa22b3a 100644 --- a/src/MEDS_polars_functions/normalization.py +++ b/src/MEDS_polars_functions/transforms/normalization.py @@ -1,6 +1,15 @@ +#!/usr/bin/env python """Transformations for normalizing MEDS datasets, across both categorical and continuous dimensions.""" +from functools import partial +from importlib.resources import files +from pathlib import Path + +import hydra import polars as pl +from omegaconf import DictConfig + +from MEDS_polars_functions.mapreduce.mapper import map_over def normalize( @@ -168,12 +177,13 @@ def normalize( mean_col = pl.col("values/sum") / pl.col("values/n_occurrences") stddev_col = (pl.col("values/sum_sqd") / pl.col("values/n_occurrences") - mean_col**2) ** 0.5 - if "values/mean" in code_metadata.columns: + code_metadata_columns = set(code_metadata.collect_schema().names()) + if "values/mean" in code_metadata_columns: cols_to_select.append("values/mean") else: cols_to_select.append(mean_col.alias("values/mean")) - if "values/std" in code_metadata.columns: + if "values/std" in code_metadata_columns: cols_to_select.append("values/std") else: cols_to_select.append(stddev_col.alias("values/std")) @@ -189,3 +199,23 @@ def normalize( pl.col("code/vocab_index").alias("code"), ((pl.col("numerical_value") - pl.col("values/mean")) / pl.col("values/std")).alias("numerical_value"), ) + + +config_yaml = files("MEDS_polars_functions").joinpath("configs/preprocess.yaml") + + +@hydra.main(version_base=None, config_path=str(config_yaml.parent), config_name=config_yaml.stem) +def main(cfg: DictConfig): + """TODO.""" + + code_metadata = pl.read_parquet( + Path(cfg.stage_cfg.metadata_input_dir) / "code_metadata.parquet", use_pyarrow=True + ).lazy() + code_modifiers = cfg.get("code_modifier_columns", None) + compute_fn = partial(normalize, code_metadata=code_metadata, code_modifiers=code_modifiers) + + map_over(cfg, compute_fn=compute_fn) + + +if __name__ == "__main__": + main() diff --git a/src/MEDS_polars_functions/transforms/occlude_outliers.py b/src/MEDS_polars_functions/transforms/occlude_outliers.py new file mode 100644 index 0000000..fb215ae --- /dev/null +++ b/src/MEDS_polars_functions/transforms/occlude_outliers.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python +"""A polars-to-polars transformation function for filtering patients by sequence length.""" +from collections.abc import Callable +from importlib.resources import files +from pathlib import Path + +import hydra +import polars as pl +from omegaconf import DictConfig + +from MEDS_polars_functions.mapreduce.mapper import map_over + +pl.enable_string_cache() + + +def occlude_outliers_fntr( + stage_cfg: DictConfig, code_metadata: pl.LazyFrame, code_modifier_columns: list[str] | None = None +) -> Callable[[pl.LazyFrame], pl.LazyFrame]: + """Filters patient events to only encompass those with a set of permissible codes. + + Args: + df: The input DataFrame. + stage_cfg: The configuration for the code filtering stage. + + Returns: + The processed DataFrame. + + Examples: + >>> code_metadata_df = pl.DataFrame({ + ... "code": pl.Series(["A", "A", "B", "C"], dtype=pl.Categorical), + ... "modifier1": [1, 2, 1, 2], + ... "values/n_occurrences": [3, 1, 3, 2], + ... "values/sum": [0.0, 4.0, 12.0, 2.0], + ... "values/sum_sqd": [27.0, 16.0, 75.0, 4.0], + ... # for clarity: ----- mean = [0.0, 4.0, 4.0, 1.0] + ... # for clarity: --- stddev = [3.0, 0.0, 3.0, 1.0] + ... }) + >>> data = pl.DataFrame({ + ... "patient_id": [1, 1, 2, 2], + ... "code": pl.Series(["A", "B", "A", "C"], dtype=pl.Categorical), + ... "modifier1": [1, 1, 2, 2], + ... # for clarity: mean [0.0, 4.0, 4.0, 1.0] + ... # for clarity: stddev [3.0, 3.0, 0.0, 1.0] + ... "numerical_value": [15., 16., 3.9, 1.0], + ... }).lazy() + >>> stage_cfg = DictConfig({"stddev_cutoff": 4.5}) + >>> fn = occlude_outliers_fntr(stage_cfg, code_metadata_df, ["modifier1"]) + >>> fn(data).collect() + shape: (4, 5) + ┌────────────┬──────┬───────────┬─────────────────┬───────────────────────────┐ + │ patient_id ┆ code ┆ modifier1 ┆ numerical_value ┆ numerical_value/is_inlier │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ cat ┆ i64 ┆ f64 ┆ bool │ + ╞════════════╪══════╪═══════════╪═════════════════╪═══════════════════════════╡ + │ 1 ┆ A ┆ 1 ┆ null ┆ false │ + │ 1 ┆ B ┆ 1 ┆ 16.0 ┆ true │ + │ 2 ┆ A ┆ 2 ┆ null ┆ false │ + │ 2 ┆ C ┆ 2 ┆ 1.0 ┆ true │ + └────────────┴──────┴───────────┴─────────────────┴───────────────────────────┘ + """ + + stddev_cutoff = stage_cfg.get("stddev_cutoff", None) + if stddev_cutoff is None: + return lambda df: df + + join_cols = ["code"] + if code_modifier_columns: + join_cols.extend(code_modifier_columns) + + cols_to_select = ["code"] + if code_modifier_columns: + cols_to_select.extend(code_modifier_columns) + + mean_col = pl.col("values/sum") / pl.col("values/n_occurrences") + stddev_col = (pl.col("values/sum_sqd") / pl.col("values/n_occurrences") - mean_col**2) ** 0.5 + if "values/mean" not in code_metadata.columns: + cols_to_select.append(mean_col.alias("values/mean")) + if "values/std" not in code_metadata.columns: + cols_to_select.append(stddev_col.alias("values/std")) + + code_metadata = code_metadata.lazy().select(cols_to_select) + + def occlude_outliers_fn(df: pl.LazyFrame) -> pl.LazyFrame: + f"""Filters out outlier numerical values from patient events. + + In particular, this function filters the DataFrame to only include numerical values that are within + {stddev_cutoff} standard deviations of the mean for the corresponding (code, modifier) pair. + """ + + val = pl.col("numerical_value") + mean = pl.col("values/mean") + stddev = pl.col("values/std") + filter_expr = (val - mean).abs() <= stddev_cutoff * stddev + + return ( + df.join(code_metadata, on=join_cols, how="left", coalesce=True) + .with_columns( + filter_expr.alias("numerical_value/is_inlier"), + pl.when(filter_expr).then(pl.col("numerical_value")).alias("numerical_value"), + ) + .drop("values/mean", "values/std") + ) + + return occlude_outliers_fn + + +config_yaml = files("MEDS_polars_functions").joinpath("configs/preprocess.yaml") + + +@hydra.main(version_base=None, config_path=str(config_yaml.parent), config_name=config_yaml.stem) +def main(cfg: DictConfig): + """TODO.""" + + code_metadata = pl.read_parquet( + Path(cfg.stage_cfg.metadata_input_dir) / "code_metadata.parquet", use_pyarrow=True + ) + compute_fn = occlude_outliers_fntr(cfg.stage_cfg, code_metadata) + + map_over(cfg, compute_fn=compute_fn) + + +if __name__ == "__main__": + main() diff --git a/src/MEDS_polars_functions/tensorize.py b/src/MEDS_polars_functions/transforms/tensorize.py similarity index 81% rename from src/MEDS_polars_functions/tensorize.py rename to src/MEDS_polars_functions/transforms/tensorize.py index ad96c55..4121e79 100644 --- a/src/MEDS_polars_functions/tensorize.py +++ b/src/MEDS_polars_functions/transforms/tensorize.py @@ -1,10 +1,19 @@ +#!/usr/bin/env python """Functions for tensorizing MEDS datasets. TODO """ +from functools import partial +from importlib.resources import files + +import hydra import polars as pl from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict +from omegaconf import DictConfig + +from MEDS_polars_functions.mapreduce.mapper import map_over +from MEDS_polars_functions.mapreduce.utils import shard_iterator def convert_to_NRT(tokenized_df: pl.LazyFrame) -> JointNestedRaggedTensorDict: @@ -72,7 +81,7 @@ def convert_to_NRT(tokenized_df: pl.LazyFrame) -> JointNestedRaggedTensorDict: # There should only be one time delta column, but this ensures we catch it regardless of the unit of time # used to convert the time deltas, and that we verify there is only one such column. - time_delta_cols = [c for c in tokenized_df.columns if c.startswith("time_delta_")] + time_delta_cols = [c for c in tokenized_df.collect_schema().names() if c.startswith("time_delta_")] if len(time_delta_cols) == 0: raise ValueError("Expected at least one time delta column, found none") @@ -84,3 +93,22 @@ def convert_to_NRT(tokenized_df: pl.LazyFrame) -> JointNestedRaggedTensorDict: return JointNestedRaggedTensorDict( tokenized_df.select(time_delta_col, "code", "numerical_value").collect().to_dict(as_series=False) ) + + +config_yaml = files("MEDS_polars_functions").joinpath("configs/preprocess.yaml") + + +@hydra.main(version_base=None, config_path=str(config_yaml.parent), config_name=config_yaml.stem) +def main(cfg: DictConfig): + """TODO.""" + + map_over( + cfg, + compute_fn=convert_to_NRT, + output_fn=JointNestedRaggedTensorDict.save, + shard_iterator_fntr=partial(shard_iterator, in_prefix="event_seqs/", out_suffix=".nrt"), + ) + + +if __name__ == "__main__": + main() diff --git a/src/MEDS_polars_functions/tokenize.py b/src/MEDS_polars_functions/transforms/tokenize.py similarity index 83% rename from src/MEDS_polars_functions/tokenize.py rename to src/MEDS_polars_functions/transforms/tokenize.py index 7233ede..c9386f8 100644 --- a/src/MEDS_polars_functions/tokenize.py +++ b/src/MEDS_polars_functions/transforms/tokenize.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python """Functions for tokenizing MEDS datasets. Here, _tokenization_ refers specifically to the process of converting a longitudinal, irregularly sampled, @@ -8,7 +9,18 @@ columns of concern here thus are `patient_id`, `timestamp`, `code`, `numerical_value`. """ +import json +import random +from importlib.resources import files +from pathlib import Path + +import hydra import polars as pl +from loguru import logger +from omegaconf import DictConfig, OmegaConf + +from MEDS_polars_functions.mapreduce.mapper import rwlock_wrap +from MEDS_polars_functions.utils import hydra_loguru_init, write_lazyframe SECONDS_PER_MINUTE = 60.0 SECONDS_PER_HOUR = SECONDS_PER_MINUTE * 60.0 @@ -194,7 +206,7 @@ def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: return ( dynamic.group_by("patient_id", "timestamp", maintain_order=True) - .agg(fill_to_nans("code").keep_name(), fill_to_nans("numerical_value").keep_name()) + .agg(fill_to_nans("code").name.keep(), fill_to_nans("numerical_value").name.keep()) .group_by("patient_id", maintain_order=True) .agg( fill_to_nans(time_delta_days_expr).alias("time_delta_days"), @@ -202,3 +214,64 @@ def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: "numerical_value", ) ) + + +config_yaml = files("MEDS_polars_functions").joinpath("configs/preprocess.yaml") + + +@hydra.main(version_base=None, config_path=str(config_yaml.parent), config_name=config_yaml.stem) +def main(cfg: DictConfig): + """TODO.""" + + hydra_loguru_init() + + logger.info( + f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" + f"Stage: {cfg.stage}\n\n" + f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" + ) + + input_dir = Path(cfg.stage_cfg.data_input_dir) + output_dir = Path(cfg.stage_cfg.output_dir) + + shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) + + patient_splits = list(shards.keys()) + random.shuffle(patient_splits) + + for sp in patient_splits: + in_fp = input_dir / f"{sp}.parquet" + schema_out_fp = output_dir / "schemas" / f"{sp}.parquet" + event_seq_out_fp = output_dir / "event_seqs" / f"{sp}.parquet" + + logger.info(f"Tokenizing {str(in_fp.resolve())} into schemas at {str(schema_out_fp.resolve())}") + + rwlock_wrap( + in_fp, + schema_out_fp, + pl.scan_parquet, + write_lazyframe, + extract_statics_and_schema, + do_return=False, + cache_intermediate=False, + do_overwrite=cfg.do_overwrite, + ) + + logger.info(f"Tokenizing {str(in_fp.resolve())} into event_seqs at {str(event_seq_out_fp.resolve())}") + + rwlock_wrap( + in_fp, + event_seq_out_fp, + pl.scan_parquet, + write_lazyframe, + extract_seq_of_patient_events, + do_return=False, + cache_intermediate=False, + do_overwrite=cfg.do_overwrite, + ) + + logger.info(f"Done with {cfg.stage}") + + +if __name__ == "__main__": + main() diff --git a/src/MEDS_polars_functions/utils.py b/src/MEDS_polars_functions/utils.py index baf7c91..9d19434 100644 --- a/src/MEDS_polars_functions/utils.py +++ b/src/MEDS_polars_functions/utils.py @@ -8,17 +8,63 @@ import hydra import polars as pl from loguru import logger -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf pl.enable_string_cache() -def get_script_docstring() -> str: - """Returns the docstring of the main function of the script that was called. +def write_lazyframe(df: pl.LazyFrame, out_fp: Path) -> None: + if isinstance(df, pl.LazyFrame): + df = df.collect() - Returns: - str: TODO + df.write_parquet(out_fp, use_pyarrow=True) + + +def stage_init(cfg: DictConfig): + """Initializes the stage by logging the configuration and the stage-specific paths. + + Args: + cfg: The global configuration object, which should have a ``cfg.stage_cfg`` attribute containing the + stage specific configuration. + + Returns: The data input directory, stage output directory, metadata input directory, and the shards file + path. """ + hydra_loguru_init() + + logger.info( + f"Running {current_script_name()} with the following configuration:\n{OmegaConf.to_yaml(cfg)}" + ) + + input_dir = Path(cfg.stage_cfg.data_input_dir) + output_dir = Path(cfg.stage_cfg.output_dir) + metadata_input_dir = Path(cfg.stage_cfg.metadata_input_dir) + shards_map_fp = Path(cfg.shards_map_fp) + + def chk(x: Path): + return "✅" if x.exists() else "❌" + + paths_strs = [ + f" - {k}: {chk(v)} {str(v.resolve())}" + for k, v in { + "input_dir": input_dir, + "output_dir": output_dir, + "metadata_input_dir": metadata_input_dir, + "shards_map_fp": shards_map_fp, + }.items() + ] + + logger_strs = [ + f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}", + "Paths: (checkbox indicates if it exists)", + ] + logger.debug("\n".join(logger_strs + paths_strs)) + + return input_dir, output_dir, metadata_input_dir, shards_map_fp + + +def get_script_docstring() -> str: + """Returns the docstring of the main function of the script from which this function was called.""" main_module = sys.modules["__main__"] func = getattr(main_module, "main", None) @@ -28,11 +74,18 @@ def get_script_docstring() -> str: def current_script_name() -> str: - """Returns the name of the script that called this function. + """Returns the name of the module that called this function.""" - Returns: - str: The name of the script that called this function. - """ + main_module = sys.modules["__main__"] + main_func = getattr(main_module, "main", None) + if main_func and callable(main_func): + func_module = main_func.__module__ + if func_module == "__main__": + return Path(sys.argv[0]).stem + else: + return func_module.split(".")[-1] + + logger.warning("Can't find main function in __main__ module. Using sys.argv[0] as a fallback.") return Path(sys.argv[0]).stem @@ -170,13 +223,6 @@ def hydra_loguru_init() -> None: logger.add(os.path.join(hydra_path, f"{logfile_name}.log")) -def write_lazyframe(df: pl.LazyFrame, out_fp: Path) -> None: - if isinstance(df, pl.LazyFrame): - df = df.collect() - - df.write_parquet(out_fp, use_pyarrow=True) - - def get_shard_prefix(base_path: Path, fp: Path) -> str: """Extracts the shard prefix from a file path by removing the raw_cohort_dir. diff --git a/tests/test_extraction.py b/tests/test_extraction.py index 75a7404..b85d46b 100644 --- a/tests/test_extraction.py +++ b/tests/test_extraction.py @@ -1,9 +1,31 @@ -"""Tests the full end-to-end extraction process.""" +"""Tests the full end-to-end extraction process. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + +import os import rootutils root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) +code_root = root / "src" / "MEDS_polars_functions" +extraction_root = code_root / "extraction" + +if os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1": + SHARD_EVENTS_SCRIPT = extraction_root / "shard_events.py" + SPLIT_AND_SHARD_SCRIPT = extraction_root / "split_and_shard_patients.py" + CONVERT_TO_SHARDED_EVENTS_SCRIPT = extraction_root / "convert_to_sharded_events.py" + MERGE_TO_MEDS_COHORT_SCRIPT = extraction_root / "merge_to_MEDS_cohort.py" + AGGREGATE_CODE_METADATA_SCRIPT = code_root / "aggregate_code_metadata.py" +else: + SHARD_EVENTS_SCRIPT = "MEDS_extract-shard_events" + SPLIT_AND_SHARD_SCRIPT = "MEDS_extract-split_and_shard_patients" + CONVERT_TO_SHARDED_EVENTS_SCRIPT = "MEDS_extract-convert_to_sharded_events" + MERGE_TO_MEDS_COHORT_SCRIPT = "MEDS_extract-merge_to_MEDS_cohort" + AGGREGATE_CODE_METADATA_SCRIPT = "MEDS_transform-aggregate_code_metadata" + import json import subprocess import tempfile @@ -246,14 +268,22 @@ def get_expected_output(df: str) -> pl.DataFrame: } -def run_command(script: Path, hydra_kwargs: dict[str, str], test_name: str): - script = str(script.resolve()) - command_parts = ["python", script] + [f"{k}={v}" for k, v in hydra_kwargs.items()] - command_out = subprocess.run(" ".join(command_parts), shell=True, capture_output=True) +def run_command( + script: Path | str, hydra_kwargs: dict[str, str], test_name: str, config_name: str | None = None +): + script = ["python", str(script.resolve())] if isinstance(script, Path) else [script] + command_parts = script + if config_name is not None: + command_parts.append(f"--config-name={config_name}") + command_parts.extend([f"{k}={v}" for k, v in hydra_kwargs.items()]) + + full_cmd = " ".join(command_parts) + command_out = subprocess.run(full_cmd, shell=True, capture_output=True) + stderr = command_out.stderr.decode() stdout = command_out.stdout.decode() if command_out.returncode != 0: - raise AssertionError(f"{test_name} failed!\nstdout:\n{stdout}\nstderr:\n{stderr}") + raise AssertionError(f"{test_name} failed!\ncommand:{full_cmd}\nstdout:\n{stdout}\nstderr:\n{stderr}") return stderr, stdout @@ -290,21 +320,8 @@ def test_extraction(): admit_vitals_parquet = raw_cohort_dir / "admit_vitals.parquet" df = pl.read_csv(admit_vitals_csv) - old_shape = df.shape - assert old_shape[0] > 0, "Should have some rows" - df.write_parquet(admit_vitals_parquet, use_pyarrow=True) - df = pl.scan_parquet(admit_vitals_parquet) - df = df.select(list(df.columns)) - assert ( - df.select(pl.len()).collect().item() == old_shape[0] - ), "Should have the same number of rows after select." - assert old_shape == df.collect().shape, "Shapes should be the same after selecting all columns." - - df = pl.scan_parquet(admit_vitals_parquet) - assert old_shape == df.collect().shape, "Shapes should be the same after scanning the parquet file." - # Write the event config YAML event_cfgs_yaml.write_text(EVENT_CFGS_YAML) @@ -327,15 +344,11 @@ def test_extraction(): "hydra.verbose": True, } - extraction_root = root / "scripts" / "extraction" - all_stderrs = [] all_stdouts = [] # Step 1: Sub-shard the data - stderr, stdout = run_command( - extraction_root / "shard_events.py", extraction_config_kwargs, "shard_events" - ) + stderr, stdout = run_command(SHARD_EVENTS_SCRIPT, extraction_config_kwargs, "shard_events") all_stderrs.append(stderr) all_stdouts.append(stdout) @@ -380,9 +393,16 @@ def test_extraction(): check_row_order=False, ) + # Step 2: Collect the patient splits + # stderr, stdout = run_command( + # "MEDS_extract_shard_patients", + # {**extraction_config_kwargs, "stage":"split_and_shard_patients"}, + # "split_and_shard_patients", + # ) + # Step 2: Collect the patient splits stderr, stdout = run_command( - extraction_root / "split_and_shard_patients.py", + SPLIT_AND_SHARD_SCRIPT, extraction_config_kwargs, "split_and_shard_patients", ) @@ -417,7 +437,7 @@ def test_extraction(): # Step 3: Extract the events and sub-shard by patient stderr, stdout = run_command( - extraction_root / "convert_to_sharded_events.py", + CONVERT_TO_SHARDED_EVENTS_SCRIPT, extraction_config_kwargs, "convert_events", ) @@ -455,7 +475,7 @@ def test_extraction(): # Step 4: Merge to the final output stderr, stdout = run_command( - extraction_root / "merge_to_MEDS_cohort.py", + MERGE_TO_MEDS_COHORT_SCRIPT, extraction_config_kwargs, "merge_sharded_events", ) @@ -501,9 +521,10 @@ def test_extraction(): # Step 4: Merge to the final output stderr, stdout = run_command( - extraction_root / "collect_code_metadata.py", + AGGREGATE_CODE_METADATA_SCRIPT, extraction_config_kwargs, - "collect_code_metadata", + "aggregate_code_metadata", + config_name="extraction", ) all_stderrs.append(stderr) all_stdouts.append(stdout) @@ -518,11 +539,11 @@ def test_extraction(): want_df = pl.read_csv(source=StringIO(MEDS_OUTPUT_CODE_METADATA_FILE)).with_columns( pl.col("code").cast(pl.Categorical), - pl.col("code/n_occurrences").cast(pl.UInt32), - pl.col("code/n_patients").cast(pl.UInt32), - pl.col("values/n_occurrences").cast(pl.UInt32), - pl.col("values/sum").cast(pl.Float64).fill_null(0), - pl.col("values/sum_sqd").cast(pl.Float64).fill_null(0), + pl.col("code/n_occurrences").cast(pl.UInt8), + pl.col("code/n_patients").cast(pl.UInt8), + pl.col("values/n_occurrences").cast(pl.UInt8), + pl.col("values/sum").cast(pl.Float32).fill_null(0), + pl.col("values/sum_sqd").cast(pl.Float32).fill_null(0), ) assert_df_equal(