diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7540f52..1533f74 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,6 +38,7 @@ repos: rev: v2.2.0 hooks: - id: autoflake + args: [--in-place, --remove-all-unused-imports] # python upgrading syntax to newer version - repo: https://github.com/asottile/pyupgrade diff --git a/MIMIC-IV_Example/README.md b/MIMIC-IV_Example/README.md index 535aa56..406f1f2 100644 --- a/MIMIC-IV_Example/README.md +++ b/MIMIC-IV_Example/README.md @@ -33,10 +33,9 @@ Download this repository and install the requirements: ```bash git clone git@github.com:mmcdermott/MEDS_polars_functions.git cd MEDS_polars_functions -git checkout MIMIC_IV conda create -n MEDS python=3.12 conda activate MEDS -pip install .[mimic] +pip install .[examples] ``` ## Step 1: Download MIMIC-IV @@ -73,6 +72,8 @@ In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less ## Step 3: Run the MEDS extraction ETL +### Running locally, serially + We will assume you want to output the final MEDS dataset into a directory we'll denote as `$MIMICIV_MEDS_DIR`. Note this is a different directory than the pre-MEDS directory (though, of course, they can both be subdirectories of the same root directory). @@ -80,12 +81,78 @@ subdirectories of the same root directory). This is a step in 4 parts: 1. Sub-shard the raw files. Run this command as many times simultaneously as you would like to have workers - performing this sub-sharding step. + performing this sub-sharding step. See below for how to automate this parallelism using hydra launchers. + +```bash +./scripts/extraction/shard_events.py \ + input_dir=$MIMICIV_PREMEDS_DIR \ + cohort_dir=$MIMICIV_MEDS_DIR \ + event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml +``` + +In practice, on a machine with 150 GB of RAM and 10 cores, this step takes approximately 20 minutes in total. + +2. Extract and form the patient splits and sub-shards. + +```bash +./scripts/extraction/split_and_shard_patients.py \ + input_dir=$MIMICIV_PREMEDS_DIR \ + cohort_dir=$MIMICIV_MEDS_DIR \ + event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml +``` + +In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less than 5 minutes in total. + +3. Extract patient sub-shards and convert to MEDS events. + +```bash +./scripts/extraction/convert_to_sharded_events.py \ + input_dir=$MIMICIV_PREMEDS_DIR \ + cohort_dir=$MIMICIV_MEDS_DIR \ + event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml +``` + +In practice, serially, this also takes around 20 minutes or more. However, it can be trivially parallelized to +cut the time down by a factor of the number of workers processing the data by simply running the command +multiple times (though this will, of course, consume more resources). If your filesystem is distributed, these +commands can also be launched as separate slurm jobs, for example. For MIMIC-IV, this level of parallelization +and performance is not necessary; however, for larger datasets, it can be. + +4. Merge the MEDS events into a single file per patient sub-shard. + +```bash +./scripts/extraction/merge_to_MEDS_cohort.py \ + input_dir=$MIMICIV_PREMEDS_DIR \ + cohort_dir=$MIMICIV_MEDS_DIR \ + event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml +``` + +### Running Locally, in Parallel. + +This step is the exact same commands as above, but leverages Hydra's multirun capabilities with the `joblib` +launcher. Install this package with the optional `local_parallelism` option (e.g., `pip install -e .[local_parallelism]` and run `./MIMIC-IV_Example/joint_script.sh`. See that script for expected args. + +### Running Each Step over Slurm + +To use slurm, run each command with the number of workers desired using Hydra's multirun capabilities with the +`submitit_slurm` launcher. Install this package with the optional `slurm_parallelism` option. See below for +modified commands. Note these can't be chained in a single script as the jobs will not wait for all slurm jobs +to finish before moving on to the next stage. Let `$N_PARALLEL_WORKERS` be the number of desired workers + +1. Sub-shard the raw files. ```bash ./scripts/extraction/shard_events.py \ - raw_cohort_dir=$MIMICIV_PREMEDS_DIR \ - MEDS_cohort_dir=$MIMICIV_MEDS_DIR \ + --multirun \ + worker="range(0,$N_PARALLEL_WORKERS)" \ + hydra/launcher=submitit_slurm \ + hydra.launcher.timeout_min=60 \ + hydra.launcher.cpus_per_task=10 \ + hydra.launcher.mem_gb=50 \ + hydra.launcher.name="${hydra.job.name}_${worker}" \ + hydra.launcher.partition="short" \ + input_dir=$MIMICIV_PREMEDS_DIR \ + cohort_dir=$MIMICIV_MEDS_DIR \ event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml ``` @@ -95,8 +162,8 @@ In practice, on a machine with 150 GB of RAM and 10 cores, this step takes appro ```bash ./scripts/extraction/split_and_shard_patients.py \ - raw_cohort_dir=$MIMICIV_PREMEDS_DIR \ - MEDS_cohort_dir=$MIMICIV_MEDS_DIR \ + input_dir=$MIMICIV_PREMEDS_DIR \ + cohort_dir=$MIMICIV_MEDS_DIR \ event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml ``` @@ -106,8 +173,8 @@ In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less ```bash ./scripts/extraction/convert_to_sharded_events.py \ - raw_cohort_dir=$MIMICIV_PREMEDS_DIR \ - MEDS_cohort_dir=$MIMICIV_MEDS_DIR \ + input_dir=$MIMICIV_PREMEDS_DIR \ + cohort_dir=$MIMICIV_MEDS_DIR \ event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml ``` @@ -121,8 +188,8 @@ and performance is not necessary; however, for larger datasets, it can be. ```bash ./scripts/extraction/merge_to_MEDS_cohort.py \ - raw_cohort_dir=$MIMICIV_PREMEDS_DIR \ - MEDS_cohort_dir=$MIMICIV_MEDS_DIR \ + input_dir=$MIMICIV_PREMEDS_DIR \ + cohort_dir=$MIMICIV_MEDS_DIR \ event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml ``` diff --git a/MIMIC-IV_Example/__init__.py b/MIMIC-IV_Example/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MIMIC-IV_Example/joint_script.sh b/MIMIC-IV_Example/joint_script.sh new file mode 100755 index 0000000..bf3438e --- /dev/null +++ b/MIMIC-IV_Example/joint_script.sh @@ -0,0 +1,76 @@ +#!/usr/bin/env bash + +# This makes the script fail if any internal script fails +set -e + +# Function to display help message +function display_help() { + echo "Usage: $0 " + echo + echo "This script processes MIMIC-IV data through several steps, handling raw data conversion," + echo "sharding events, splitting patients, converting to sharded events, and merging into a MEDS cohort." + echo + echo "Arguments:" + echo " MIMICIV_RAW_DIR Directory containing raw MIMIC-IV data files." + echo " MIMICIV_PREMEDS_DIR Output directory for pre-MEDS data." + echo " MIMICIV_MEDS_DIR Output directory for processed MEDS data." + echo " N_PARALLEL_WORKERS Number of parallel workers for processing." + echo + echo "Options:" + echo " -h, --help Display this help message and exit." + exit 1 +} + +# Check if the first parameter is '-h' or '--help' +if [[ "$1" == "-h" || "$1" == "--help" ]]; then + display_help +fi + +# Check for mandatory parameters +if [ "$#" -lt 4 ]; then + echo "Error: Incorrect number of arguments provided." + display_help +fi + +MIMICIV_RAW_DIR="$1" +MIMICIV_PREMEDS_DIR="$2" +MIMICIV_MEDS_DIR="$3" +N_PARALLEL_WORKERS="$4" + +shift 4 + +echo "Running pre-MEDS conversion." +./MIMIC-IV_Example/pre_MEDS.py raw_cohort_dir="$MIMICIV_RAW_DIR" output_dir="$MIMICIV_PREMEDS_DIR" + +echo "Running shard_events.py with $N_PARALLEL_WORKERS workers in parallel" +./scripts/extraction/shard_events.py \ + --multirun \ + worker="range(0,$N_PARALLEL_WORKERS)" \ + hydra/launcher=joblib \ + input_dir="$MIMICIV_PREMEDS_DIR" \ + cohort_dir="$MIMICIV_MEDS_DIR" \ + event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" + +echo "Splitting patients in serial" +./scripts/extraction/split_and_shard_patients.py \ + input_dir="$MIMICIV_PREMEDS_DIR" \ + cohort_dir="$MIMICIV_MEDS_DIR" \ + event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" + +echo "Converting to sharded events with $N_PARALLEL_WORKERS workers in parallel" +./scripts/extraction/convert_to_sharded_events.py \ + --multirun \ + worker="range(0,$N_PARALLEL_WORKERS)" \ + hydra/launcher=joblib \ + input_dir="$MIMICIV_PREMEDS_DIR" \ + cohort_dir="$MIMICIV_MEDS_DIR" \ + event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" + +echo "Merging to a MEDS cohort with $N_PARALLEL_WORKERS workers in parallel" +./scripts/extraction/merge_to_MEDS_cohort.py \ + --multirun \ + worker="range(0,$N_PARALLEL_WORKERS)" \ + hydra/launcher=joblib \ + input_dir="$MIMICIV_PREMEDS_DIR" \ + cohort_dir="$MIMICIV_MEDS_DIR" \ + event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" diff --git a/MIMIC-IV_Example/joint_script_slurm.sh b/MIMIC-IV_Example/joint_script_slurm.sh new file mode 100755 index 0000000..feb7fd3 --- /dev/null +++ b/MIMIC-IV_Example/joint_script_slurm.sh @@ -0,0 +1,112 @@ +#!/usr/bin/env bash + +# This makes the script fail if any internal script fails +set -e + +# Function to display help message +function display_help() { + echo "Usage: $0 " + echo + echo "This script processes MIMIC-IV data through several steps, handling raw data conversion," + echo "sharding events, splitting patients, converting to sharded events, and merging into a MEDS cohort." + echo "This script uses slurm to process the data in parallel via the 'submitit' Hydra launcher." + echo + echo "Arguments:" + echo " MIMICIV_RAW_DIR Directory containing raw MIMIC-IV data files." + echo " MIMICIV_PREMEDS_DIR Output directory for pre-MEDS data." + echo " MIMICIV_MEDS_DIR Output directory for processed MEDS data." + echo " N_PARALLEL_WORKERS Number of parallel workers for processing." + echo + echo "Options:" + echo " -h, --help Display this help message and exit." + exit 1 +} + +# Check if the first parameter is '-h' or '--help' +if [[ "$1" == "-h" || "$1" == "--help" ]]; then + display_help +fi + +# Check for mandatory parameters +if [ "$#" -ne 4 ]; then + echo "Error: Incorrect number of arguments provided." + display_help +fi + +export MIMICIV_RAW_DIR="$1" +export MIMICIV_PREMEDS_DIR="$2" +export MIMICIV_MEDS_DIR="$3" +export N_PARALLEL_WORKERS="$4" + +shift 4 + +# Note we use `--multirun` throughout here due to ensure the submitit launcher is used throughout, so that +# this doesn't fall back on running anything locally in a setting where only slurm worker nodes have +# sufficient computational resources to run the actual jobs. + +# echo "Running pre-MEDS conversion on one worker." +# ./MIMIC-IV_Example/pre_MEDS.py \ +# --multirun \ +# worker="range(0,1)" \ +# hydra/launcher=submitit_slurm \ +# hydra.launcher.timeout_min=60 \ +# hydra.launcher.cpus_per_task=10 \ +# hydra.launcher.mem_gb=50 \ +# hydra.launcher.partition="short" \ +# raw_cohort_dir="$MIMICIV_RAW_DIR" \ +# output_dir="$MIMICIV_PREMEDS_DIR" + +echo "Trying submitit launching with $N_PARALLEL_WORKERS jobs." + +./scripts/extraction/shard_events.py \ + --multirun \ + worker="range(0,$N_PARALLEL_WORKERS)" \ + hydra/launcher=submitit_slurm \ + hydra.launcher.timeout_min=60 \ + hydra.launcher.cpus_per_task=10 \ + hydra.launcher.mem_gb=50 \ + hydra.launcher.partition="short" \ + "hydra.job.env_copy=[PATH]" \ + input_dir="$MIMICIV_PREMEDS_DIR" \ + cohort_dir="$MIMICIV_MEDS_DIR" \ + event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml \ + stage=shard_events + +# echo "Splitting patients on one worker" +# ./scripts/extraction/split_and_shard_patients.py \ +# --multirun \ +# worker="range(0,1)" \ +# hydra/launcher=submitit_slurm \ +# hydra.launcher.timeout_min=60 \ +# hydra.launcher.cpus_per_task=10 \ +# hydra.launcher.mem_gb=50 \ +# hydra.launcher.partition="short" \ +# input_dir="$MIMICIV_PREMEDS_DIR" \ +# cohort_dir="$MIMICIV_MEDS_DIR" \ +# event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" +# +# echo "Converting to sharded events with $N_PARALLEL_WORKERS workers in parallel" +# ./scripts/extraction/convert_to_sharded_events.py \ +# --multirun \ +# worker="range(0,$N_PARALLEL_WORKERS)" \ +# hydra/launcher=submitit_slurm \ +# hydra.launcher.timeout_min=60 \ +# hydra.launcher.cpus_per_task=10 \ +# hydra.launcher.mem_gb=50 \ +# hydra.launcher.partition="short" \ +# input_dir="$MIMICIV_PREMEDS_DIR" \ +# cohort_dir="$MIMICIV_MEDS_DIR" \ +# event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" +# +# echo "Merging to a MEDS cohort with $N_PARALLEL_WORKERS workers in parallel" +# ./scripts/extraction/merge_to_MEDS_cohort.py \ +# --multirun \ +# worker="range(0,$N_PARALLEL_WORKERS)" \ +# hydra/launcher=submitit_slurm \ +# hydra.launcher.timeout_min=60 \ +# hydra.launcher.cpus_per_task=10 \ +# hydra.launcher.mem_gb=50 \ +# hydra.launcher.partition="short" \ +# input_dir="$MIMICIV_PREMEDS_DIR" \ +# cohort_dir="$MIMICIV_MEDS_DIR" \ +# event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" diff --git a/MIMIC-IV_Example/pre_MEDS.py b/MIMIC-IV_Example/pre_MEDS.py index 1f2f223..bf99f3a 100755 --- a/MIMIC-IV_Example/pre_MEDS.py +++ b/MIMIC-IV_Example/pre_MEDS.py @@ -59,7 +59,7 @@ def fix_static_data(raw_static_df: pl.LazyFrame, death_times_df: pl.LazyFrame) - return raw_static_df.join(death_times_df, on="subject_id", how="left").select( "subject_id", - pl.coalesce(pl.col("dod"), pl.col("deathtime")).alias("dod"), + pl.coalesce(pl.col("deathtime"), pl.col("dod")).alias("dod"), (pl.col("anchor_year") - pl.col("anchor_age")).cast(str).alias("year_of_birth"), "gender", ) @@ -94,6 +94,11 @@ def main(cfg: DictConfig): pfx = get_shard_prefix(raw_cohort_dir, in_fp) out_fp = MEDS_input_dir / in_fp.relative_to(raw_cohort_dir) + + if out_fp.is_file(): + print(f"Done with {pfx}. Continuing") + continue + out_fp.parent.mkdir(parents=True, exist_ok=True) if pfx not in FUNCTIONS: @@ -101,11 +106,15 @@ def main(cfg: DictConfig): f"No function needed for {pfx}: " f"Symlinking {str(in_fp.resolve())} to {str(out_fp.resolve())}" ) - relative_in_fp = in_fp.relative_to(out_fp.parent, walk_up=True) + relative_in_fp = in_fp.relative_to(out_fp.resolve().parent, walk_up=True) out_fp.symlink_to(relative_in_fp) continue else: out_fp = MEDS_input_dir / f"{pfx}.parquet" + if out_fp.is_file(): + print(f"Done with {pfx}. Continuing") + continue + fn, need_df = FUNCTIONS[pfx] if not need_df: st = datetime.now() diff --git a/MIMIC-IV_Example/sbatch_joint_script.sh b/MIMIC-IV_Example/sbatch_joint_script.sh new file mode 100644 index 0000000..75d3281 --- /dev/null +++ b/MIMIC-IV_Example/sbatch_joint_script.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +#SBATCH -c 10 # Request one core +#SBATCH -t 0-03:00 # Runtime in D-HH:MM format +#SBATCH -p short # Partition to run in +#SBATCH --mem=300GB # Memory total in MiB (for all cores) +#SBATCH -o MIMIC_IV_MEDS_%j_sbatch.out # File to which STDOUT will be written, including job ID (%j) +#SBATCH -e MIMIC_IV_MEDS_%j_sbatch.err # File to which STDERR will be written, including job ID (%j) + +cd /n/data1/hms/dbmi/zaklab/mmd/MEDS_polars_functions || exit + +MIMICIV_MEDS_DIR="$3" + +LOG_DIR="$MIMICIV_MEDS_DIR/.logs" + +echo "Running with saving to $LOG_DIR" + +mkdir -p "$LOG_DIR" + +PATH="/home/mbm47/.conda/envs/MEDS_pipelines/bin:$PATH" \ + time mprof run --include-children --exit-code --output "$LOG_DIR/mprofile.dat" \ + ./MIMIC-IV_Example/joint_script.sh "$@" 2> "$LOG_DIR/timings.txt" diff --git a/README.md b/README.md index e03dda6..d57b82b 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,11 @@ more information. This package provides three things: 1. A working, scalable, simple example of how to extract and pre-process MEDS data for downstream modeling. + These examples are provided in the form of: + - A set of integration tests that are run over synthetic data to verify correctness of the ETL pipeline. + See `tests/test_extraction.py` for the ETL tests with the in-built synthetic source data. + - A working MIMIC-IV MEDS ETL pipeline that can be run over MIMIC-IV v2.2 in approximately 1 hour in serial + mode (and much faster if parallelized). See `MIMIC-IV_Example` for more details. 2. A flexible ETL for extracting MEDS data from a variety of source formats. 3. A pre-processing pipeline that can be used for models that require: - Filtering data to only include patients with a certain number of events @@ -27,9 +32,92 @@ This package provides three things: ## Installation -For now, clone this repository and run `pip install -e .` from the repository root. +- For a base installation, clone this repository and run `pip install .` from the repository root. +- For running the MIMIC-IV example, install the optional MIMIC dependencies as well with `pip install .[mimic]`. +- To support same-machine, process-based parallelism, install the optional joblib dependencies with `pip install .[local_parallelism]`. +- To support cluster-based parallelism, install the optional submitit dependencies with `pip install .[slurm_parallelism]`. +- For working on development, install the optional development dependencies with `pip install .[dev,tests]`. +- Optional dependencies can be mutually installed by combining the optional dependency names with commas in + the square brackets, e.g., `pip install .[mimic,local_parallelism]`. + +## Usage -- High Level + +The MEDS ETL and pre-processing pipelines are designed to be run in a modular, stage-based manner, with each +stage of the pipeline being run as a separate script. For a single pipeline, all scripts will take the same +arguments by leveraging the same Hydra configuration file, and to run multiple workers on a single stage in +parallel, the user can launch the same script multiple times _without changing the arguments or configuration +file_, and the scripts will automatically handle the parallelism and avoid duplicative work. This permits +tremendous flexibility in how these pipelines can be run. + +- The user can run the entire pipeline in serial, through a single shell script simply by calling each + stage's script in sequence. +- The user can leverage arbitrary scheduling systems (e.g., Slurm, LSF, Kubernetes, etc.) to run each stage + in parallel on a cluster, by constructing the appropriate worker scripts to run each stage's script and + simply launching as many worker jobs as is desired (note this will typically required a distributed file + system to work correctly, as these scripts use manually created file locks to avoid duplicative work). +- The user can run each stage in parallel on a single machine by launching multiple copies of the same + script in different terminal sessions. This can result in a significant speedup depending on the machine + configuration as it ensures that parallelism can be used with minimal file read contention. + +Two of these methods of parallelism, in particular local-machine parallelism and slurm-based cluster +parallelism, are supported explicitly by this package through the use of the `joblib` and `submitit` Hydra +plugins and Hydra's multirun capabilities, which will be discussed in more detail below. + +By following this design convention, each individual stage of the pipeline can be kept extremely simple (often +each stage corresponds simply to a single short "dataframe" function), can be rigorously tested, can be cached +after completion to permit easy re-suming or re-running of the pipeline, and permits extremely flexible and +efficient (through parallelization) use of the pipeline in a variety of environments, all without imposing +significant complexity, overhead, or computational dependencies on the user. + +Below we walk through usage of this mechanism for both the ETL and the model-specific pre-processing +pipelines in more detail. + +### Scripts for the ETL Pipeline + +The ETL pipeline (which is more complete, and likely to be viable for a wider range of input datasets out of +the box) relies on the following configuration files and scripts: + +Configuration: `configs/extraction.yaml` -## MEDS ETL / Extraction Pipeline +```yaml +# The event conversion configuration file is used throughout the pipeline to define the events to extract. +event_conversion_config_fp: ??? + +stages: + - shard_events + - split_and_shard_patients + - convert_to_sharded_events + - merge_to_MEDS_cohort + +stage_configs: + shard_events: + row_chunksize: 200000000 + infer_schema_length: 10000 + split_and_shard_patients: + is_metadata: true + output_dir: ${cohort_dir} + n_patients_per_shard: 50000 + external_splits_json_fp: + split_fracs: + train: 0.8 + tuning: 0.1 + held_out: 0.1 + merge_to_MEDS_cohort: + output_dir: ${cohort_dir}/final_cohort +``` + +Scripts: + +1. `shard_events.py`: Shards the input data into smaller, event-level shards. +2. `split_and_shard_patients.py`: Splits the patient population into ML splits and shards these splits into + patient-level shards. +3. `convert_to_sharded_events.py`: Converts the input, event-level shards into the MEDS event format and + sub-shards them into patient-level sub-shards. +4. `merge_to_MEDS_cohort.py`: Merges the patient-level, event-level shards into full patient-level shards. + +See the `MIMIC-IV_Example` directory for a full, worked example of the ETL on MIMIC-IV v2.2. + +## MEDS ETL / Extraction Pipeline Details ### Overview @@ -197,6 +285,39 @@ running multiple copies of the same script on independent workers to process the steps again need to happen in a single-threaded manner, but these steps are generally very fast and should not be a bottleneck. +## Overview of configuration manipulation + +### Pipeline configuration: Stages and OmegaConf Resolvers + +The pipeline configuration file for both the provided extraction and pre-processing pipelines are structured +to permit both ease of understanding, flexibility for user-derived modifications, and ease of use in the +simple, file-in/file-out scripts that this repository promotes. How this works is that each pipeline +(extraction and pre-processing) defines one global configuration file which is used as the Hydra specification +for all scripts in that pipeline. This file leverages some generic pipeline configuration options, specified +in `pipeline.yaml` and imported via the Hydra `defaults:` list, but also defines a list of stages with +stage-specific configurations. + +The user can specify the stage in question on the command line either manually (e.g., `stage=stage_name`) or +allow the stage name to be inferred automatically from the script name. Each script receives both the global +configuration file but also a sub-configuration (within the `stage_cfg` node in the received global +configuration) which is pre-populated with the stage-specific configuration for the stage in question and +automatically inferred input and output file paths (if not overwritten in the config file) based on the stage +name and its position in the overall pipeline. This makes it easy to leverage transformations and scripts +defined here in new configuration pipelines, simply by placing them as a stage in a broader pipeline in a +different configuration or order relative to other stages. + +### Running the Pipeline in Parallel via Hydra Multirun + +We support two (optional) hydra multirun job launchers for parallelizing ETL and pre-processing pipeline +steps: [`joblib`](https://hydra.cc/docs/plugins/joblib_launcher/) (for local parallelism) and +[`submitit`](https://hydra.cc/docs/plugins/submitit_launcher/) to launch things with slurm for cluster +parallelism. + +To use either of these, you need to install additional optional dependencies: + +1. `pip install -e .[local_parallelism]` for joblib local parallelism support, or +2. `pip install -e .[slurm_parallelism]` for submitit cluster parallelism support. + ## TODOs: 1. We need to have a vehicle to cleanly separate dataset-specific variables from the general configuration diff --git a/configs/extraction.yaml b/configs/extraction.yaml index 54708d0..1a1c0dd 100644 --- a/configs/extraction.yaml +++ b/configs/extraction.yaml @@ -1,31 +1,67 @@ -# Raw data -raw_cohort_dir: ??? -MEDS_cohort_dir: ??? +defaults: + - pipeline + - _self_ -# Event Conversion -event_conversion_config_fp: ??? - -# Splits -external_splits_json_fp: null -split_fracs: - train: 0.8 - tuning: 0.1 - held_out: 0.1 +description: |- + This pipeline extracts raw MEDS events in longitudinal, sparse form from an input dataset meeting select + criteria and converts them to the flattened, MEDS format. It can be run in its entirety, with controllable + levels of parallelism, or in stages. Arguments: + - `event_conversion_config_fp`: The path to the event conversion configuration file. This file defines + the events to extract from the various rows of the various input files encountered in the global input + directory. + - `input_dir`: The path to the directory containing the raw input files. + - `cohort_dir`: The path to the directory where the output cohort will be written. It will be written in + various subfolders of this dir depending on the stage, as intermediate stages cache their output during + computation for efficiency of re-running and distributing. -# Sharding -row_chunksize: 200000000 -n_patients_per_shard: 50000 -infer_schema_length: 10000 +# The event conversion configuration file is used throughout the pipeline to define the events to extract. +event_conversion_config_fp: ??? -# Misc -do_overwrite: False -seed: 1 +stages: + - shard_events + - split_and_shard_patients + - convert_to_sharded_events + - merge_to_MEDS_cohort -# Hydra -hydra: - job: - name: MEDS_ETL_step_${now:%Y-%m-%d_%H-%M-%S} - run: - dir: ${MEDS_cohort_dir}/.logs/etl/${hydra.job.name} - sweep: - dir: ${MEDS_cohort_dir}/.logs/etl/${hydra.job.name} +stage_configs: + shard_events: + description: |- + This stage shards the raw input events into smaller files for easier processing. Arguments: + - `row_chunksize`: The number of rows to read in at a time. + - `infer_schema_length`: The number of rows to read in to infer the schema (only used if the source + files are csvs) + row_chunksize: 200000000 + infer_schema_length: 10000 + split_and_shard_patients: + description: |- + This stage splits the patients into training, tuning, and held-out sets, and further splits those sets + into shards. Arguments: + - `n_patients_per_shard`: The number of patients to include in a shard. + - `external_splits_json_fp`: The path to a json file containing any pre-defined splits for specially + held-out test sets beyond the IID held out set that will be produced (e.g., for prospective + datasets, etc.). + - `split_fracs`: The fraction of patients to include in the IID training, tuning, and held-out sets. + Split fractions can be changed for the default names by adding a hydra-syntax command line argument + for the nested name; e.g., `split_fracs.train=0.7 split_fracs.tuning=0.1 split_fracs.held_out=0.2`. + A split can be removed with the `~` override Hydra syntax. Similarly, a new split name can be added + with the standard Hydra `+` override option. E.g., `~split_fracs.held_out +split_fracs.test=0.1`. It + is the user's responsibility to ensure that split fractions sum to 1. + is_metadata: True + output_dir: ${cohort_dir} + n_patients_per_shard: 50000 + external_splits_json_fp: null + split_fracs: + train: 0.8 + tuning: 0.1 + held_out: 0.1 + merge_to_MEDS_cohort: + description: |- + This stage splits the patients into training, tuning, and held-out sets, and further splits those sets + into shards. Arguments: + - `n_patients_per_shard`: The number of patients to include in a shard. + - `external_splits_json_fp`: The path to a json file containing any pre-defined splits for specially + held-out test sets beyond the IID held out set that will be produced (e.g., for prospective + datasets, etc.). + - `split_fracs`: The fraction of patients to include in the IID training, tuning, and held-out sets. + output_dir: ${cohort_dir}/final_cohort + unique_by: "*" diff --git a/configs/pipeline.yaml b/configs/pipeline.yaml new file mode 100644 index 0000000..229ea26 --- /dev/null +++ b/configs/pipeline.yaml @@ -0,0 +1,43 @@ +# Global IO +input_dir: ??? +cohort_dir: ??? + +_default_description: |- + This is a MEDS pipeline ETL. Please set a more detailed description at the top of your specific pipeline + configuration file. + +log_dir: "${cohort_dir}/.logs/${stage}" + +# General pipeline variables +do_overwrite: False +seed: 1 +stages: ??? # The list of stages to this overall pipeline (in order) +stage_configs: ??? # The configurations for each stage, keyed by stage name + +# Mapreduce information +worker: 1 +polling_time: 300 # wait time in seconds before beginning reduction steps + +# Filling in the current stage +stage: ${current_script_name:} +stage_cfg: ${oc.create:${populate_stage:${stage}, ${input_dir}, ${cohort_dir}, ${stages}, ${stage_configs}}} + +# Hydra +hydra: + job: + name: "${stage}_${worker}_${now:%Y-%m-%d_%H-%M-%S}" + run: + dir: "${log_dir}" + sweep: + dir: "${log_dir}" + help: + app_name: "MEDS/${stage}" + template: |- + == ${hydra.help.app_name} == + ${hydra.help.app_name} is a command line tool that provides an interface for running MEDS pipelines. + + **Pipeline description:** + ${oc.select:description, ${_default_description}} + + **Stage description:** + ${oc.select:stage_configs.${stage}.description, ${get_script_docstring:}} diff --git a/configs/preprocess.yaml b/configs/preprocess.yaml index 397ff93..d65150b 100644 --- a/configs/preprocess.yaml +++ b/configs/preprocess.yaml @@ -1,59 +1,51 @@ -# Raw data -MEDS_cohort_dir: ??? -output_data_dir: ??? -log_dir: "${output_data_dir}/.logs" - -# Worker / Stage information -stage: ??? -worker: 1 -polling_time: 300 # wait time in seconds before beginning reduction steps - -# Filtering parameters -min_code_occurrences: null -min_events_per_patient: null -min_measurements_per_patient: null - -# Time-derived measurements -time_derived_measurements: - age: - dob_code: ??? - age_code: "AGE" - age_unit: "years" - time_of_day: - bin_endpoints: [6, 12, 18, 24] - -# Code modifiers will be used as adjoining parts of the `code` columns during group-bys and eventual +defaults: + - pipeline + - _self_ + +# Global pipeline parameters: +# 1. Code modifiers will be used as adjoining parts of the `code` columns during group-bys and eventual # tokenization. code_modifier_columns: ??? -# Code metadata extraction. These may contain duplicates because the data may be filtered between different -# stages, depending on the pipeline in question. -code_processing_stages: - preliminary_counts: - - "code/n_occurrences" - - "code/n_patients" - outlier_detection: - - "values/n_occurrences" - - "values/sum" - - "values/sum_sqd" - normalization: - - "code/n_occurrences" - - "code/n_patients" - - "values/n_occurrences" - - "values/sum" - - "values/sum_sqd" - -# Outlier detection -outlier_stddev_cutoff: 4.5 - -# Misc -do_overwrite: False - -# Hydra -hydra: - job: - name: "MEDS_Preprocessor/stage_${stage}/worker_${worker}/${now:%Y-%m-%d_%H-%M-%S}" - run: - dir: "${log_dir}/${hydra.job.name}" - sweep: - dir: "${log_dir}/${hydra.job.name}" +# Pipeline Structure +stages: + - name: filter_patients + min_events_per_patient: null + min_measurements_per_patient: null + + - name: add_time_derived_measurements + age: + dob_code: ??? + age_code: "AGE" + age_unit: "years" + time_of_day: + bin_endpoints: [6, 12, 18, 24] + + - name: preliminary_counts + obs_aggregations: + - "code/n_occurrences" + - "code/n_patients" + + - name: filter_codes + min_code_occurrences: null + + - name: fit_outlier_detection + aggregations: + - "values/n_occurrences" + - "values/sum" + - "values/sum_sqd" + + - name: filter_outliers + stddev_cutoff: 4.5 + + - name: fit_normalization + aggregations: + - "code/n_occurrences" + - "code/n_patients" + - "values/n_occurrences" + - "values/sum" + - "values/sum_sqd" + + - name: normalization + - name: tokenization + - name: tensorization diff --git a/eICU_Example/README.md b/eICU_Example/README.md new file mode 100644 index 0000000..0984b99 --- /dev/null +++ b/eICU_Example/README.md @@ -0,0 +1,225 @@ +# eICU-CRD Example + +This is an example of how to extract a MEDS dataset from [eICU-CRD +v2.0](https://physionet.org/content/eicu-crd/2.0/). All scripts in this README are assumed to +be run **not** from this directory but from the root directory of this entire repository (e.g., one directory +up from this one). + +**Status**: This is a work in progress. The code is not yet functional. Remaining work includes: + +- [ ] Implementing the pre-MEDS processing step. + - [ ] Identifying the pre-MEDS steps for eICU +- [ ] Testing the pre-MEDS processing step on live eICU-CRD. + - [ ] Test that it runs at all. + - [ ] Test that the output is as expected. +- [ ] Check the installation instructions on a fresh client. +- [ ] Testing the `configs/event_configs.yaml` configuration on eICU-CRD +- [ ] Testing the MEDS extraction ETL runs on eICU-CRD (this should be expected to work, but needs + live testing). + - [ ] Sub-sharding + - [ ] Patient split gathering + - [ ] Event extraction + - [ ] Merging +- [ ] Validating the output MEDS cohort + - [ ] Basic validation + - [ ] Detailed validation + +## Step 0: Installation + +Download this repository and install the requirements: + +```bash +git clone git@github.com:mmcdermott/MEDS_polars_functions.git +cd MEDS_polars_functions +conda create -n MEDS python=3.12 +conda activate MEDS +pip install .[examples] +``` + +## Step 1: Download eICU + +Download the eICU-CRD dataset (version 2.0) from https://physionet.org/content/eicu-crd/2.0/ following the +instructions on that page. You will need the raw `.csv.gz` files for this example. We will use +`$EICU_RAW_DIR` to denote the root directory of where the resulting _core data files_ are stored -- e.g., +there should be a `hosp` and `icu` subdirectory of `$EICU_RAW_DIR`. + +## Step 2: Get the data ready for base MEDS extraction + +This is a step in a few parts: + +1. Join a few tables by `hadm_id` to get the right timestamps in the right rows for processing. In + particular, we need to join: + - TODO +2. Convert the patient's static data to a more parseable form. This entails: + - Get the patient's DOB in a format that is usable for MEDS, rather than the integral `anchor_year` and + `anchor_offset` fields. + - Merge the patient's `dod` with the `deathtime` from the `admissions` table. + +After these steps, modified files or symlinks to the original files will be written in a new directory which +will be used as the input to the actual MEDS extraction ETL. We'll use `$EICU_PREMEDS_DIR` to denote this +directory. + +To run this step, you can use the following script (assumed to be run **not** from this directory but from the +root directory of this repository): + +```bash +./eICU_Example/pre_MEDS.py raw_cohort_dir=$EICU_RAW_DIR output_dir=$EICU_PREMEDS_DIR +``` + +In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less than 5 minutes in total. + +## Step 3: Run the MEDS extraction ETL + +Note that eICU has a lot more observations per patient than does MIMIC-IV, so to keep to a reasonable memory +burden (e.g., \< 150GB per worker), you will want a smaller shard size, as well as to turn off the final unique +check (which should not be necessary given the structure of eICU and is expensive) in the merge stage. You can +do this by setting the following parameters at the end of the mandatory args when running this script: + +- `stage_configs.split_and_shard_patients.n_patients_per_shard=10000` +- `stage_configs.merge_to_MEDS_cohort.unique_by=null` + +### Running locally, serially + +We will assume you want to output the final MEDS dataset into a directory we'll denote as `$EICU_MEDS_DIR`. +Note this is a different directory than the pre-MEDS directory (though, of course, they can both be +subdirectories of the same root directory). + +This is a step in 4 parts: + +1. Sub-shard the raw files. Run this command as many times simultaneously as you would like to have workers + performing this sub-sharding step. See below for how to automate this parallelism using hydra launchers. + +```bash +./scripts/extraction/shard_events.py \ + input_dir=$EICU_PREMEDS_DIR \ + cohort_dir=$EICU_MEDS_DIR \ + event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml +``` + +In practice, on a machine with 150 GB of RAM and 10 cores, this step takes approximately 20 minutes in total. + +2. Extract and form the patient splits and sub-shards. + +```bash +./scripts/extraction/split_and_shard_patients.py \ + input_dir=$EICU_PREMEDS_DIR \ + cohort_dir=$EICU_MEDS_DIR \ + event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml +``` + +In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less than 5 minutes in total. + +3. Extract patient sub-shards and convert to MEDS events. + +```bash +./scripts/extraction/convert_to_sharded_events.py \ + input_dir=$EICU_PREMEDS_DIR \ + cohort_dir=$EICU_MEDS_DIR \ + event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml +``` + +In practice, serially, this also takes around 20 minutes or more. However, it can be trivially parallelized to +cut the time down by a factor of the number of workers processing the data by simply running the command +multiple times (though this will, of course, consume more resources). If your filesystem is distributed, these +commands can also be launched as separate slurm jobs, for example. For eICU, this level of parallelization +and performance is not necessary; however, for larger datasets, it can be. + +4. Merge the MEDS events into a single file per patient sub-shard. + +```bash +./scripts/extraction/merge_to_MEDS_cohort.py \ + input_dir=$EICU_PREMEDS_DIR \ + cohort_dir=$EICU_MEDS_DIR \ + event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml +``` + +### Running Locally, in Parallel. + +This step is the exact same commands as above, but leverages Hydra's multirun capabilities with the `joblib` +launcher. Install this package with the optional `local_parallelism` option (e.g., `pip install -e .[local_parallelism]` and run `./eICU_Example/joint_script.sh`. See that script for expected args. + +### Running Each Step over Slurm + +To use slurm, run each command with the number of workers desired using Hydra's multirun capabilities with the +`submitit_slurm` launcher. Install this package with the optional `slurm_parallelism` option. See below for +modified commands. Note these can't be chained in a single script as the jobs will not wait for all slurm jobs +to finish before moving on to the next stage. Let `$N_PARALLEL_WORKERS` be the number of desired workers + +1. Sub-shard the raw files. + +```bash +./scripts/extraction/shard_events.py \ + --multirun \ + worker="range(0,$N_PARALLEL_WORKERS)" \ + hydra/launcher=submitit_slurm \ + hydra.launcher.timeout_min=60 \ + hydra.launcher.cpus_per_task=10 \ + hydra.launcher.mem_gb=50 \ + hydra.launcher.name="${hydra.job.name}_${worker}" \ + hydra.launcher.partition="short" \ + input_dir=$EICU_PREMEDS_DIR \ + cohort_dir=$EICU_MEDS_DIR \ + event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml +``` + +In practice, on a machine with 150 GB of RAM and 10 cores, this step takes approximately 20 minutes in total. + +2. Extract and form the patient splits and sub-shards. + +```bash +./scripts/extraction/split_and_shard_patients.py \ + input_dir=$EICU_PREMEDS_DIR \ + cohort_dir=$EICU_MEDS_DIR \ + event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml +``` + +In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less than 5 minutes in total. + +3. Extract patient sub-shards and convert to MEDS events. + +```bash +./scripts/extraction/convert_to_sharded_events.py \ + input_dir=$EICU_PREMEDS_DIR \ + cohort_dir=$EICU_MEDS_DIR \ + event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml +``` + +In practice, serially, this also takes around 20 minutes or more. However, it can be trivially parallelized to +cut the time down by a factor of the number of workers processing the data by simply running the command +multiple times (though this will, of course, consume more resources). If your filesystem is distributed, these +commands can also be launched as separate slurm jobs, for example. For eICU, this level of parallelization +and performance is not necessary; however, for larger datasets, it can be. + +4. Merge the MEDS events into a single file per patient sub-shard. + +```bash +./scripts/extraction/merge_to_MEDS_cohort.py \ + input_dir=$EICU_PREMEDS_DIR \ + cohort_dir=$EICU_MEDS_DIR \ + event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml +``` + +## Limitations / TO-DOs: + +Currently, some tables are ignored, including: + +1. `admissiondrug`: The [documentation](https://eicu-crd.mit.edu/eicutables/admissiondrug/) notes that this is + extremely infrequently used, so we skip it. +2. + +Lots of questions remain about how to appropriately handle timestamps of the data -- e.g., things like HCPCS +events are stored at the level of the _date_, not the _datetime_. How should those be slotted into the +timeline which is otherwise stored at the _datetime_ resolution? + +Other questions: + +1. How to handle merging the deathtimes between the hosp table and the patients table? +2. How to handle the dob nonsense MIMIC has? + +## Future Work + +### Pre-MEDS Processing + +If you wanted, some other processing could also be done here, such as: + +1. Converting the patient's dynamically recorded race into a static, most commonly recorded race field. diff --git a/eICU_Example/__init__.py b/eICU_Example/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/eICU_Example/configs/event_configs.yaml b/eICU_Example/configs/event_configs.yaml new file mode 100644 index 0000000..77f4023 --- /dev/null +++ b/eICU_Example/configs/event_configs.yaml @@ -0,0 +1,613 @@ +# Note that there is no "patient_id" for eICU -- patients are only differentiable during the course of a +# single health system stay. Accordingly, we set the "patient" id here as the "patientHealthSystemStayID" + +patient_id_col: patienthealthsystemstayid + +patient: + dob: + code: "DOB" + timestamp: col(dateofbirth) + uniquepid: "uniquepid" + gender: + code: ["GENDER", "col(gender)"] + timestamp: null + ethnicity: + code: ["ETHNICITY", "col(ethnicity)"] + timestamp: null + hosp_admission: + code: + - "HOSPITAL_ADMISSION" + - col(hospitaladmitsource) + - col(hospitalregion) + - col(hospitalteachingstatus) + - col(hospitalnumbedscategory) + timestamp: col(hospitaladmittimestamp) + hospital_id: "hospitalid" + hosp_discharge: + code: + - "HOSPITAL_DISCHARGE" + - col(hospitaldischargestatus) + - col(hospitaldischargelocation) + timestamp: col(hospitaldischargetimestamp) + unit_admission: + code: + - "UNIT_ADMISSION" + - col(unitadmitsource) + - col(unitstaytype) + timestamp: col(unitadmittimestamp) + ward_id: "wardid" + unit_stay_id: "patientunitstayid" + unit_admission_weight: + code: + - "UNIT_ADMISSION_WEIGHT" + timestamp: col(unitadmittimestamp) + numerical_value: "unitadmissionweight" + unit_admission_height: + code: + - "UNIT_ADMISSION_HEIGHT" + timestamp: col(unitadmittimestamp) + numerical_value: "unitadmissionheight" + unit_discharge: + code: + - "UNIT_DISCHARGE" + - col(unitdischargestatus) + - col(unitdischargelocation) + timestamp: col(unitdischargetimestamp) + unit_discharge_weight: + code: + - "UNIT_DISCHARGE_WEIGHT" + timestamp: col(unitdischargetimestamp) + numerical_value: "unitdischargeweight" + +admissiondx: + admission_diagnosis: + code: + - "ADMISSION_DX" + - col(admitdxname) + timestamp: col(admitDxEnteredTimestamp) + admission_dx_id: "admitDxID" + unit_stay_id: "patientunitstayid" + +allergy: + allergy: + code: + - "ALLERGY" + - col(allergytype) + - col(allergyname) + timestamp: col(allergyEnteredTimestamp) + +carePlanGeneral: + cplItem: + code: + - "CAREPLAN_GENERAL" + - col(cplgroup) + - col(cplitemvalue) + timestamp: col(carePlanGeneralItemEnteredTimestamp) + +carePlanEOL: + cplEolDiscussion: + code: + - "CAREPLAN_EOL" + timestamp: col(carePlanEolDiscussionOccurredTimestamp) + +carePlanGoal: + cplGoal: + code: + - "CAREPLAN_GOAL" + - col(cplgoalcategory) + - col(cplgoalvalue) + - col(cplgoalstatus) + timestamp: col(carePlanGoalEnteredTimestamp) + +carePlanInfectiousDisease: + cplInfectDisease: + code: + - "CAREPLAN_INFECTIOUS_DISEASE" + - col(infectdiseasesite) + - col(infectdiseaseassessment) + - col(treatment) + - col(responsetotherapy) + timestamp: col(carePlanInfectDiseaseEnteredTimestamp) + +diagnosis: + diagnosis: + code: + - "ICD9CM" + - col(icd9code) + - col(diagnosispriority) + timestamp: col(diagnosisEnteredTimestamp) + diagnosis_string: "diagnosisstring" + +infusionDrug: + infusion: + code: + - "INFUSION" + - col(infusiondrugid) + - col(drugname) + timestamp: col(infusionEnteredTimestamp) + drug_rate: "drugrate" + infusion_rate: "infusionrate" + drug_amount: "drugamount" + volume_of_fluid: "volumeoffluid" + patient_weight: + code: + - "INFUSION_PATIENT_WEIGHT" + timestamp: col(infusionEnteredTimestamp) + numerical_value: "patientweight" + +lab: + lab: + code: + - "LAB" + - col(labmeasurenamesystem) + - col(labmeasurenameinterface) + - col(labname) + timestamp: col(labResultDrawnTimestamp) + numerical_value: "labresult" + text_value: "labresulttext" + lab_type_id: "labtypeid" + +medication: + drug_ordered: + code: + - "MEDICATION" + - "ORDERED" + - col(drugname) + timestamp: col(drugordertimestamp) + medication_id: "medicationid" + drug_iv_admixture: "drugivadmixture" + dosage: "dosage" + route_admin: "routeadmin" + frequency: "frequency" + loading_dose: "loadingdose" + prn: "prn" + gtc: "gtc" + drug_started: + code: + - "MEDICATION" + - "STARTED" + - col(drugname) + timestamp: col(drugstarttimestamp) + medication_id: "medicationid" + drug_stopped: + code: + - "MEDICATION" + - "STOPPED" + - col(drugname) + timestamp: col(drugstoptimestamp) + medication_id: "medicationid" + +nurseAssessment: + nurse_assessment_performed: + code: + - "NURSE_ASSESSMENT" + - "PERFORMED" + - NOT YET DONE + timestamp: col(nurseAssessPerformedTimestamp) + nurse_assessment_id: "nurseassessid" + cell_label: "celllabel" + cell_attribute: "cellattribute" + cell_attribute_value: "cellattributevalue" + + nurse_assessment_entered: + code: + - "NURSE_ASSESSMENT" + - "ENTERED" + - NOT YET DONE + timestamp: col(nurseAssessEnteredTimestamp) + nurse_assessment_id: "nurseassessid" + cell_label: "celllabel" + cell_attribute: "cellattribute" + cell_attribute_value: "cellattributevalue" + +nurseCare: + nurse_care_performed: + code: + - "NURSE_CARE" + - "PERFORMED" + - NOT YET DONE + timestamp: col(nurseCarePerformedTimestamp) + nurse_care_id: "nursecareid" + cell_label: "celllabel" + cell_attribute: "cellattribute" + cell_attribute_value: "cellattributevalue" + + nurse_care_entered: + code: + - "NURSE_CARE" + - "ENTERED" + - NOT YET DONE + timestamp: col(nurseCareEnteredTimestamp) + nurse_care_id: "nursecareid" + cell_label: "celllabel" + cell_attribute: "cellattribute" + cell_attribute_value: "cellattributevalue" + +nurseCharting: + nurse_charting_performed: + code: + - "NURSE_CHARTING" + - "PERFORMED" + - NOT YET DONE + timestamp: col(nursingChartPerformedTimestamp) + nurse_charting_id: "nursingchartid" + cell_type_cat: "nursingchartcelltypecat" + cell_type_val_name: "nursingchartcelltypevalname" + cell_type_val_label: "nursingchartcelltypevallabel" + cell_value: "nursingchartvalue" + + nurse_charting_entered: + code: + - "NURSE_CHARTING" + - "ENTERED" + - NOT YET DONE + timestamp: col(nursingChartEnteredTimestamp) + nurse_charting_id: "nursingchartid" + cell_type_cat: "nursingchartcelltypecat" + cell_type_val_name: "nursingchartcelltypevalname" + cell_type_val_label: "nursingchartcelltypevallabel" + cell_value: "nursingchartvalue" + +pastHistory: + past_history_taken: + code: + - "PAST_HISTORY" + - "TAKEN" + - NOT YET DONE + timestamp: col(pastHistoryTakenTimestamp) + past_history_id: "pasthistoryid" + note_type: "pasthistorynotetype" + path: "pasthistorypath" + value: "pasthistoryvalue" + value_text: "pasthistoryvaluetext" + + past_history_entered: + code: + - "PAST_HISTORY" + - "ENTERED" + - NOT YET DONE + timestamp: col(pastHistoryEnteredTimestamp) + past_history_id: "pasthistoryid" + note_type: "pasthistorynotetype" + path: "pasthistorypath" + value: "pasthistoryvalue" + value_text: "pasthistoryvaluetext" + +physicalExam: + physical_exam_entered: + code: + - "PHYSICAL_EXAM" + - "ENTERED" + - NOT YET DONE + timestamp: col(physicalExamEnteredTimestamp) + physical_exam_id: "physicalexamid" + text: "physicalexamtext" + path: "physicalexampath" + value: "physicalexamvalue" + +respiratoryCare: + resp_care_status: + code: + - "RESP_CARE" + - "STATUS" + - NOT YET DONE + timestamp: col(respCareStatusEnteredTimestamp) + resp_care_id: "respcareid" + airwaytype: "airwaytype" + airwaysize: "airwaysize" + airwayposition: "airwayposition" + cuffpressure: "cuffpressure" + apneaparms: "apneaparms" + lowexhmvlimit: "lowexhmvlimit" + hiexhmvlimit: "hiexhmvlimit" + lowexhtvlimit: "lowexhtvlimit" + hipeakpreslimit: "hipeakpreslimit" + lowpeakpreslimit: "lowpeakpreslimit" + hirespratelimit: "hirespratelimit" + lowrespratelimit: "lowrespratelimit" + sighpreslimit: "sighpreslimit" + lowironoxlimit: "lowironoxlimit" + highironoxlimit: "highironoxlimit" + meanairwaypreslimit: "meanairwaypreslimit" + peeplimit: "peeplimit" + cpaplimit: "cpaplimit" + setapneainterval: "setapneainterval" + setapneatv: "setapneatv" + setapneaippeephigh: "setapneaippeephigh" + setapnearr: "setapnearr" + setapneapeakflow: "setapneapeakflow" + setapneainsptime: "setapneainsptime" + setapneaie: "setapneaie" + setapneafio2: "setapneafio2" + + vent_start: + code: + - "VENT" + - "START" + - NOT YET DONE + timestamp: col(ventStartTimestamp) + resp_care_id: "respcareid" + + vent_end: + code: + - "VENT" + - "END" + - NOT YET DONE + timestamp: col(ventEndTimestamp) + resp_care_id: "respcareid" + +respiratoryCharting: + resp_charting_performed: + code: + - "RESP_CHARTING" + - "PERFORMED" + - NOT YET DONE + timestamp: col(respChartPerformedTimestamp) + resp_chart_id: "respchartid" + type_cat: "respcharttypecat" + value_label: "respchartvaluelabel" + value: "respchartvalue" + + resp_charting_entered: + code: + - "RESP_CHARTING" + - "ENTERED" + - NOT YET DONE + timestamp: col(respChartEnteredTimestamp) + resp_chart_id: "respchartid" + type_cat: "respcharttypecat" + value_label: "respchartvaluelabel" + value: "respchartvalue" + +treatment: + treatment: + code: + - "TREATMENT" + - "ENTERED" + - col(treatmentstring) + timestamp: col(treatmentEnteredTimestamp) + treatment_id: "treatmentid" + +vitalAperiodic: + non_invasive_systolic: + code: + - "VITALS" + - "APERIODIC" + - "BP" + - "NONINVASIVE_SYSTOLIC" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalaperiodicid" + numeric_value: "noninvasivesystolic" + non_invasive_diastolic: + code: + - "VITALS" + - "APERIODIC" + - "BP" + - "NONINVASIVE_DIASTOLIC" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalaperiodicid" + numeric_value: "noninvasivediastolic" + + non_invasive_mean: + code: + - "VITALS" + - "APERIODIC" + - "BP" + - "NONINVASIVE_MEAN" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalaperiodicid" + numeric_value: "noninvasivemean" + + paop: + code: + - "VITALS" + - "APERIODIC" + - "PAOP" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalaperiodicid" + numeric_value: "paop" + + cardiac_output: + code: + - "VITALS" + - "APERIODIC" + - "CARDIAC_OUTPUT" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalaperiodicid" + numeric_value: "cardiacoutput" + + cardiac_input: + code: + - "VITALS" + - "APERIODIC" + - "CARDIAC_INPUT" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalaperiodicid" + numeric_value: "cardiacinput" + + svr: + code: + - "VITALS" + - "APERIODIC" + - "SVR" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalaperiodicid" + numeric_value: "svr" + + svri: + code: + - "VITALS" + - "APERIODIC" + - "SVRI" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalaperiodicid" + numeric_value: "svri" + + pvr: + code: + - "VITALS" + - "APERIODIC" + - "PVR" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalaperiodicid" + numeric_value: "pvr" + + pvri: + code: + - "VITALS" + - "APERIODIC" + - "PVRI" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalaperiodicid" + numeric_value: "pvri" + +vitalPeriodic: + temperature: + code: + - "VITALS" + - "PERIODIC" + - "TEMPERATURE" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalperiodicid" + numeric_value: "temperature" + + saO2: + code: + - "VITALS" + - "PERIODIC" + - "SAO2" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalperiodicid" + numeric_value: "sao2" + + heartRate: + code: + - "VITALS" + - "PERIODIC" + - "HEARTRATE" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalperiodicid" + numeric_value: "heartrate" + + respiration: + code: + - "VITALS" + - "PERIODIC" + - "RESPIRATION" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalperiodicid" + numeric_value: "respiration" + + cvp: + code: + - "VITALS" + - "PERIODIC" + - "CVP" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalperiodicid" + numeric_value: "cvp" + + etCo2: + code: + - "VITALS" + - "PERIODIC" + - "ETCO2" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalperiodicid" + numeric_value: "etco2" + + systemic_systolic: + code: + - "VITALS" + - "PERIODIC" + - "BP" + - "SYSTEMIC_SYSTOLIC" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalperiodicid" + numeric_value: "systemicsystolic" + + systemic_diastolic: + code: + - "VITALS" + - "PERIODIC" + - "BP" + - "SYSTEMIC_DIASTOLIC" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalperiodicid" + numeric_value: "systemicdiastolic" + + systemic_mean: + code: + - "VITALS" + - "PERIODIC" + - "BP" + - "SYSTEMIC_MEAN" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalperiodicid" + numeric_value: "systemicmean" + + pa_systolic: + code: + - "VITALS" + - "PERIODIC" + - "BP" + - "PULM_ART_SYSTOLIC" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalperiodicid" + numeric_value: "pasystolic" + + pa_diastolic: + code: + - "VITALS" + - "PERIODIC" + - "BP" + - "PULM_ART_DIASTOLIC" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalperiodicid" + numeric_value: "padiastolic" + + pa_mean: + code: + - "VITALS" + - "PERIODIC" + - "BP" + - "PULM_ART_MEAN" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalperiodicid" + numeric_value: "pamean" + + st1: + code: + - "VITALS" + - "PERIODIC" + - "ST1" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalperiodicid" + numeric_value: "st1" + + st2: + code: + - "VITALS" + - "PERIODIC" + - "ST2" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalperiodicid" + numeric_value: "st2" + + st3: + code: + - "VITALS" + - "PERIODIC" + - "ST3" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalperiodicid" + numeric_value: "st3" + + ICP: + code: + - "VITALS" + - "PERIODIC" + - "ICP" + timestamp: col(observationEnteredTimestamp) + vital_id: "vitalperiodicid" + numeric_value: "icp" diff --git a/eICU_Example/configs/pre_MEDS.yaml b/eICU_Example/configs/pre_MEDS.yaml new file mode 100644 index 0000000..b5cfa4c --- /dev/null +++ b/eICU_Example/configs/pre_MEDS.yaml @@ -0,0 +1,11 @@ +raw_cohort_dir: ??? +output_dir: ??? + +# Hydra +hydra: + job: + name: pre_MEDS_${now:%Y-%m-%d_%H-%M-%S} + run: + dir: ${output_dir}/.logs/${hydra.job.name} + sweep: + dir: ${output_dir}/.logs/${hydra.job.name} diff --git a/eICU_Example/configs/table_preprocessors.yaml b/eICU_Example/configs/table_preprocessors.yaml new file mode 100644 index 0000000..3faf4aa --- /dev/null +++ b/eICU_Example/configs/table_preprocessors.yaml @@ -0,0 +1,283 @@ +admissiondx: + offset_col: "admitdxenteredoffset" + pseudotime_col: "admitDxEnteredTimestamp" + output_data_cols: ["admitdxname", "admitdxid"] + warning_items: + ["How should we use `admitdxtest`?", "How should we use `admitdxpath`?"] + +allergy: + offset_col: "allergyenteredoffset" + pseudotime_col: "allergyEnteredTimestamp" + output_data_cols: ["allergytype", "allergyname"] + warning_items: + - "How should we use `allergyNoteType`?" + - "How should we use `specialtyType`?" + - "How should we use `userType`?" + - >- + Is `drugName` the name of the drug to which the patient is allergic or the drug given to the patient + (docs say 'name of the selected admission drug')? + +carePlanGeneral: + offset_col: "cplitemoffset" + pseudotime_col: "carePlanGeneralItemEnteredTimestamp" + output_data_cols: ["cplgroup", "cplitemvalue"] + +carePlanEOL: + offset_col: "cpleoldiscussionoffset" + pseudotime_col: "carePlanEolDiscussionOccurredTimestamp" + warning_items: + - "Is the DiscussionOffset time actually reliable? Should we fall back on the SaveOffset time?" + +carePlanGoal: + offset_col: "cplgoaloffset" + pseudotime_col: "carePlanGoalEnteredTimestamp" + output_data_cols: ["cplgoalcategory", "cplgoalvalue", "cplgoalstatus"] + +carePlanInfectiousDisease: + offset_col: "cplinfectdiseaseoffset" + pseudotime_col: "carePlanInfectDiseaseEnteredTimestamp" + output_data_cols: + - "infectdiseasesite" + - "infectdiseaseassessment" + - "responsetotherapy" + - "treatment" + +diagnosis: + offset_col: "diagnosisoffset" + pseudotime_col: "diagnosisEnteredTimestamp" + output_data_cols: ["icd9code", "diagnosispriority", "diagnosisstring"] + warning_items: + - "Though we use it, the `diagnosisString` field documentation is unclear -- by what is it separated?" + +infusionDrug: + offset_col: "infusionoffset" + pseudotime_col: "infusionEnteredTimestamp" + output_data_cols: + - "infusiondrugid" + - "drugname" + - "drugrate" + - "infusionrate" + - "drugamount" + - "volumeoffluid" + - "patientweight" + +lab: + offset_col: "labresultoffset" + pseudotime_col: "labResultDrawnTimestamp" + output_data_cols: + - "labname" + - "labresult" + - "labresulttext" + - "labmeasurenamesystem" + - "labmeasurenameinterface" + - "labtypeid" + warning_items: + - "Is this the time the lab was drawn? Entered? The time the result came in?" + - "We **IGNORE** the `labResultRevisedOffset` column -- this may be a mistake!" + +medication: + offset_col: + - "drugorderoffset" + - "drugstartoffset" + - "drugstopoffset" + pseudotime_col: + - "drugordertimestamp" + - "drugstarttimestamp" + - "drugstoptimestamp" + output_data_cols: + - "medicationid" + - "drugivadmixture" + - "drugname" + - "drughiclseqno" + - "dosage" + - "routeadmin" + - "frequency" + - "loadingdose" + - "prn" + - "gtc" + warning_items: + - "We **IGNORE** the `drugOrderCancelled` column -- this may be a mistake!" + +nurseAssessment: + offset_col: + - "nurseassessoffset" + - "nurseassessentryoffset" + pseudotime_col: + - "nurseAssessPerformedTimestamp" + - "nurseAssessEnteredTimestamp" + output_data_cols: + - "nurseassessid" + - "celllabel" + - "cellattribute" + - "cellattributevalue" + warning_items: + - "Should we be using `cellAttributePath` instead of `cellAttribute`?" + - "SOME MAY BE LISTS" + +nurseCare: + offset_col: + - "nursecareoffset" + - "nursecareentryoffset" + pseudotime_col: + - "nurseCarePerformedTimestamp" + - "nurseCareEnteredTimestamp" + output_data_cols: + - "nursecareid" + - "celllabel" + - "cellattribute" + - "cellattributevalue" + warning_items: + - "Should we be using `cellAttributePath` instead of `cellAttribute`?" + - "SOME MAY BE LISTS" + +nurseCharting: + offset_col: + - "nursingchartoffset" + - "nursingchartentryoffset" + pseudotime_col: + - "nursingChartPerformedTimestamp" + - "nursingChartEnteredTimestamp" + output_data_cols: + - "nursingchartid" + - "nursingchartcelltypecat" + - "nursingchartcelltypevalname" + - "nursingchartcelltypevallabel" + - "nursingchartvalue" + warning_items: + - "SOME MAY BE LISTS" + +pastHistory: + offset_col: + - "pasthistoryoffset" + - "pasthistoryenteredoffset" + pseudotime_col: + - "pastHistoryTakenTimestamp" + - "pastHistoryEnteredTimestamp" + output_data_cols: + - "pasthistoryid" + - "pasthistorynotetype" + - "pasthistorypath" + - "pasthistoryvalue" + - "pasthistoryvaluetext" + warning_items: + - "SOME MAY BE LISTS" + - "How should we use `pastHistoryPath` vs. `pastHistoryNoteType`?" + - "How should we use `pastHistoryValue` vs. `pastHistoryValueText`?" + +physicalExam: + offset_col: "physicalexamoffset" + pseudotime_col: "physicalExamEnteredTimestamp" + output_data_cols: + - "physicalexamid" + - "physicalexamtext" + - "physicalexampath" + - "physicalexamvalue" + warning_items: + - "How should we use `physicalExamValue` vs. `physicalExamText`?" + - "I believe the `physicalExamValue` is a **LIST**. This must be processed specially." + +respiratoryCare: + offset_col: + - "respcarestatusoffset" + - "ventstartoffset" + - "ventendoffset" + pseudotime_col: + - "respCareStatusEnteredTimestamp" + - "ventStartTimestamp" + - "ventEndTimestamp" + output_data_cols: + - "respcareid" + - "airwaytype" + - "airwaysize" + - "airwayposition" + - "cuffpressure" + - "apneaparms" + - "lowexhmvlimit" + - "hiexhmvlimit" + - "lowexhtvlimit" + - "hipeakpreslimit" + - "lowpeakpreslimit" + - "hirespratelimit" + - "lowrespratelimit" + - "sighpreslimit" + - "lowironoxlimit" + - "highironoxlimit" + - "meanairwaypreslimit" + - "peeplimit" + - "cpaplimit" + - "setapneainterval" + - "setapneatv" + - "setapneaippeephigh" + - "setapnearr" + - "setapneapeakflow" + - "setapneainsptime" + - "setapneaie" + - "setapneafio2" + warning_items: + - "We ignore the `priorVent*` columns -- this may be a mistake!" + - "There is a lot of data in this table -- what should be incorporated into the event structure?" + - "We might be able to use `priorVent` timestamps to further refine true season of unit admission." + +respiratoryCharting: + offset_col: + - "respchartoffset" + - "respchartentryoffset" + pseudotime_col: + - "respChartPerformedTimestamp" + - "respChartEnteredTimestamp" + output_data_cols: + - "respchartid" + - "respcharttypecat" + - "respchartvaluelabel" + - "respchartvalue" + warning_items: + - "SOME MAY BE LISTS" + +treatment: + offset_col: "treatmentoffset" + pseudotime_col: "treatmentEnteredTimestamp" + output_data_cols: + - "treatmentid" + - "treatmentstring" + warning_items: + - "Absence of entries in table do not indicate absence of treatments" + +vitalAperiodic: + offset_col: "observationoffset" + pseudotime_col: "observationEnteredTimestamp" + output_data_cols: + - "vitalaperiodicid" + - "noninvasivesystolic" + - "noninvasivediastolic" + - "noninvasivemean" + - "paop" + - "cardiacoutput" + - "cardiacinput" + - "svr" + - "svri" + - "pvr" + - "pvri" + +vitalPeriodic: + offset_col: "observationoffset" + pseudotime_col: "observationEnteredTimestamp" + output_data_cols: + - "vitalperiodicid" + - "temperature" + - "sao2" + - "heartrate" + - "respiration" + - "cvp" + - "etco2" + - "systemicsystolic" + - "systemicdiastolic" + - "systemicmean" + - "pasystolic" + - "padiastolic" + - "pamean" + - "st1" + - "st2" + - "st3" + - "icp" + warning_items: + - "These are 5-minute median values. There are going to be a *lot* of events." diff --git a/eICU_Example/joint_script.sh b/eICU_Example/joint_script.sh new file mode 100755 index 0000000..fd76ee2 --- /dev/null +++ b/eICU_Example/joint_script.sh @@ -0,0 +1,84 @@ +#!/usr/bin/env bash + +# This makes the script fail if any internal script fails +set -e + +# Function to display help message +function display_help() { + echo "Usage: $0 " + echo + echo "This script processes eICU data through several steps, handling raw data conversion," + echo "sharding events, splitting patients, converting to sharded events, and merging into a MEDS cohort." + echo + echo "Arguments:" + echo " EICU_RAW_DIR Directory containing raw eICU data files." + echo " EICU_PREMEDS_DIR Output directory for pre-MEDS data." + echo " EICU_MEDS_DIR Output directory for processed MEDS data." + echo " N_PARALLEL_WORKERS Number of parallel workers for processing." + echo + echo "Options:" + echo " -h, --help Display this help message and exit." + exit 1 +} + +# Check if the first parameter is '-h' or '--help' +if [[ "$1" == "-h" || "$1" == "--help" ]]; then + display_help +fi + +# Check for mandatory parameters +if [ "$#" -lt 4 ]; then + echo "Error: Incorrect number of arguments provided." + display_help +fi + +EICU_RAW_DIR="$1" +EICU_PREMEDS_DIR="$2" +EICU_MEDS_DIR="$3" +N_PARALLEL_WORKERS="$4" + +shift 4 + +echo "Note that eICU has a lot more observations per patient than does MIMIC-IV, so to keep to a reasonable " +echo "memory burden (e.g., < 150GB per worker), you will want a smaller shard size, as well as to turn off " +echo "the final unique check (which should not be necessary given the structure of eICU and is expensive) " +echo "in the merge stage. You can do this by setting the following parameters at the end of the mandatory " +echo "args when running this script:" +echo " * stage_configs.split_and_shard_patients.n_patients_per_shard=10000" +echo " * stage_configs.merge_to_MEDS_cohort.unique_by=null" + +echo "Running pre-MEDS conversion." +./eICU_Example/pre_MEDS.py raw_cohort_dir="$EICU_RAW_DIR" output_dir="$EICU_PREMEDS_DIR" + +echo "Running shard_events.py with $N_PARALLEL_WORKERS workers in parallel" +./scripts/extraction/shard_events.py \ + --multirun \ + worker="range(0,$N_PARALLEL_WORKERS)" \ + hydra/launcher=joblib \ + input_dir="$EICU_PREMEDS_DIR" \ + cohort_dir="$EICU_MEDS_DIR" \ + event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml "$@" + +echo "Splitting patients in serial" +./scripts/extraction/split_and_shard_patients.py \ + input_dir="$EICU_PREMEDS_DIR" \ + cohort_dir="$EICU_MEDS_DIR" \ + event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml "$@" + +echo "Converting to sharded events with $N_PARALLEL_WORKERS workers in parallel" +./scripts/extraction/convert_to_sharded_events.py \ + --multirun \ + worker="range(0,$N_PARALLEL_WORKERS)" \ + hydra/launcher=joblib \ + input_dir="$EICU_PREMEDS_DIR" \ + cohort_dir="$EICU_MEDS_DIR" \ + event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml "$@" + +echo "Merging to a MEDS cohort with $N_PARALLEL_WORKERS workers in parallel" +./scripts/extraction/merge_to_MEDS_cohort.py \ + --multirun \ + worker="range(0,$N_PARALLEL_WORKERS)" \ + hydra/launcher=joblib \ + input_dir="$EICU_PREMEDS_DIR" \ + cohort_dir="$EICU_MEDS_DIR" \ + event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml "$@" diff --git a/eICU_Example/joint_script_slurm.sh b/eICU_Example/joint_script_slurm.sh new file mode 100755 index 0000000..7880286 --- /dev/null +++ b/eICU_Example/joint_script_slurm.sh @@ -0,0 +1,111 @@ +#!/usr/bin/env bash + +# This makes the script fail if any internal script fails +set -e + +# Function to display help message +function display_help() { + echo "Usage: $0 " + echo + echo "This script processes eICU data through several steps, handling raw data conversion," + echo "sharding events, splitting patients, converting to sharded events, and merging into a MEDS cohort." + echo "This script uses slurm to process the data in parallel via the 'submitit' Hydra launcher." + echo + echo "Arguments:" + echo " EICU_RAW_DIR Directory containing raw eICU data files." + echo " EICU_PREMEDS_DIR Output directory for pre-MEDS data." + echo " EICU_MEDS_DIR Output directory for processed MEDS data." + echo " N_PARALLEL_WORKERS Number of parallel workers for processing." + echo + echo "Options:" + echo " -h, --help Display this help message and exit." + exit 1 +} + +# Check if the first parameter is '-h' or '--help' +if [[ "$1" == "-h" || "$1" == "--help" ]]; then + display_help +fi + +# Check for mandatory parameters +if [ "$#" -ne 4 ]; then + echo "Error: Incorrect number of arguments provided." + display_help +fi + +EICU_RAW_DIR="$1" +EICU_PREMEDS_DIR="$2" +EICU_MEDS_DIR="$3" +N_PARALLEL_WORKERS="$4" + +shift 4 + +# Note we use `--multirun` throughout here due to ensure the submitit launcher is used throughout, so that +# this doesn't fall back on running anything locally in a setting where only slurm worker nodes have +# sufficient computational resources to run the actual jobs. + +echo "Running pre-MEDS conversion on one worker." +./eICU_Example/pre_MEDS.py \ + --multirun \ + worker="range(0,1)" \ + hydra/launcher=submitit_slurm \ + hydra.launcher.timeout_min=60 \ + hydra.launcher.cpus_per_task=10 \ + hydra.launcher.mem_gb=50 \ + hydra.launcher.partition="short" \ + raw_cohort_dir="$EICU_RAW_DIR" \ + output_dir="$EICU_PREMEDS_DIR" + +echo "Trying submitit launching with $N_PARALLEL_WORKERS jobs." + +./scripts/extraction/shard_events.py \ + --multirun \ + worker="range(0,$N_PARALLEL_WORKERS)" \ + hydra/launcher=submitit_slurm \ + hydra.launcher.timeout_min=60 \ + hydra.launcher.cpus_per_task=10 \ + hydra.launcher.mem_gb=50 \ + hydra.launcher.partition="short" \ + "hydra.job.env_copy=[PATH]" \ + input_dir="$EICU_PREMEDS_DIR" \ + cohort_dir="$EICU_MEDS_DIR" \ + event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml + +echo "Splitting patients on one worker" +./scripts/extraction/split_and_shard_patients.py \ + --multirun \ + worker="range(0,1)" \ + hydra/launcher=submitit_slurm \ + hydra.launcher.timeout_min=60 \ + hydra.launcher.cpus_per_task=10 \ + hydra.launcher.mem_gb=50 \ + hydra.launcher.partition="short" \ + input_dir="$EICU_PREMEDS_DIR" \ + cohort_dir="$EICU_MEDS_DIR" \ + event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml "$@" + +echo "Converting to sharded events with $N_PARALLEL_WORKERS workers in parallel" +./scripts/extraction/convert_to_sharded_events.py \ + --multirun \ + worker="range(0,$N_PARALLEL_WORKERS)" \ + hydra/launcher=submitit_slurm \ + hydra.launcher.timeout_min=60 \ + hydra.launcher.cpus_per_task=10 \ + hydra.launcher.mem_gb=50 \ + hydra.launcher.partition="short" \ + input_dir="$EICU_PREMEDS_DIR" \ + cohort_dir="$EICU_MEDS_DIR" \ + event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml "$@" + +echo "Merging to a MEDS cohort with $N_PARALLEL_WORKERS workers in parallel" +./scripts/extraction/merge_to_MEDS_cohort.py \ + --multirun \ + worker="range(0,$N_PARALLEL_WORKERS)" \ + hydra/launcher=submitit_slurm \ + hydra.launcher.timeout_min=60 \ + hydra.launcher.cpus_per_task=10 \ + hydra.launcher.mem_gb=50 \ + hydra.launcher.partition="short" \ + input_dir="$EICU_PREMEDS_DIR" \ + cohort_dir="$EICU_MEDS_DIR" \ + event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml "$@" diff --git a/eICU_Example/pre_MEDS.py b/eICU_Example/pre_MEDS.py new file mode 100755 index 0000000..e5855f4 --- /dev/null +++ b/eICU_Example/pre_MEDS.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python + +"""Performs pre-MEDS data wrangling for eICU. + +See the docstring of `main` for more information. +""" +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + +import gzip +from collections.abc import Callable +from datetime import datetime +from pathlib import Path + +import hydra +import polars as pl +from loguru import logger +from omegaconf import DictConfig, OmegaConf + +from MEDS_polars_functions.utils import ( + get_shard_prefix, + hydra_loguru_init, + write_lazyframe, +) + +HEALTH_SYSTEM_STAY_ID = "patienthealthsystemstayid" +UNIT_STAY_ID = "patientunitstayid" +PATIENT_ID = "uniquepid" + +# The end of year date, used for year-only timestamps in eICU. The time is set to midnight as we'll add a +# 24-hour time component from other columns in the data. +END_OF_YEAR = {"month": 12, "day": 31, "hour": 0, "minute": 0, "second": 0} + + +def load_raw_eicu_file(fp: Path, **kwargs) -> pl.LazyFrame: + """Load a raw MIMIC file into a Polars DataFrame. + + Args: + fp: The path to the MIMIC file. + + Returns: + The Polars DataFrame containing the MIMIC data. + """ + + with gzip.open(fp, mode="rb") as f: + return pl.read_csv(f, infer_schema_length=100000000, **kwargs).lazy() + + +def check_timestamps_agree(df: pl.LazyFrame, pseudotime_col: pl.Expr, given_24htime_col: str): + """Checks that the time-of-day portions agree between the pseudotime and given columns. + + Raises a `ValueError` if the times don't match within a minute. + + Args: + TODO + """ + expected_time = pl.col(given_24htime_col).str.strptime(pl.Time, "%H:%M:%S") + + # The use of `.dt.combine` here re-sets the "time-of-day" of the pseudotime_col column + time_deltas_min = (pseudotime_col - pseudotime_col.dt.combine(expected_time)).dt.total_minutes() + + # Check that the time deltas are all within 1 minute + logger.info( + "Checking that stated 24h times are consistent given offsets between {pseudotime_col.name} and " + f"{given_24htime_col}..." + ) + max_time_deltas_min = df.select(time_deltas_min.abs().max()).collect().item() + if max_time_deltas_min > 1: + raise ValueError( + f"Max number of minutes between {pseudotime_col.name} and {given_24htime_col} is " + f"{max_time_deltas_min}. Should be <= 1." + ) + + +def process_patient(df: pl.LazyFrame, hospital_df: pl.LazyFrame) -> pl.LazyFrame: + """Takes the patient table and converts it to a form that includes timestamps. + + As eICU stores only offset times, note here that we add a CONSTANT TIME ACROSS ALL PATIENTS for the true + timestamp of their health system admission. This is acceptable because in eICU ONLY RELATIVE TIME + DIFFERENCES ARE MEANINGFUL, NOT ABSOLUTE TIMES. + + The output of this process is ultimately converted to events via the `patient` key in the + `configs/event_configs.yaml` file. + """ + + hospital_discharge_pseudotime = pl.datetime( + year=pl.col("hospitaldischargeyear"), **END_OF_YEAR + ).dt.combine(pl.col("hospitaldischargetime24").str.strptime(pl.Time, "%H:%M:%S")) + + unit_admit_pseudotime = hospital_discharge_pseudotime - pl.duration( + minutes=pl.col("hospitaldischargeoffset") + ) + + unit_discharge_pseudotime = unit_admit_pseudotime + pl.duration(minutes=pl.col("unitdischargeoffset")) + + hospital_admit_pseudotime = unit_admit_pseudotime + pl.duration(minutes=pl.col("hospitaladmitoffset")) + + age_in_years = ( + pl.when(pl.col("age") == "> 89").then(90).otherwise(pl.col("age").cast(pl.UInt16, strict=False)) + ) + age_in_days = age_in_years * 365.25 + # We assume that the patient was born at the midpoint of the year as we don't know the actual birthdate + pseudo_date_of_birth = unit_admit_pseudotime - pl.duration(days=(age_in_days - 365.25 / 2)) + + # Check the times + start = datetime.now() + logger.info( + "Checking that the 24h times are consistent. If this is extremely slow, consider refactoring to have " + "only one `.collect()` call." + ) + check_timestamps_agree(df, hospital_discharge_pseudotime, "hospitaldischargetime24") + check_timestamps_agree(df, hospital_admit_pseudotime, "hospitaladmittime24") + check_timestamps_agree(df, unit_admit_pseudotime, "unitadmittime24") + check_timestamps_agree(df, unit_discharge_pseudotime, "unitdischargetime24") + logger.info(f"Validated 24h times in {datetime.now() - start}") + + logger.warning("NOT validating the `unitVisitNumber` column as that isn't implemented yet.") + + logger.warning( + "NOT SURE ABOUT THE FOLLOWING. Check with the eICU team:\n" + " - `apacheAdmissionDx` is not selected from the patients table as we grab it from `admissiondx`. " + "Is this right?\n" + " - `admissionHeight` and `admissionWeight` are interpreted as **unit** admission height/weight, " + "not hospital admission height/weight. Is this right?\n" + " - `age` is interpreted as the age at the time of the unit stay, not the hospital stay. " + "Is this right?\n" + " - `What is the actual mean age for those > 89? Here we assume 90.\n" + " - Note that all the column names appear to be all in lowercase for the csv versions, vs. the docs" + ) + + return df.join(hospital_df, left_on="hospitalid", right_on="hospitalid", how="left").select( + # 1. Static variables + PATIENT_ID, + "gender", + pseudo_date_of_birth.alias("dateofbirth"), + "ethnicity", + # 2. Health system stay parameters + HEALTH_SYSTEM_STAY_ID, + "hospitalid", + pl.col("numbedscategory").alias("hospitalnumbedscategory"), + pl.col("teachingstatus").alias("hospitalteachingstatus"), + pl.col("region").alias("hospitalregion"), + # 2.1 Admission parameters + hospital_admit_pseudotime.alias("hospitaladmittimestamp"), + "hospitaladmitsource", + # 2.2 Discharge parameters + hospital_discharge_pseudotime.alias("hospitaldischargetimestamp"), + "hospitaldischargelocation", + "hospitaldischargestatus", + # 3. Unit stay parameters + UNIT_STAY_ID, + "wardid", + # 3.1 Admission parameters + unit_admit_pseudotime.alias("unitadmittimestamp"), + "unitadmitsource", + "unitstaytype", + pl.col("admissionheight").alias("unitadmissionheight"), + pl.col("admissionweight").alias("unitadmissionweight"), + # 3.2 Discharge parameters + unit_discharge_pseudotime.alias("unitdischargetimestamp"), + "unitdischargelocation", + "unitdischargestatus", + pl.col("dischargeweight").alias("unitdischargeweight"), + ) + + +def join_and_get_pseudotime_fntr( + table_name: str, + offset_col: str | list[str], + pseudotime_col: str | list[str], + output_data_cols: list[str] | None = None, + warning_items: list[str] | None = None, +) -> Callable[[pl.LazyFrame, pl.LazyFrame], pl.LazyFrame]: + """Returns a function that joins a dataframe to the `patient` table and adds pseudotimes. + + Also raises specified warning strings via the logger for uncertain columns. + + TODO + """ + + if output_data_cols is None: + output_data_cols = [] + + if isinstance(offset_col, str): + offset_col = [offset_col] + if isinstance(pseudotime_col, str): + pseudotime_col = [pseudotime_col] + + if len(offset_col) != len(pseudotime_col): + raise ValueError( + "There must be the same number of `offset_col`s and `pseudotime_col`s specified. Got " + f"{len(offset_col)} and {len(pseudotime_col)}, respectively." + ) + + def fn(df: pl.LazyFrame, patient_df: pl.LazyFrame) -> pl.LazyFrame: + f"""Takes the {table_name} table and converts it to a form that includes pseudo-timestamps. + + The output of this process is ultimately converted to events via the `{table_name}` key in the + `configs/event_configs.yaml` file. + """ + + pseudotimes = [ + (pl.col("unitadmittimestamp") + pl.duration(minutes=pl.col(offset))).alias(pseudotime) + for pseudotime, offset in zip(pseudotime_col, offset_col) + ] + + if warning_items: + warning_lines = [ + f"NOT SURE ABOUT THE FOLLOWING for {table_name} table. Check with the eICU team:", + *(f" - {item}" for item in warning_items), + ] + logger.warning("\n".join(warning_lines)) + + return df.join(patient_df, on=UNIT_STAY_ID, how="inner").select( + HEALTH_SYSTEM_STAY_ID, + UNIT_STAY_ID, + *pseudotimes, + *output_data_cols, + ) + + return fn + + +NEEDED_PATIENT_COLS = [UNIT_STAY_ID, HEALTH_SYSTEM_STAY_ID, "unitadmittimestamp"] + + +@hydra.main(version_base=None, config_path="configs", config_name="pre_MEDS") +def main(cfg: DictConfig): + """Performs pre-MEDS data wrangling for eICU. + + Inputs are the raw eICU files, read from the `raw_cohort_dir` config parameter. Output files are either + symlinked (if they are not modified) or written in processed form to the `MEDS_input_dir` config + parameter. Hydra is used to manage configuration parameters and logging. + + Note that eICU has only a tentative ability to identify true relative admission times for even the same + patient, as health system stay IDs are only temporally ordered at the *year* level. As such, to properly + parse this dataset in a longitudinal form, you must do one of the following: + 1. Not operate at the level of patients at all, but instead at the level of health system stays, as + individual events within a health system stay can be well ordered. + 2. Restrict the analysis to only patients who do not have multiple health system stays within a single + year (as health system stays across years can be well ordered, provided we assume to distinct stays + within a single health system cannot overlap). + + In this pipeline, we choose to operate at the level of health system stays, as this is the most general + approach. The only downside is that we lose the ability to track individual patients across health system + stays, and thus can only explore questions of limited longitudinal scope. + + We ignore the following tables for the given reasons: + 1. `admissiondrug`: This table is noted in the + [documentation](https://eicu-crd.mit.edu/eicutables/admissiondrug/) as being "Extremely infrequently + used". + 2. `apacheApsVar`: This table is a sort of "meta-table" that contains variables used to compute the + APACHE score; we won't use these raw variables from this table, but instead will use the raw data. + 3. `apachePatientResult`: This table has pre-computed APACHE score variables; we won't use these and + will use the raw data directly. + 4. `apachePredVar`: This table contains variables used to compute the APACHE score; we won't use these + in favor of the raw data directly. + 5. `carePlanCareProvider`: This table contains information about the provider for given care-plan + entries; however, as we can't link this table to the particular care-plan entries, we don't use it + here. It also is not clear (to the author of this script; the eICU team may know more) how reliable + the time-offsets are for this table as they merely denote when a provider was entered into the care + plan. + 6. `customLab`: The documentation for this table is very sparse, so we skip it. + 7. `intakeOutput`: There are a number of significant warnings about duplicates, cumulative values, and + more in the documentation for this table, so for now we skip it. + 8. `microLab`: We don't use this because the culture taken time != culture result time, so seeing this + data would give a model an advantage over any possible real-world implementation. Plus, the docs say + it is not well populated. + 9. `note`: This table is largely duplicated with structured data due to the fact that primarily + narrative notes were removed due to PHI constraints (see the docs). + + There are other notes for this pipeline: + 1. Many fields here are, I believe, **lists**, not simple categoricals, and should be split and + processed accordingly. This is not yet done. + + Args (all as part of the config file): + raw_cohort_dir: The directory containing the raw eICU files. + output_dir: The directory to write the processed files to. + """ + + hydra_loguru_init() + + table_preprocessors_config_fp = Path("./eICU_Example/configs/table_preprocessors.yaml") + logger.info(f"Loading table preprocessors from {str(table_preprocessors_config_fp.resolve())}...") + preprocessors = OmegaConf.load(table_preprocessors_config_fp) + functions = {} + for table_name, preprocessor_cfg in preprocessors.items(): + logger.info(f" Adding preprocessor for {table_name}:\n{OmegaConf.to_yaml(preprocessor_cfg)}") + functions[table_name] = join_and_get_pseudotime_fntr(table_name=table_name, **preprocessor_cfg) + + raw_cohort_dir = Path(cfg.raw_cohort_dir) + MEDS_input_dir = Path(cfg.output_dir) + + patient_out_fp = MEDS_input_dir / "patient.parquet" + + if patient_out_fp.is_file(): + logger.info(f"Reloading processed patient df from {str(patient_out_fp.resolve())}") + patient_df = pl.read_parquet(patient_out_fp, columns=NEEDED_PATIENT_COLS, use_pyarrow=True).lazy() + else: + logger.info("Processing patient table first...") + + hospital_fp = raw_cohort_dir / "hospital.csv.gz" + patient_fp = raw_cohort_dir / "patient.csv.gz" + logger.info(f"Loading {str(hospital_fp.resolve())}...") + hospital_df = load_raw_eicu_file( + hospital_fp, columns=["hospitalid", "numbedscategory", "teachingstatus", "region"] + ) + logger.info(f"Loading {str(patient_fp.resolve())}...") + raw_patient_df = load_raw_eicu_file(patient_fp) + + logger.info("Processing patient table...") + patient_df = process_patient(raw_patient_df, hospital_df) + write_lazyframe(patient_df, MEDS_input_dir / "patient.parquet") + + all_fps = [ + fp for fp in raw_cohort_dir.glob("*.csv.gz") if fp.name not in {"hospital.csv.gz", "patient.csv.gz"} + ] + + unused_tables = { + "admissiondrug", + "apacheApsVar", + "apachePatientResult", + "apachePredVar", + "carePlanCareProvider", + "customLab", + "intakeOutput", + "microLab", + "note", + } + + for in_fp in all_fps: + pfx = get_shard_prefix(raw_cohort_dir, in_fp) + if pfx in unused_tables: + logger.warning(f"Skipping {pfx} as it is not supported in this pipeline.") + continue + elif pfx not in functions: + logger.warning(f"No function needed for {pfx}. For eICU, THIS IS UNEXPECTED") + continue + + out_fp = MEDS_input_dir / f"{pfx}.parquet" + + if out_fp.is_file(): + print(f"Done with {pfx}. Continuing") + continue + + out_fp.parent.mkdir(parents=True, exist_ok=True) + + fn = functions[pfx] + + st = datetime.now() + logger.info(f"Processing {pfx}...") + df = load_raw_eicu_file(in_fp) + logger.info(f" * Loaded raw {in_fp} in {datetime.now() - st}") + processed_df = fn(df, patient_df) + write_lazyframe(processed_df, out_fp) + logger.info(f" * Processed and wrote to {str(out_fp.resolve())} in {datetime.now() - st}") + + logger.info(f"Done! All dataframes processed and written to {str(MEDS_input_dir.resolve())}") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index bbf8dee..25b9527 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,9 +19,11 @@ classifiers = [ dependencies = ["polars", "pyarrow", "nested_ragged_tensors", "loguru", "hydra-core", "numpy"] [project.optional-dependencies] -mimic = ["rootutils"] +examples = ["rootutils"] dev = ["pre-commit"] tests = ["pytest", "pytest-cov[toml]", "rootutils"] +local_parallelism = ["hydra-joblib-launcher"] +slurm_parallelism = ["hydra-submitit-launcher"] [project.urls] Homepage = "https://github.com/mmcdermott/MEDS_polars_functions" diff --git a/scripts/extraction/convert_to_sharded_events.py b/scripts/extraction/convert_to_sharded_events.py index 50fcab2..bc1eff3 100755 --- a/scripts/extraction/convert_to_sharded_events.py +++ b/scripts/extraction/convert_to_sharded_events.py @@ -22,10 +22,13 @@ def main(cfg: DictConfig): hydra_loguru_init() - Path(cfg.raw_cohort_dir) - MEDS_cohort_dir = Path(cfg.MEDS_cohort_dir) + logger.info( + f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" + f"Stage: {cfg.stage}\n\n" + f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" + ) - shards = json.loads((MEDS_cohort_dir / "splits.json").read_text()) + shards = json.loads((Path(cfg.stage_cfg.metadata_input_dir) / "splits.json").read_text()) event_conversion_cfg_fp = Path(cfg.event_conversion_config_fp) if not event_conversion_cfg_fp.exists(): @@ -39,7 +42,7 @@ def main(cfg: DictConfig): default_patient_id_col = event_conversion_cfg.pop("patient_id_col", "patient_id") - patient_subsharded_dir = MEDS_cohort_dir / "patient_sub_sharded_events" + patient_subsharded_dir = Path(cfg.stage_cfg.output_dir) patient_subsharded_dir.mkdir(parents=True, exist_ok=True) OmegaConf.save(event_conversion_cfg, patient_subsharded_dir / "event_conversion_config.yaml") @@ -57,7 +60,7 @@ def main(cfg: DictConfig): event_cfgs = copy.deepcopy(event_cfgs) input_patient_id_column = event_cfgs.pop("patient_id_col", default_patient_id_col) - event_shards = list((MEDS_cohort_dir / "sub_sharded" / input_prefix).glob("*.parquet")) + event_shards = list((Path(cfg.stage_cfg.data_input_dir) / input_prefix).glob("*.parquet")) random.shuffle(event_shards) for shard_fp in event_shards: diff --git a/scripts/extraction/merge_to_MEDS_cohort.py b/scripts/extraction/merge_to_MEDS_cohort.py index cc69d2f..ade8d50 100755 --- a/scripts/extraction/merge_to_MEDS_cohort.py +++ b/scripts/extraction/merge_to_MEDS_cohort.py @@ -2,12 +2,13 @@ import json import random +from functools import partial from pathlib import Path import hydra import polars as pl from loguru import logger -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from MEDS_polars_functions.mapper import wrap as rwlock_wrap from MEDS_polars_functions.utils import hydra_loguru_init @@ -15,7 +16,7 @@ pl.enable_string_cache() -def read_fn(sp_dir: Path) -> pl.LazyFrame: +def read_fn(sp_dir: Path, unique_by: list[str] | str | None) -> pl.LazyFrame: files_to_read = list(sp_dir.glob("**/*.parquet")) if not files_to_read: @@ -25,7 +26,25 @@ def read_fn(sp_dir: Path) -> pl.LazyFrame: logger.info(f"Reading {len(files_to_read)} files:\n{file_strs}") dfs = [pl.scan_parquet(fp, glob=False) for fp in files_to_read] - return pl.concat(dfs, how="diagonal").unique(maintain_order=False).sort(by=["patient_id", "timestamp"]) + df = pl.concat(dfs, how="diagonal_relaxed") + + match unique_by: + case None: + pass + case "*": + df = df.unique(maintain_order=False) + case list() if len(unique_by) == 0 and all(isinstance(u, str) for u in unique_by): + subset = [] + for u in unique_by: + if u in df.columns: + subset.append(u) + else: + logger.warning(f"Column {u} not found in dataframe. Omitting from unique-by subset.") + df = df.unique(maintain_order=False, subset=subset) + case _: + raise ValueError(f"Invalid unique_by value: {unique_by}") + + return df.sort(by=["patient_id", "timestamp"], multithreaded=False) def write_fn(df: pl.LazyFrame, out_fp: Path) -> None: @@ -42,27 +61,33 @@ def main(cfg: DictConfig): hydra_loguru_init() - MEDS_cohort_dir = Path(cfg.MEDS_cohort_dir) + logger.info( + f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" + f"Stage: {cfg.stage}\n\n" + f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" + ) - shards = json.loads((MEDS_cohort_dir / "splits.json").read_text()) + shards = json.loads((Path(cfg.stage_cfg.metadata_input_dir) / "splits.json").read_text()) logger.info("Starting patient shard merging.") - patient_subsharded_dir = MEDS_cohort_dir / "patient_sub_sharded_events" + patient_subsharded_dir = Path(cfg.stage_cfg.data_input_dir) if not patient_subsharded_dir.is_dir(): raise FileNotFoundError(f"Patient sub-sharded directory not found: {patient_subsharded_dir}") patient_splits = list(shards.keys()) random.shuffle(patient_splits) + reader = partial(read_fn, unique_by=cfg.stage_cfg.get("unique_by", None)) + for sp in patient_splits: in_dir = patient_subsharded_dir / sp - out_fp = MEDS_cohort_dir / "final_cohort" / f"{sp}.parquet" + out_fp = Path(cfg.stage_cfg.output_dir) / f"{sp}.parquet" shard_fps = sorted(list(in_dir.glob("**/*.parquet"))) shard_fp_strs = [f" * {str(fp.resolve())}" for fp in shard_fps] logger.info(f"Merging {len(shard_fp_strs)} shards into {out_fp}:\n" + "\n".join(shard_fp_strs)) - rwlock_wrap(in_dir, out_fp, read_fn, write_fn, identity_fn, do_return=False) + rwlock_wrap(in_dir, out_fp, reader, write_fn, identity_fn, do_return=False) logger.info("Output cohort written.") diff --git a/scripts/extraction/shard_events.py b/scripts/extraction/shard_events.py index 15737c1..9ce0ac9 100755 --- a/scripts/extraction/shard_events.py +++ b/scripts/extraction/shard_events.py @@ -190,9 +190,14 @@ def main(cfg: DictConfig): """ hydra_loguru_init() - raw_cohort_dir = Path(cfg.raw_cohort_dir) - MEDS_cohort_dir = Path(cfg.MEDS_cohort_dir) - row_chunksize = cfg.row_chunksize + logger.info( + f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" + f"Stage: {cfg.stage}\n\n" + f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" + ) + + raw_cohort_dir = Path(cfg.stage_cfg.data_input_dir) + row_chunksize = cfg.stage_cfg.row_chunksize event_conversion_cfg_fp = Path(cfg.event_conversion_config_fp) if not event_conversion_cfg_fp.exists(): @@ -217,6 +222,9 @@ def main(cfg: DictConfig): input_files_to_subshard.append(f) seen_files.add(get_shard_prefix(raw_cohort_dir, f)) + if not input_files_to_subshard: + raise FileNotFoundError(f"Can't find any files in {str(raw_cohort_dir.resolve())} to sub-shard!") + random.shuffle(input_files_to_subshard) subsharding_files_strs = "\n".join([f" * {str(fp.resolve())}" for fp in input_files_to_subshard]) @@ -226,19 +234,21 @@ def main(cfg: DictConfig): ) logger.info( f"Will read raw data from {str(raw_cohort_dir.resolve())}/$IN_FILE.parquet and write sub-sharded " - f"data to {str(MEDS_cohort_dir.resolve())}/sub_sharded/$IN_FILE/$ROW_START-$ROW_END.parquet" + f"data to {cfg.stage_cfg.output_dir}/$IN_FILE/$ROW_START-$ROW_END.parquet" ) start = datetime.now() for input_file in input_files_to_subshard: columns = prefix_to_columns[get_shard_prefix(raw_cohort_dir, input_file)] - out_dir = MEDS_cohort_dir / "sub_sharded" / get_shard_prefix(raw_cohort_dir, input_file) + out_dir = Path(cfg.stage_cfg.output_dir) / get_shard_prefix(raw_cohort_dir, input_file) out_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Processing {input_file} to {out_dir}.") logger.info(f"Performing preliminary read of {str(input_file.resolve())} to determine row count.") - df = scan_with_row_idx(input_file, columns=columns, infer_schema_length=cfg["infer_schema_length"]) + df = scan_with_row_idx( + input_file, columns=columns, infer_schema_length=cfg.stage_cfg.infer_schema_length + ) row_count = df.select(pl.len()).collect().item() @@ -272,7 +282,9 @@ def main(cfg: DictConfig): rwlock_wrap( input_file, out_fp, - partial(scan_with_row_idx, columns=columns, infer_schema_length=cfg["infer_schema_length"]), + partial( + scan_with_row_idx, columns=columns, infer_schema_length=cfg.stage_cfg.infer_schema_length + ), write_lazyframe, compute_fn, do_overwrite=cfg.do_overwrite, diff --git a/scripts/extraction/split_and_shard_patients.py b/scripts/extraction/split_and_shard_patients.py index fa5c1c2..f618da5 100755 --- a/scripts/extraction/split_and_shard_patients.py +++ b/scripts/extraction/split_and_shard_patients.py @@ -18,10 +18,16 @@ def main(cfg: DictConfig): hydra_loguru_init() + logger.info( + f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" + f"Stage: {cfg.stage}\n\n" + f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" + ) + logger.info("Starting patient splitting and sharding") - MEDS_cohort_dir = Path(cfg.MEDS_cohort_dir) - subsharded_dir = MEDS_cohort_dir / "sub_sharded" + MEDS_cohort_dir = Path(cfg.stage_cfg.output_dir) + subsharded_dir = Path(cfg.stage_cfg.data_input_dir) event_conversion_cfg_fp = Path(cfg.event_conversion_config_fp) if not event_conversion_cfg_fp.exists(): @@ -61,8 +67,8 @@ def main(cfg: DictConfig): logger.info(f"Found {len(patient_ids)} unique patient IDs of type {patient_ids.dtype}") - if cfg.external_splits_json_fp: - external_splits_json_fp = Path(cfg.external_splits_json_fp) + if cfg.stage_cfg.external_splits_json_fp: + external_splits_json_fp = Path(cfg.stage_cfg.external_splits_json_fp) if not external_splits_json_fp.exists(): raise FileNotFoundError(f"External splits JSON file not found at {external_splits_json_fp}") @@ -79,8 +85,8 @@ def main(cfg: DictConfig): sharded_patients = shard_patients( patients=patient_ids, external_splits=external_splits, - split_fracs_dict=cfg.split_fracs, - n_patients_per_shard=cfg.n_patients_per_shard, + split_fracs_dict=cfg.stage_cfg.split_fracs, + n_patients_per_shard=cfg.stage_cfg.n_patients_per_shard, seed=cfg.seed, ) diff --git a/scripts/preprocessing/add_time_derived_measurements.py b/scripts/preprocessing/add_time_derived_measurements.py index e5cae0d..1e01067 100644 --- a/scripts/preprocessing/add_time_derived_measurements.py +++ b/scripts/preprocessing/add_time_derived_measurements.py @@ -24,12 +24,17 @@ def main(cfg: DictConfig): hydra_loguru_init() - MEDS_cohort_dir = Path(cfg.MEDS_cohort_dir) - output_dir = Path(cfg.output_data_dir) + logger.info( + f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" + f"Stage: {cfg.stage}\n\n" + f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" + ) - shards = json.loads((MEDS_cohort_dir / "splits.json").read_text()) + output_dir = Path(cfg.stage_dfg.output_dir) - final_cohort_dir = MEDS_cohort_dir / "final_cohort" + shards = json.loads((Path(cfg.stage_cfg.metadata_input_dir) / "splits.json").read_text()) + + final_cohort_dir = cfg.stage_cfg.data_input_dir / "final_cohort" filtered_patients_dir = output_dir / "patients_above_length_threshold" with_time_derived_dir = output_dir / "with_time_derived_measurements" diff --git a/scripts/preprocessing/collect_code_metadata.py b/scripts/preprocessing/collect_code_metadata.py index 36f4b77..fa25bcb 100644 --- a/scripts/preprocessing/collect_code_metadata.py +++ b/scripts/preprocessing/collect_code_metadata.py @@ -9,7 +9,7 @@ import hydra import polars as pl from loguru import logger -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from MEDS_polars_functions.code_metadata import mapper_fntr, reducer_fntr from MEDS_polars_functions.mapper import wrap as rwlock_wrap @@ -22,6 +22,12 @@ def main(cfg: DictConfig): hydra_loguru_init() + logger.info( + f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" + f"Stage: {cfg.stage}\n\n" + f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" + ) + MEDS_cohort_dir = Path(cfg.MEDS_cohort_dir) output_dir = Path(cfg.output_data_dir) diff --git a/scripts/preprocessing/filter_patients.py b/scripts/preprocessing/filter_patients.py index f926401..a2b6308 100644 --- a/scripts/preprocessing/filter_patients.py +++ b/scripts/preprocessing/filter_patients.py @@ -8,7 +8,7 @@ import hydra import polars as pl from loguru import logger -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from MEDS_polars_functions.filter_patients_by_length import ( filter_patients_by_num_events, @@ -24,6 +24,12 @@ def main(cfg: DictConfig): hydra_loguru_init() + logger.info( + f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" + f"Stage: {cfg.stage}\n\n" + f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" + ) + MEDS_cohort_dir = Path(cfg.MEDS_cohort_dir) output_dir = Path(cfg.output_data_dir) diff --git a/src/MEDS_polars_functions/event_conversion.py b/src/MEDS_polars_functions/event_conversion.py index 15f1e9a..eae9505 100644 --- a/src/MEDS_polars_functions/event_conversion.py +++ b/src/MEDS_polars_functions/event_conversion.py @@ -278,6 +278,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy ... ValueError: Source column 'discharge_time' for event column foobar is not numeric or categorical! Cannot be used as an event col. """ # noqa: E501 + df = df event_exprs = {"patient_id": pl.col("patient_id")} if "code" not in event_cfg: @@ -550,5 +551,5 @@ def convert_to_events( except Exception as e: raise ValueError(f"Error extracting event {event_name}: {e}") from e - df = pl.concat(event_dfs, how="diagonal") + df = pl.concat(event_dfs, how="diagonal_relaxed") return df diff --git a/src/MEDS_polars_functions/mapper.py b/src/MEDS_polars_functions/mapper.py index deefd0d..34275b8 100644 --- a/src/MEDS_polars_functions/mapper.py +++ b/src/MEDS_polars_functions/mapper.py @@ -8,6 +8,79 @@ from loguru import logger +LOCK_TIME_FMT = "%Y-%m-%dT%H:%M:%S.%f" + + +def get_earliest_lock(cache_directory: Path) -> datetime | None: + """Returns the earliest start time of any lock file present in a cache directory, or None if none exist. + + Args: + cache_directory: The cache directory to check for the presence of a lock file. + + Examples: + >>> import tempfile + >>> directory = tempfile.TemporaryDirectory() + >>> root = Path(directory.name) + >>> empty_directory = root / "cache_empty" + >>> empty_directory.mkdir(exist_ok=True, parents=True) + >>> cache_directory = root / "cache_with_locks" + >>> locks_directory = cache_directory / "locks" + >>> locks_directory.mkdir(exist_ok=True, parents=True) + >>> time_1 = datetime(2021, 1, 1) + >>> time_1_str = time_1.strftime(LOCK_TIME_FMT) # "2021-01-01T00:00:00.000000" + >>> lock_fp_1 = locks_directory / f"{time_1_str}.json" + >>> _ = lock_fp_1.write_text(json.dumps({"start": time_1_str})) + >>> time_2 = datetime(2021, 1, 2, 3, 4, 5) + >>> time_2_str = time_2.strftime(LOCK_TIME_FMT) # "2021-01-02T03:04:05.000000" + >>> lock_fp_2 = locks_directory / f"{time_2_str}.json" + >>> _ = lock_fp_2.write_text(json.dumps({"start": time_2_str})) + >>> get_earliest_lock(cache_directory) + datetime.datetime(2021, 1, 1, 0, 0) + >>> get_earliest_lock(empty_directory) is None + True + >>> lock_fp_1.unlink() + >>> get_earliest_lock(cache_directory) + datetime.datetime(2021, 1, 2, 3, 4, 5) + >>> directory.cleanup() + """ + locks_directory = cache_directory / "locks" + + lock_times = [ + datetime.strptime(json.loads(lock_fp.read_text())["start"], LOCK_TIME_FMT) + for lock_fp in locks_directory.glob("*.json") + ] + + return min(lock_times) if lock_times else None + + +def register_lock(cache_directory: Path) -> tuple[datetime, Path]: + """Register a lock file in a cache directory. + + Args: + cache_directory: The cache directory to register a lock file in. + + Examples: + >>> import tempfile + >>> directory = tempfile.TemporaryDirectory() + >>> root = Path(directory.name) + >>> cache_directory = root / "cache_with_locks" + >>> lock_time, lock_fp = register_lock(cache_directory) + >>> assert (datetime.now() - lock_time).total_seconds() < 1, "Lock time should be ~ now." + >>> lock_fp.is_file() + True + >>> lock_fp.read_text() == f'{{"start": "{lock_time.strftime(LOCK_TIME_FMT)}"}}' + True + >>> directory.cleanup() + """ + + lock_directory = cache_directory / "locks" + lock_directory.mkdir(exist_ok=True, parents=True) + + lock_time = datetime.now() + lock_fp = lock_directory / f"{lock_time.strftime(LOCK_TIME_FMT)}.json" + lock_fp.write_text(json.dumps({"start": lock_time.strftime(LOCK_TIME_FMT)})) + return lock_time, lock_fp + def wrap[ DF_T @@ -108,15 +181,15 @@ def wrap[ │ 3 ┆ 5 ┆ 12 │ └─────┴─────┴─────┘ >>> shutil.rmtree(cache_directory) - >>> lock_fp = cache_directory / "lock.json" - >>> assert not lock_fp.is_file() - >>> def lock_fp_checker_fn(df: pl.DataFrame) -> pl.DataFrame: - ... print(f"Lock fp exists? {lock_fp.is_file()}") + >>> lock_dir = cache_directory / "locks" + >>> assert not lock_dir.exists() + >>> def lock_dir_checker_fn(df: pl.DataFrame) -> pl.DataFrame: + ... print(f"Lock dir exists? {lock_dir.exists()}") ... return df >>> result_computed, out_df = wrap( - ... in_fp, out_fp, read_fn, write_fn, lock_fp_checker_fn, do_return=True + ... in_fp, out_fp, read_fn, write_fn, lock_dir_checker_fn, do_return=True ... ) - Lock fp exists? True + Lock dir exists? True >>> assert result_computed >>> out_df shape: (3, 3) @@ -146,21 +219,19 @@ def wrap[ cache_directory = out_fp.parent / f".{out_fp.stem}_cache" cache_directory.mkdir(exist_ok=True, parents=True) - st_time = datetime.now() - runtime_info = {"start": str(st_time)} + earliest_lock_time = get_earliest_lock(cache_directory) + if earliest_lock_time is not None: + logger.info(f"{out_fp} is in progress as of {earliest_lock_time}. Returning.") + return False, None if do_return else False - lock_fp = cache_directory / "lock.json" - if lock_fp.is_file(): - started_at = json.loads(lock_fp.read_text())["start"] - logger.info( - f"{out_fp} is under construction as of {started_at} as {lock_fp} exists. " "Returning None." - ) - if do_return: - return False, None - else: - return False + st_time, lock_fp = register_lock(cache_directory) - lock_fp.write_text(json.dumps(runtime_info)) + logger.info(f"Registered lock at {st_time}. Double checking no earlier locks have been registered.") + earliest_lock_time = get_earliest_lock(cache_directory) + if earliest_lock_time < st_time: + logger.info(f"Earlier lock found at {earliest_lock_time}. Deleting current lock and returning.") + lock_fp.unlink() + return False, None if do_return else False logger.info(f"Reading input dataframe from {in_fp}") df = read_fn(in_fp) diff --git a/src/MEDS_polars_functions/utils.py b/src/MEDS_polars_functions/utils.py index 7899653..a307bae 100644 --- a/src/MEDS_polars_functions/utils.py +++ b/src/MEDS_polars_functions/utils.py @@ -1,11 +1,166 @@ """Core utilities for MEDS pipelines built with these tools.""" +import inspect import os +import sys from pathlib import Path import hydra import polars as pl -from loguru import logger as log +from loguru import logger +from omegaconf import OmegaConf + +pl.enable_string_cache() + + +def get_script_docstring() -> str: + """Returns the docstring of the main function of the script that was called. + + Returns: + str: TODO + """ + + main_module = sys.modules["__main__"] + func = getattr(main_module, "main", None) + if func and callable(func): + return inspect.getdoc(func) or "" + return "" + + +def current_script_name() -> str: + """Returns the name of the script that called this function. + + Returns: + str: The name of the script that called this function. + """ + return Path(sys.argv[0]).stem + + +def populate_stage( + stage_name: str, + input_dir: str, + cohort_dir: str, + stages: list[str], + stage_configs: dict[str, dict], + pre_parsed_stages: dict[str, dict] | None = None, +) -> dict: + """Populates a stage in the stages configuration with inferred stage parameters. + + Infers and adds (unless already present, in which case the provided value is used) the following + parameters to the stage configuration: + - `is_metadata`: Whether the stage is a metadata stage, which is determined to be `False` if the stage + does not have an `aggregations` parameter. + - `data_input_dir`: The input directory for the stage (either the global input directory or the previous + data stage's output directory). + - `metadata_input_dir`: The input directory for the stage (either the global input directory or the + previous metadata stage's output directory). + - `output_dir`: The output directory for the stage (the cohort directory with the stage name appended). + + Args: + stage_name: The name of the stage to populate. + input_dir: The global input directory. + cohort_dir: The cohort directory into which this overall pipeline is writing data. + stages: The names of the stages processed by this pipeline in order. + stage_configs: The raw, unresolved stage configuration dictionaries for any stages with specific + arguments, keyed by stage name. + pre_parsed_stages: The stages configuration dictionaries (resolved), keyed by stage name. If + specified, the function will not re-resolve the stages in this list. + + Returns: + dict: The populated stage configuration. + + Raises: + ValueError: If the stage is not present in the stages configuration. + + Examples: + >>> from omegaconf import DictConfig + >>> root_config = DictConfig({ + ... "input_dir": "/a/b", + ... "cohort_dir": "/c/d", + ... "stages": ["stage1", "stage2", "stage3", "stage4", "stage5", "stage6"], + ... "stage_configs": { + ... "stage2": {"is_metadata": True}, + ... "stage3": {"is_metadata": None}, + ... "stage4": {"data_input_dir": "/e/f", "output_dir": "/g/h"}, + ... "stage5": {"aggregations": ["foo"]}, + ... }, + ... }) + >>> args = [root_config[k] for k in ["input_dir", "cohort_dir", "stages", "stage_configs"]] + >>> populate_stage("stage1", *args) # doctest: +NORMALIZE_WHITESPACE + {'is_metadata': False, 'data_input_dir': '/a/b', 'metadata_input_dir': '/a/b', + 'output_dir': '/c/d/stage1'} + >>> populate_stage("stage2", *args) # doctest: +NORMALIZE_WHITESPACE + {'is_metadata': True, 'data_input_dir': '/c/d/stage1', 'metadata_input_dir': '/a/b', + 'output_dir': '/c/d/stage2'} + >>> populate_stage("stage3", *args) # doctest: +NORMALIZE_WHITESPACE + {'is_metadata': False, 'data_input_dir': '/c/d/stage1', + 'metadata_input_dir': '/c/d/stage2', 'output_dir': '/c/d/stage3'} + >>> populate_stage("stage4", *args) # doctest: +NORMALIZE_WHITESPACE + {'data_input_dir': '/e/f', 'output_dir': '/g/h', 'is_metadata': False, + 'metadata_input_dir': '/c/d/stage2'} + >>> populate_stage("stage5", *args) # doctest: +NORMALIZE_WHITESPACE + {'aggregations': ['foo'], 'is_metadata': True, 'data_input_dir': '/g/h', + 'metadata_input_dir': '/c/d/stage2', 'output_dir': '/c/d/stage5'} + >>> populate_stage("stage6", *args) # doctest: +NORMALIZE_WHITESPACE + {'is_metadata': False, 'data_input_dir': '/g/h', + 'metadata_input_dir': '/c/d/stage5', 'output_dir': '/c/d/stage6'} + >>> populate_stage("stage7", *args) # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: 'stage7' is not a valid stage name. Options are: stage1, stage2, stage3, stage4, stage5, + stage6 + """ + + for s in stage_configs.keys(): + if s not in stages: + raise ValueError( + f"stage config key '{s}' is not a valid stage name. Options are: {list(stages.keys())}" + ) + + if stage_name not in stages: + raise ValueError(f"'{stage_name}' is not a valid stage name. Options are: {', '.join(stages)}") + + if pre_parsed_stages is None: + pre_parsed_stages = {} + + stage = None + prior_data_stage = None + prior_metadata_stage = None + for s in stages: + if s == stage_name: + stage = stage_configs.get(s, {}) + break + elif s in pre_parsed_stages: + s_resolved = pre_parsed_stages[s] + else: + s_resolved = populate_stage(s, input_dir, cohort_dir, stages, stage_configs, pre_parsed_stages) + + pre_parsed_stages[s] = s_resolved + if s_resolved["is_metadata"]: + prior_metadata_stage = s_resolved + else: + prior_data_stage = s_resolved + + inferred_keys = { + "is_metadata": "aggregations" in stage, + "data_input_dir": input_dir if prior_data_stage is None else prior_data_stage["output_dir"], + "metadata_input_dir": ( + input_dir if prior_metadata_stage is None else prior_metadata_stage["output_dir"] + ), + "output_dir": os.path.join(cohort_dir, stage_name), + } + + out = {**stage} + for key, val in inferred_keys.items(): + if key not in out or out[key] is None: + out[key] = val + + return out + + +OmegaConf.register_new_resolver("get_script_docstring", get_script_docstring, replace=False) +OmegaConf.register_new_resolver("current_script_name", current_script_name, replace=False) +OmegaConf.register_new_resolver("populate_stage", populate_stage, replace=False) def hydra_loguru_init() -> None: @@ -14,11 +169,14 @@ def hydra_loguru_init() -> None: Must be called from a hydra main! """ hydra_path = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir - log.add(os.path.join(hydra_path, "main.log")) + logger.add(os.path.join(hydra_path, "main.log")) def write_lazyframe(df: pl.LazyFrame, out_fp: Path) -> None: - df.collect().write_parquet(out_fp, use_pyarrow=True) + if isinstance(df, pl.LazyFrame): + df = df.collect() + + df.write_parquet(out_fp, use_pyarrow=True) def get_shard_prefix(base_path: Path, fp: Path) -> str: diff --git a/tests/test_extraction.py b/tests/test_extraction.py index a256864..9343d17 100644 --- a/tests/test_extraction.py +++ b/tests/test_extraction.py @@ -245,14 +245,14 @@ def test_extraction(): # 4. Merge to the final output. extraction_config_kwargs = { - "raw_cohort_dir": str(raw_cohort_dir.resolve()), - "MEDS_cohort_dir": str(MEDS_cohort_dir.resolve()), + "input_dir": str(raw_cohort_dir.resolve()), + "cohort_dir": str(MEDS_cohort_dir.resolve()), "event_conversion_config_fp": str(event_cfgs_yaml.resolve()), - "split_fracs.train": 4 / 6, - "split_fracs.tuning": 1 / 6, - "split_fracs.held_out": 1 / 6, - "row_chunksize": 10, - "n_patients_per_shard": 2, + "stage_configs.split_and_shard_patients.split_fracs.train": 4 / 6, + "stage_configs.split_and_shard_patients.split_fracs.tuning": 1 / 6, + "stage_configs.split_and_shard_patients.split_fracs.held_out": 1 / 6, + "stage_configs.shard_events.row_chunksize": 10, + "stage_configs.split_and_shard_patients.n_patients_per_shard": 2, "hydra.verbose": True, } @@ -269,7 +269,7 @@ def test_extraction(): all_stderrs.append(stderr) all_stdouts.append(stdout) - subsharded_dir = MEDS_cohort_dir / "sub_sharded" + subsharded_dir = MEDS_cohort_dir / "shard_events" try: out_files = list(subsharded_dir.glob("**/*.parquet")) @@ -319,24 +319,30 @@ def test_extraction(): all_stderrs.append(stderr) all_stdouts.append(stdout) - splits_fp = MEDS_cohort_dir / "splits.json" - assert splits_fp.is_file(), f"Expected splits @ {str(splits_fp.resolve())} to exist." + try: + splits_fp = MEDS_cohort_dir / "splits.json" + assert splits_fp.is_file(), f"Expected splits @ {str(splits_fp.resolve())} to exist." - splits = json.loads(splits_fp.read_text()) - expected_keys = ["train/0", "train/1", "tuning/0", "held_out/0"] + splits = json.loads(splits_fp.read_text()) + expected_keys = ["train/0", "train/1", "tuning/0", "held_out/0"] - expected_keys_str = ", ".join(f"'{k}'" for k in expected_keys) - got_keys_str = ", ".join(f"'{k}'" for k in splits.keys()) + expected_keys_str = ", ".join(f"'{k}'" for k in expected_keys) + got_keys_str = ", ".join(f"'{k}'" for k in splits.keys()) - assert set(splits.keys()) == set(expected_keys), ( - f"Expected splits to have keys {expected_keys_str}.\n" f"Got keys: {got_keys_str}" - ) + assert set(splits.keys()) == set(expected_keys), ( + f"Expected splits to have keys {expected_keys_str}.\n" f"Got keys: {got_keys_str}" + ) - assert splits == EXPECTED_SPLITS, ( - f"Expected splits to be {EXPECTED_SPLITS}, got {splits}. NOTE THIS MAY CHANGE IF THE SEED OR " - "DATA CHANGES -- FAILURE HERE MAY BE JUST DUE TO A NON-DETERMINISTIC SPLIT AND THE TEST NEEDING " - "TO BE UPDATED." - ) + assert splits == EXPECTED_SPLITS, ( + f"Expected splits to be {EXPECTED_SPLITS}, got {splits}. NOTE THIS MAY CHANGE IF THE SEED OR " + "DATA CHANGES -- FAILURE HERE MAY BE JUST DUE TO A NON-DETERMINISTIC SPLIT AND THE TEST " + "NEEDING TO BE UPDATED." + ) + except AssertionError as e: + print("Failed to split patients") + print(f"stderr:\n{stderr}") + print(f"stdout:\n{stdout}") + raise e # Step 3: Extract the events and sub-shard by patient stderr, stdout = run_command( @@ -347,7 +353,7 @@ def test_extraction(): all_stderrs.append(stderr) all_stdouts.append(stdout) - patient_subsharded_folder = MEDS_cohort_dir / "patient_sub_sharded_events" + patient_subsharded_folder = MEDS_cohort_dir / "convert_to_sharded_events" assert patient_subsharded_folder.is_dir(), f"Expected {patient_subsharded_folder} to be a directory." for split, expected_outputs in SUB_SHARDED_OUTPUTS.items():