diff --git a/MIMIC-IV_Example/README.md b/MIMIC-IV_Example/README.md index c003860..6bf348d 100644 --- a/MIMIC-IV_Example/README.md +++ b/MIMIC-IV_Example/README.md @@ -76,10 +76,10 @@ This is a step in a few parts: - the `hosp/diagnoses_icd` table with the `hosp/admissions` table to get the `dischtime` for each `hadm_id`. - the `hosp/drgcodes` table with the `hosp/admissions` table to get the `dischtime` for each `hadm_id`. -2. Convert the patient's static data to a more parseable form. This entails: - - Get the patient's DOB in a format that is usable for MEDS, rather than the integral `anchor_year` and +2. Convert the subject's static data to a more parseable form. This entails: + - Get the subject's DOB in a format that is usable for MEDS, rather than the integral `anchor_year` and `anchor_offset` fields. - - Merge the patient's `dod` with the `deathtime` from the `admissions` table. + - Merge the subject's `dod` with the `deathtime` from the `admissions` table. After these steps, modified files or symlinks to the original files will be written in a new directory which will be used as the input to the actual MEDS extraction ETL. We'll use `$MIMICIV_PREMEDS_DIR` to denote this @@ -109,14 +109,14 @@ This is a step in 4 parts: This step uses the `./scripts/extraction/shard_events.py` script. See `joint_script*.sh` for the expected format of the command. -2. Extract and form the patient splits and sub-shards. The `./scripts/extraction/split_and_shard_patients.py` +2. Extract and form the subject splits and sub-shards. The `./scripts/extraction/split_and_shard_subjects.py` script is used for this step. See `joint_script*.sh` for the expected format of the command. -3. Extract patient sub-shards and convert to MEDS events. The +3. Extract subject sub-shards and convert to MEDS events. The `./scripts/extraction/convert_to_sharded_events.py` script is used for this step. See `joint_script*.sh` for the expected format of the command. -4. Merge the MEDS events into a single file per patient sub-shard. The +4. Merge the MEDS events into a single file per subject sub-shard. The `./scripts/extraction/merge_to_MEDS_cohort.py` script is used for this step. See `joint_script*.sh` for the expected format of the command. @@ -139,7 +139,7 @@ timeline which is otherwise stored at the _datetime_ resolution? Other questions: -1. How to handle merging the deathtimes between the hosp table and the patients table? +1. How to handle merging the deathtimes between the hosp table and the subjects table? 2. How to handle the dob nonsense MIMIC has? ## Notes @@ -153,4 +153,4 @@ may need to run `unset SLURM_CPU_BIND` in your terminal first to avoid errors. If you wanted, some other processing could also be done here, such as: -1. Converting the patient's dynamically recorded race into a static, most commonly recorded race field. +1. Converting the subject's dynamically recorded race into a static, most commonly recorded race field. diff --git a/MIMIC-IV_Example/configs/event_configs.yaml b/MIMIC-IV_Example/configs/event_configs.yaml index 0cd0381..619d5a2 100644 --- a/MIMIC-IV_Example/configs/event_configs.yaml +++ b/MIMIC-IV_Example/configs/event_configs.yaml @@ -1,4 +1,4 @@ -patient_id_col: subject_id +subject_id_col: subject_id hosp/admissions: ed_registration: code: ED_REGISTRATION @@ -27,7 +27,7 @@ hosp/admissions: time: col(dischtime) time_format: "%Y-%m-%d %H:%M:%S" hadm_id: hadm_id - # We omit the death event here as it is joined to the data in the patients table in the pre-MEDS step. + # We omit the death event here as it is joined to the data in the subjects table in the pre-MEDS step. hosp/diagnoses_icd: diagnosis: @@ -108,7 +108,7 @@ hosp/omr: time: col(chartdate) time_format: "%Y-%m-%d" -hosp/patients: +hosp/subjects: gender: code: - GENDER @@ -295,18 +295,18 @@ icu/inputevents: description: ["omop_concept_name", "label"] # List of strings are columns to be collated itemid: "itemid (omop_source_code)" parent_codes: "{omop_vocabulary_id}/{omop_concept_code}" - patient_weight: + subject_weight: code: - - PATIENT_WEIGHT_AT_INFUSION + - SUBJECT_WEIGHT_AT_INFUSION - KG time: col(starttime) time_format: "%Y-%m-%d %H:%M:%S" - numeric_value: patientweight + numeric_value: subjectweight icu/outputevents: output: code: - - PATIENT_FLUID_OUTPUT + - SUBJECT_FLUID_OUTPUT - col(itemid) - col(valueuom) time: col(charttime) diff --git a/MIMIC-IV_Example/joint_script.sh b/MIMIC-IV_Example/joint_script.sh index a98fee7..dd1459c 100755 --- a/MIMIC-IV_Example/joint_script.sh +++ b/MIMIC-IV_Example/joint_script.sh @@ -8,7 +8,7 @@ function display_help() { echo "Usage: $0 " echo echo "This script processes MIMIC-IV data through several steps, handling raw data conversion," - echo "sharding events, splitting patients, converting to sharded events, and merging into a MEDS cohort." + echo "sharding events, splitting subjects, converting to sharded events, and merging into a MEDS cohort." echo echo "Arguments:" echo " MIMICIV_RAW_DIR Directory containing raw MIMIC-IV data files." @@ -88,11 +88,11 @@ MEDS_extract-shard_events \ etl_metadata.dataset_version="2.2" \ event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" -echo "Splitting patients in serial" -MEDS_extract-split_and_shard_patients \ +echo "Splitting subjects in serial" +MEDS_extract-split_and_shard_subjects \ input_dir="$MIMICIV_PREMEDS_DIR" \ cohort_dir="$MIMICIV_MEDS_DIR" \ - stage="split_and_shard_patients" \ + stage="split_and_shard_subjects" \ etl_metadata.dataset_name="MIMIC-IV" \ etl_metadata.dataset_version="2.2" \ event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" diff --git a/MIMIC-IV_Example/joint_script_slurm.sh b/MIMIC-IV_Example/joint_script_slurm.sh index 3ff9684..e13fb7e 100755 --- a/MIMIC-IV_Example/joint_script_slurm.sh +++ b/MIMIC-IV_Example/joint_script_slurm.sh @@ -8,7 +8,7 @@ function display_help() { echo "Usage: $0 " echo echo "This script processes MIMIC-IV data through several steps, handling raw data conversion," - echo "sharding events, splitting patients, converting to sharded events, and merging into a MEDS cohort." + echo "sharding events, splitting subjects, converting to sharded events, and merging into a MEDS cohort." echo "This script uses slurm to process the data in parallel via the 'submitit' Hydra launcher." echo echo "Arguments:" @@ -72,8 +72,8 @@ MEDS_extract-shard_events \ event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml \ stage=shard_events -echo "Splitting patients on one worker" -MEDS_extract-split_and_shard_patients \ +echo "Splitting subjects on one worker" +MEDS_extract-split_and_shard_subjects \ --multirun \ worker="range(0,1)" \ hydra/launcher=submitit_slurm \ diff --git a/README.md b/README.md index e797d3f..a9e057d 100644 --- a/README.md +++ b/README.md @@ -45,12 +45,12 @@ directories. The fundamental design philosophy of this repository can be summarized as follows: 1. _(The MEDS Assumption)_: All structured electronic health record (EHR) data can be represented as a - series of events, each of which is associated with a patient, a time, and a set of codes and + series of events, each of which is associated with a subject, a time, and a set of codes and numeric values. This representation is the Medical Event Data Standard (MEDS) format, and in this - repository we use it in the "flat" format, where data is organized as rows of `patient_id`, + repository we use it in the "flat" format, where data is organized as rows of `subject_id`, `time`, `code`, `numeric_value` columns. 2. _Easy Efficiency through Sharding_: MEDS datasets in this repository are sharded into smaller, more - manageable pieces (organized as separate files) at the patient level (and, during the raw-data extraction + manageable pieces (organized as separate files) at the subject level (and, during the raw-data extraction process, the event level). This enables users to scale up their processing capabilities ad nauseum by leveraging more workers to process these shards in parallel. This parallelization is seamlessly enabled with the configuration schema used in the scripts in this repository. This style of parallelization @@ -62,7 +62,7 @@ The fundamental design philosophy of this repository can be summarized as follow the others, and each stage is designed to do a small amount of work and be easily testable in isolation. This design philosophy ensures that the pipeline is robust to changes, easy to debug, and easy to extend. In particular, to add new operations specific to a given model or dataset, the user need only write - simple functions that take in a flat MEDS dataframe (representing a single patient level shard) and + simple functions that take in a flat MEDS dataframe (representing a single subject level shard) and return a new flat MEDS dataframe, and then wrap that function in a script by following the examples provided in this repository. These individual functions can use the same configuration schema as other stages in the pipeline or include a separate, stage-specific configuration, and can use whatever @@ -198,9 +198,9 @@ To use this repository as a template, the user should follow these steps: Assumptions: 1. Your data is organized in a set of parquet files on disk such that each row of each file corresponds to - one or more measurements per patient and has all necessary information in that row to extract said - measurement, organized in a simple, columnar format. Each of these parquet files stores the patient's ID in - a column called `patient_id` in the same type. + one or more measurements per subject and has all necessary information in that row to extract said + measurement, organized in a simple, columnar format. Each of these parquet files stores the subject's ID in + a column called `subject_id` in the same type. 2. You have a pre-defined or can externally define the requisite MEDS base `code_metadata` file that describes the codes in your data as necessary. This file is not used in the provided pre-processing pipeline in this package, but is necessary for other uses of the MEDS data. @@ -221,16 +221,16 @@ The provided ETL consists of the following steps, which can be performed as need degree of parallelism is desired per step. 1. It re-shards the input data into a set of smaller, event-level shards to facilitate parallel processing. - This can be skipped if your input data is already suitably sharded at either a per-patient or per-event + This can be skipped if your input data is already suitably sharded at either a per-subject or per-event level. 2. It extracts the subject IDs from the sharded data and computes the set of ML splits and (per split) the - patient shards. These are stored in a JSON file in the output cohort directory. + subject shards. These are stored in a JSON file in the output cohort directory. 3. It converts the input, event level shards into the MEDS flat format and joins and shards these data into - patient-level shards for MEDS use and stores them in a nested format in the output cohort directory, + subject-level shards for MEDS use and stores them in a nested format in the output cohort directory, again in the flat format. This step can be broken down into two sub-steps: - - First, each input shard is converted to the MEDS flat format and split into sub patient-level shards. - - Second, the appropriate sub patient-level shards are joined and re-organized into the final - patient-level shards. This method ensures that we minimize the amount of read contention on the input + - First, each input shard is converted to the MEDS flat format and split into sub subject-level shards. + - Second, the appropriate sub subject-level shards are joined and re-organized into the final + subject-level shards. This method ensures that we minimize the amount of read contention on the input shards during the join process and can maximize parallel throughput, as (theoretically, with sufficient workers) all input shards can be sub-sharded in parallel and then all output shards can be joined in parallel. @@ -239,7 +239,7 @@ The ETL scripts all use [Hydra](https://hydra.cc/) for configuration management, `configs/extraction.yaml` file for configuration. The user can override any of these settings in the normal way for Hydra configurations. -If desired, appropriate scripts can be written and run at a per-patient shard level to convert between the +If desired, appropriate scripts can be written and run at a per-subject shard level to convert between the flat format and any of the other valid nested MEDS format, but for now we leave that up to the user. #### Input Event Extraction @@ -250,11 +250,11 @@ dataframes should be parsed into different event formats. The YAML file stores a following structure: ```yaml -patient_id: $GLOBAL_PATIENT_ID_OVERWRITE # Optional, if you want to overwrite the patient ID column name for - # all inputs. If not specified, defaults to "patient_id". +subject_id: $GLOBAL_SUBJECT_ID_OVERWRITE # Optional, if you want to overwrite the subject ID column name for + # all inputs. If not specified, defaults to "subject_id". $INPUT_FILE_STEM: - patient_id: $INPUT_FILE_PATIENT_ID # Optional, if you want to overwrite the patient ID column name for - # this input. IF not specified, defaults to the global patient ID. + subject_id: $INPUT_FILE_SUBJECT_ID # Optional, if you want to overwrite the subject ID column name for + # this input. IF not specified, defaults to the global subject ID. $EVENT_NAME: code: - $CODE_PART_1 @@ -287,18 +287,18 @@ script is a functional test that is also run with `pytest` to verify correctness 1. `scripts/extraction/shard_events.py` shards the input data into smaller, event-level shards by splitting raw files into chunks of a configurable number of rows. Files are split sequentially, with no regard for - data content or patient boundaries. The resulting files are stored in the `subsharded_events` + data content or subject boundaries. The resulting files are stored in the `subsharded_events` subdirectory of the output directory. -2. `scripts/extraction/split_and_shard_patients.py` splits the patient population into ML splits and shards - these splits into patient-level shards. The result of this process is only a simple `JSON` file - containing the patient IDs belonging to individual splits and shards. This file is stored in the +2. `scripts/extraction/split_and_shard_subjects.py` splits the subject population into ML splits and shards + these splits into subject-level shards. The result of this process is only a simple `JSON` file + containing the subject IDs belonging to individual splits and shards. This file is stored in the `output_directory/splits.json` file. 3. `scripts/extraction/convert_to_sharded_events.py` converts the input, event-level shards into the MEDS - event format and splits them into patient-level sub-shards. So, the resulting files are sharded into - patient-level, then event-level groups and are not merged into full patient-level shards or appropriately + event format and splits them into subject-level sub-shards. So, the resulting files are sharded into + subject-level, then event-level groups and are not merged into full subject-level shards or appropriately sorted for downstream use. -4. `scripts/extraction/merge_to_MEDS_cohort.py` merges the patient-level, event-level shards into full - patient-level shards and sorts them appropriately for downstream use. The resulting files are stored in +4. `scripts/extraction/merge_to_MEDS_cohort.py` merges the subject-level, event-level shards into full + subject-level shards and sorts them appropriately for downstream use. The resulting files are stored in the `output_directory/final_cohort` directory. ## MEDS Pre-processing Transformations @@ -308,9 +308,9 @@ contains a variety of pre-processing transformations and scripts that can be app in various ways to prepare them for downstream modeling. Broadly speaking, the pre-processing pipeline can be broken down into the following steps: -1. Filtering the dataset by criteria that do not require cross-patient analyses, e.g., +1. Filtering the dataset by criteria that do not require cross-subject analyses, e.g., - - Filtering patients by the number of events or unique times they have. + - Filtering subjects by the number of events or unique times they have. - Removing numeric values that fall outside of pre-specified, per-code ranges (e.g., for outlier removal). @@ -318,9 +318,9 @@ broken down into the following steps: - Adding time-derived measurements, e.g., - The time since the last event of a certain type. - - The patient's age as of each unique timepoint. + - The subject's age as of each unique timepoint. - The time-of-day of each event. - - Adding a "dummy" event to the dataset for each patient that occurs at the end of the observation + - Adding a "dummy" event to the dataset for each subject that occurs at the end of the observation period. 3. Iteratively (a) grouping the dataset by `code` and associated code modifier columns and collecting @@ -344,11 +344,11 @@ broken down into the following steps: 5. Normalizing the data to convert codes to indices and numeric values to the desired form (either categorical indices or normalized numeric values). -6. Tokenizing the data in time to create a pre-tensorized dataset with clear delineations between patients, - patient sequence elements, and measurements per sequence element (note that various of these delineations +6. Tokenizing the data in time to create a pre-tensorized dataset with clear delineations between subjects, + subject sequence elements, and measurements per sequence element (note that various of these delineations may be fully flat/trivial for unnested formats). -7. Tensorizing the data to permit efficient retrieval from disk of patient data for deep-learning modeling +7. Tensorizing the data to permit efficient retrieval from disk of subject data for deep-learning modeling via PyTorch. Much like how the entire MEDS ETL pipeline is controlled by a single configuration file, the pre-processing @@ -363,7 +363,7 @@ be a bottleneck. Tokenization is the process of producing dataframes that are arranged into the sequences that will eventually be processed by deep-learning methods. Generally, these dataframes will be arranged such that each row -corresponds to a unique patient, with nested list-type columns corresponding either to _events_ (unique +corresponds to a unique subject, with nested list-type columns corresponding either to _events_ (unique timepoints), themselves with nested, list-type measurements, or to _measurements_ (unique measurements within a timepoint) directly. Importantly, _tokenized files are generally not ideally suited to direct ingestion by PyTorch datasets_. Instead, they should undergo a _tensorization_ process to be converted into a format that @@ -379,7 +379,7 @@ does not inhibit rapid training, and (3) be organized such that CPU and GPU reso during training. Similarly, by _scalability_, we mean that the three desiderata above should hold true even as the dataset size grows much larger---while total training time can increase, time to begin training, to process the data per-item, and CPU/GPU resources required should remain constant, or only grow negligibly, -such as the cost of maintaining a larger index of patient IDs to file offsets or paths (though disk space will +such as the cost of maintaining a larger index of subject IDs to file offsets or paths (though disk space will of course increase). Depending on one's performance needs and dataset sizes, there are 3 modes of deep learning training that can @@ -398,7 +398,7 @@ on an as-needed basis. This mode is extremely scalable, because the entire datas loaded or stored in memory in its entirety. When done properly, retrieving data from disk can be done in a manner that is independent of the total dataset size as well, thereby rendering the load time similarly unconstrained by total dataset size. This mode is also extremely flexible, because different cohorts can be -loaded from the same base dataset simply by changing which patients and what offsets within patient data are +loaded from the same base dataset simply by changing which subjects and what offsets within subject data are read on any given cohort, all without changing the base files or underlying code. However, this mode does require ragged dataset collation which can be more resource intensive than pre-batched iteration, so it is slower than the "Fixed-batch retrieval" approach. This mode is what is currently supported by this repository. diff --git a/eICU_Example/README.md b/eICU_Example/README.md index c0494c9..37eb9d0 100644 --- a/eICU_Example/README.md +++ b/eICU_Example/README.md @@ -19,7 +19,7 @@ up from this one). - [ ] Testing the MEDS extraction ETL runs on eICU-CRD (this should be expected to work, but needs live testing). - [ ] Sub-sharding - - [ ] Patient split gathering + - [ ] Subject split gathering - [ ] Event extraction - [ ] Merging - [ ] Validating the output MEDS cohort @@ -58,10 +58,10 @@ This is a step in a few parts: 1. Join a few tables by `hadm_id` to get the right timestamps in the right rows for processing. In particular, we need to join: - TODO -2. Convert the patient's static data to a more parseable form. This entails: - - Get the patient's DOB in a format that is usable for MEDS, rather than the integral `anchor_year` and +2. Convert the subject's static data to a more parseable form. This entails: + - Get the subject's DOB in a format that is usable for MEDS, rather than the integral `anchor_year` and `anchor_offset` fields. - - Merge the patient's `dod` with the `deathtime` from the `admissions` table. + - Merge the subject's `dod` with the `deathtime` from the `admissions` table. After these steps, modified files or symlinks to the original files will be written in a new directory which will be used as the input to the actual MEDS extraction ETL. We'll use `$EICU_PREMEDS_DIR` to denote this @@ -78,12 +78,12 @@ In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less ## Step 3: Run the MEDS extraction ETL -Note that eICU has a lot more observations per patient than does MIMIC-IV, so to keep to a reasonable memory +Note that eICU has a lot more observations per subject than does MIMIC-IV, so to keep to a reasonable memory burden (e.g., \< 150GB per worker), you will want a smaller shard size, as well as to turn off the final unique check (which should not be necessary given the structure of eICU and is expensive) in the merge stage. You can do this by setting the following parameters at the end of the mandatory args when running this script: -- `stage_configs.split_and_shard_patients.n_patients_per_shard=10000` +- `stage_configs.split_and_shard_subjects.n_subjects_per_shard=10000` - `stage_configs.merge_to_MEDS_cohort.unique_by=null` ### Running locally, serially @@ -106,10 +106,10 @@ This is a step in 4 parts: In practice, on a machine with 150 GB of RAM and 10 cores, this step takes approximately 20 minutes in total. -1. Extract and form the patient splits and sub-shards. +1. Extract and form the subject splits and sub-shards. ```bash -./scripts/extraction/split_and_shard_patients.py \ +./scripts/extraction/split_and_shard_subjects.py \ input_dir=$EICU_PREMEDS_DIR \ cohort_dir=$EICU_MEDS_DIR \ event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml @@ -117,7 +117,7 @@ In practice, on a machine with 150 GB of RAM and 10 cores, this step takes appro In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less than 5 minutes in total. -1. Extract patient sub-shards and convert to MEDS events. +1. Extract subject sub-shards and convert to MEDS events. ```bash ./scripts/extraction/convert_to_sharded_events.py \ @@ -132,7 +132,7 @@ multiple times (though this will, of course, consume more resources). If your fi commands can also be launched as separate slurm jobs, for example. For eICU, this level of parallelization and performance is not necessary; however, for larger datasets, it can be. -1. Merge the MEDS events into a single file per patient sub-shard. +1. Merge the MEDS events into a single file per subject sub-shard. ```bash ./scripts/extraction/merge_to_MEDS_cohort.py \ @@ -172,10 +172,10 @@ to finish before moving on to the next stage. Let `$N_PARALLEL_WORKERS` be the n In practice, on a machine with 150 GB of RAM and 10 cores, this step takes approximately 20 minutes in total. -1. Extract and form the patient splits and sub-shards. +1. Extract and form the subject splits and sub-shards. ```bash -./scripts/extraction/split_and_shard_patients.py \ +./scripts/extraction/split_and_shard_subjects.py \ input_dir=$EICU_PREMEDS_DIR \ cohort_dir=$EICU_MEDS_DIR \ event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml @@ -183,7 +183,7 @@ In practice, on a machine with 150 GB of RAM and 10 cores, this step takes appro In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less than 5 minutes in total. -1. Extract patient sub-shards and convert to MEDS events. +1. Extract subject sub-shards and convert to MEDS events. ```bash ./scripts/extraction/convert_to_sharded_events.py \ @@ -198,7 +198,7 @@ multiple times (though this will, of course, consume more resources). If your fi commands can also be launched as separate slurm jobs, for example. For eICU, this level of parallelization and performance is not necessary; however, for larger datasets, it can be. -1. Merge the MEDS events into a single file per patient sub-shard. +1. Merge the MEDS events into a single file per subject sub-shard. ```bash ./scripts/extraction/merge_to_MEDS_cohort.py \ @@ -221,7 +221,7 @@ timeline which is otherwise stored at the _datetime_ resolution? Other questions: -1. How to handle merging the deathtimes between the hosp table and the patients table? +1. How to handle merging the deathtimes between the hosp table and the subjects table? 2. How to handle the dob nonsense MIMIC has? ## Future Work @@ -230,4 +230,4 @@ Other questions: If you wanted, some other processing could also be done here, such as: -1. Converting the patient's dynamically recorded race into a static, most commonly recorded race field. +1. Converting the subject's dynamically recorded race into a static, most commonly recorded race field. diff --git a/eICU_Example/configs/event_configs.yaml b/eICU_Example/configs/event_configs.yaml index e6f2e7a..fb7901c 100644 --- a/eICU_Example/configs/event_configs.yaml +++ b/eICU_Example/configs/event_configs.yaml @@ -1,7 +1,7 @@ -# Note that there is no "patient_id" for eICU -- patients are only differentiable during the course of a +# Note that there is no "subject_id" for eICU -- patients are only differentiable during the course of a # single health system stay. Accordingly, we set the "patient" id here as the "patientHealthSystemStayID" -patient_id_col: patienthealthsystemstayid +subject_id_col: patienthealthsystemstayid patient: dob: @@ -131,7 +131,7 @@ infusionDrug: volume_of_fluid: "volumeoffluid" patient_weight: code: - - "INFUSION_PATIENT_WEIGHT" + - "INFUSION_SUBJECT_WEIGHT" time: col(infusionEnteredTimestamp) numeric_value: "patientweight" diff --git a/eICU_Example/joint_script.sh b/eICU_Example/joint_script.sh index fd76ee2..0b3ad6c 100755 --- a/eICU_Example/joint_script.sh +++ b/eICU_Example/joint_script.sh @@ -8,7 +8,7 @@ function display_help() { echo "Usage: $0 " echo echo "This script processes eICU data through several steps, handling raw data conversion," - echo "sharding events, splitting patients, converting to sharded events, and merging into a MEDS cohort." + echo "sharding events, splitting subjects, converting to sharded events, and merging into a MEDS cohort." echo echo "Arguments:" echo " EICU_RAW_DIR Directory containing raw eICU data files." @@ -39,12 +39,12 @@ N_PARALLEL_WORKERS="$4" shift 4 -echo "Note that eICU has a lot more observations per patient than does MIMIC-IV, so to keep to a reasonable " +echo "Note that eICU has a lot more observations per subject than does MIMIC-IV, so to keep to a reasonable " echo "memory burden (e.g., < 150GB per worker), you will want a smaller shard size, as well as to turn off " echo "the final unique check (which should not be necessary given the structure of eICU and is expensive) " echo "in the merge stage. You can do this by setting the following parameters at the end of the mandatory " echo "args when running this script:" -echo " * stage_configs.split_and_shard_patients.n_patients_per_shard=10000" +echo " * stage_configs.split_and_shard_subjects.n_subjects_per_shard=10000" echo " * stage_configs.merge_to_MEDS_cohort.unique_by=null" echo "Running pre-MEDS conversion." @@ -59,8 +59,8 @@ echo "Running shard_events.py with $N_PARALLEL_WORKERS workers in parallel" cohort_dir="$EICU_MEDS_DIR" \ event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml "$@" -echo "Splitting patients in serial" -./scripts/extraction/split_and_shard_patients.py \ +echo "Splitting subjects in serial" +./scripts/extraction/split_and_shard_subjects.py \ input_dir="$EICU_PREMEDS_DIR" \ cohort_dir="$EICU_MEDS_DIR" \ event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml "$@" diff --git a/eICU_Example/joint_script_slurm.sh b/eICU_Example/joint_script_slurm.sh index 7880286..bdd7abe 100755 --- a/eICU_Example/joint_script_slurm.sh +++ b/eICU_Example/joint_script_slurm.sh @@ -8,7 +8,7 @@ function display_help() { echo "Usage: $0 " echo echo "This script processes eICU data through several steps, handling raw data conversion," - echo "sharding events, splitting patients, converting to sharded events, and merging into a MEDS cohort." + echo "sharding events, splitting subjects, converting to sharded events, and merging into a MEDS cohort." echo "This script uses slurm to process the data in parallel via the 'submitit' Hydra launcher." echo echo "Arguments:" @@ -71,8 +71,8 @@ echo "Trying submitit launching with $N_PARALLEL_WORKERS jobs." cohort_dir="$EICU_MEDS_DIR" \ event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml -echo "Splitting patients on one worker" -./scripts/extraction/split_and_shard_patients.py \ +echo "Splitting subjects on one worker" +./scripts/extraction/split_and_shard_subjects.py \ --multirun \ worker="range(0,1)" \ hydra/launcher=submitit_slurm \ diff --git a/pyproject.toml b/pyproject.toml index cc193b6..c9f4906 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ - "polars~=1.1.0", "pyarrow", "nested_ragged_tensors", "loguru", "hydra-core", "numpy", "meds==0.3", + "polars~=1.1.0", "pyarrow", "nested_ragged_tensors", "loguru", "hydra-core", "numpy", "meds==0.3.2", ] [tool.setuptools_scm] @@ -35,7 +35,7 @@ docs = [ [project.scripts] # MEDS_extract -MEDS_extract-split_and_shard_patients = "MEDS_transforms.extract.split_and_shard_patients:main" +MEDS_extract-split_and_shard_subjects = "MEDS_transforms.extract.split_and_shard_subjects:main" MEDS_extract-shard_events = "MEDS_transforms.extract.shard_events:main" MEDS_extract-convert_to_sharded_events = "MEDS_transforms.extract.convert_to_sharded_events:main" MEDS_extract-merge_to_MEDS_cohort = "MEDS_transforms.extract.merge_to_MEDS_cohort:main" @@ -50,7 +50,7 @@ MEDS_transform-fit_vocabulary_indices = "MEDS_transforms.fit_vocabulary_indices: MEDS_transform-reshard_to_split = "MEDS_transforms.reshard_to_split:main" ## Filters MEDS_transform-filter_measurements = "MEDS_transforms.filters.filter_measurements:main" -MEDS_transform-filter_patients = "MEDS_transforms.filters.filter_patients:main" +MEDS_transform-filter_subjects = "MEDS_transforms.filters.filter_subjects:main" ## Transforms MEDS_transform-reorder_measurements = "MEDS_transforms.transforms.reorder_measurements:main" MEDS_transform-add_time_derived_measurements = "MEDS_transforms.transforms.add_time_derived_measurements:main" diff --git a/src/MEDS_transforms/__init__.py b/src/MEDS_transforms/__init__.py index 38f2ac9..8d0ffd6 100644 --- a/src/MEDS_transforms/__init__.py +++ b/src/MEDS_transforms/__init__.py @@ -2,22 +2,23 @@ from importlib.resources import files import polars as pl +from meds import code_field, subject_id_field, time_field __package_name__ = "MEDS_transforms" try: __version__ = version(__package_name__) -except PackageNotFoundError: +except PackageNotFoundError: # pragma: no cover __version__ = "unknown" PREPROCESS_CONFIG_YAML = files(__package_name__).joinpath("configs/preprocess.yaml") EXTRACT_CONFIG_YAML = files(__package_name__).joinpath("configs/extract.yaml") -MANDATORY_COLUMNS = ["patient_id", "time", "code", "numeric_value"] +MANDATORY_COLUMNS = [subject_id_field, time_field, code_field, "numeric_value"] MANDATORY_TYPES = { - "patient_id": pl.Int64, - "time": pl.Datetime("us"), - "code": pl.String, + subject_id_field: pl.Int64, + time_field: pl.Datetime("us"), + code_field: pl.String, "numeric_value": pl.Float32, "categorical_value": pl.String, "text_value": pl.String, @@ -29,7 +30,7 @@ "category_value": "categoric_value", "textual_value": "text_value", "timestamp": "time", - "subject_id": "patient_id", + "patient_id": subject_id_field, } INFERRED_STAGE_KEYS = { diff --git a/src/MEDS_transforms/aggregate_code_metadata.py b/src/MEDS_transforms/aggregate_code_metadata.py index 1f6828b..1ac8cdf 100755 --- a/src/MEDS_transforms/aggregate_code_metadata.py +++ b/src/MEDS_transforms/aggregate_code_metadata.py @@ -12,6 +12,7 @@ import polars as pl import polars.selectors as cs from loguru import logger +from meds import subject_id_field from omegaconf import DictConfig, ListConfig, OmegaConf from MEDS_transforms import PREPROCESS_CONFIG_YAML @@ -26,7 +27,7 @@ class METADATA_FN(StrEnum): This enumeration contains the supported code-metadata collection and aggregation function names that can be applied to codes (or, rather, unique code & modifier units) in a MEDS cohort. Each function name is mapped, in the below `CODE_METADATA_AGGREGATIONS` dictionary, to mapper and reducer functions that (a) - collect the raw data at a per code-modifier level from MEDS patient-level shards and (b) aggregates two or + collect the raw data at a per code-modifier level from MEDS subject-level shards and (b) aggregates two or more per-shard metadata files into a single metadata file, which can be used to merge metadata across all shards into a single file. @@ -45,14 +46,14 @@ class METADATA_FN(StrEnum): or on the command line. Args: - "code/n_patients": Collects the number of unique patients who have (anywhere in their record) the code + "code/n_subjects": Collects the number of unique subjects who have (anywhere in their record) the code & modifiers group. "code/n_occurrences": Collects the total number of occurrences of the code & modifiers group across - all observations for all patients. - "values/n_patients": Collects the number of unique patients who have a non-null, non-nan + all observations for all subjects. + "values/n_subjects": Collects the number of unique subjects who have a non-null, non-nan numeric_value field for the code & modifiers group. "values/n_occurrences": Collects the total number of non-null, non-nan numeric_value occurrences for - the code & modifiers group across all observations for all patients. + the code & modifiers group across all observations for all subjects. "values/n_ints": Collects the number of times the observed, non-null numeric_value for the code & modifiers group is an integral value (i.e., a whole number, not an integral type). "values/sum": Collects the sum of the non-null, non-nan numeric_value values for the code & @@ -67,9 +68,9 @@ class METADATA_FN(StrEnum): the configuration file using the dictionary syntax for the aggregation. """ - CODE_N_PATIENTS = "code/n_patients" + CODE_N_PATIENTS = "code/n_subjects" CODE_N_OCCURRENCES = "code/n_occurrences" - VALUES_N_PATIENTS = "values/n_patients" + VALUES_N_PATIENTS = "values/n_subjects" VALUES_N_OCCURRENCES = "values/n_occurrences" VALUES_N_INTS = "values/n_ints" VALUES_SUM = "values/sum" @@ -157,10 +158,10 @@ def quantile_reducer(cols: cs._selector_proxy_, quantiles: list[float]) -> pl.Ex PRESENT_VALS = VAL.filter(VAL_PRESENT) CODE_METADATA_AGGREGATIONS: dict[METADATA_FN, MapReducePair] = { - METADATA_FN.CODE_N_PATIENTS: MapReducePair(pl.col("patient_id").n_unique(), pl.sum_horizontal), + METADATA_FN.CODE_N_PATIENTS: MapReducePair(pl.col(subject_id_field).n_unique(), pl.sum_horizontal), METADATA_FN.CODE_N_OCCURRENCES: MapReducePair(pl.len(), pl.sum_horizontal), METADATA_FN.VALUES_N_PATIENTS: MapReducePair( - pl.col("patient_id").filter(VAL_PRESENT).n_unique(), pl.sum_horizontal + pl.col(subject_id_field).filter(VAL_PRESENT).n_unique(), pl.sum_horizontal ), METADATA_FN.VALUES_N_OCCURRENCES: MapReducePair(PRESENT_VALS.len(), pl.sum_horizontal), METADATA_FN.VALUES_N_INTS: MapReducePair(VAL.filter(VAL_PRESENT & IS_INT).len(), pl.sum_horizontal), @@ -203,9 +204,9 @@ def validate_args_and_get_code_cols(stage_cfg: DictConfig, code_modifiers: list[ Traceback (most recent call last): ... ValueError: Metadata aggregation function INVALID not found in METADATA_FN enumeration. Values are: - code/n_patients, code/n_occurrences, values/n_patients, values/n_occurrences, values/n_ints, + code/n_subjects, code/n_occurrences, values/n_subjects, values/n_occurrences, values/n_ints, values/sum, values/sum_sqd, values/min, values/max, values/quantiles - >>> valid_cfg = DictConfig({"aggregations": ["code/n_patients", {"name": "values/n_ints"}]}) + >>> valid_cfg = DictConfig({"aggregations": ["code/n_subjects", {"name": "values/n_ints"}]}) >>> validate_args_and_get_code_cols(valid_cfg, 33) Traceback (most recent call last): ... @@ -264,7 +265,7 @@ def mapper_fntr( A function that extracts the specified metadata from a MEDS cohort shard after grouping by the specified code & modifier columns. **Note**: The output of this function will, if ``stage_cfg.do_summarize_over_all_codes`` is True, contain the metadata summarizing all observations - across all codes and patients in the shard, with both ``code`` and all ``code_modifiers`` set + across all codes and subjects in the shard, with both ``code`` and all ``code_modifiers`` set to `None` in the output dataframe, in the same format as the code/modifier specific rows with non-null values. @@ -274,13 +275,13 @@ def mapper_fntr( ... "code": ["A", "B", "A", "B", "C", "A", "C", "B", "D"], ... "modifier1": [1, 2, 1, 2, 1, 2, 1, 2, None], ... "modifier_ignored": [3, 3, 4, 4, 5, 5, 6, 6, 7], - ... "patient_id": [1, 2, 1, 3, 1, 2, 2, 2, 1], + ... "subject_id": [1, 2, 1, 3, 1, 2, 2, 2, 1], ... "numeric_value": [1.1, 2., 1.1, 4., 5., 6., 7.5, float('nan'), None], ... }) >>> df shape: (9, 5) ┌──────┬───────────┬──────────────────┬────────────┬───────────────┐ - │ code ┆ modifier1 ┆ modifier_ignored ┆ patient_id ┆ numeric_value │ + │ code ┆ modifier1 ┆ modifier_ignored ┆ subject_id ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ i64 ┆ i64 ┆ i64 ┆ f64 │ ╞══════╪═══════════╪══════════════════╪════════════╪═══════════════╡ @@ -295,14 +296,14 @@ def mapper_fntr( │ D ┆ null ┆ 7 ┆ 1 ┆ null │ └──────┴───────────┴──────────────────┴────────────┴───────────────┘ >>> stage_cfg = DictConfig({ - ... "aggregations": ["code/n_patients", "values/n_ints"], + ... "aggregations": ["code/n_subjects", "values/n_ints"], ... "do_summarize_over_all_codes": True ... }) >>> mapper = mapper_fntr(stage_cfg, None) >>> mapper(df.lazy()).collect() shape: (5, 3) ┌──────┬─────────────────┬───────────────┐ - │ code ┆ code/n_patients ┆ values/n_ints │ + │ code ┆ code/n_subjects ┆ values/n_ints │ │ --- ┆ --- ┆ --- │ │ str ┆ u32 ┆ u32 │ ╞══════╪═════════════════╪═══════════════╡ @@ -312,12 +313,12 @@ def mapper_fntr( │ C ┆ 2 ┆ 1 │ │ D ┆ 1 ┆ 0 │ └──────┴─────────────────┴───────────────┘ - >>> stage_cfg = DictConfig({"aggregations": ["code/n_patients", "values/n_ints"]}) + >>> stage_cfg = DictConfig({"aggregations": ["code/n_subjects", "values/n_ints"]}) >>> mapper = mapper_fntr(stage_cfg, None) >>> mapper(df.lazy()).collect() shape: (4, 3) ┌──────┬─────────────────┬───────────────┐ - │ code ┆ code/n_patients ┆ values/n_ints │ + │ code ┆ code/n_subjects ┆ values/n_ints │ │ --- ┆ --- ┆ --- │ │ str ┆ u32 ┆ u32 │ ╞══════╪═════════════════╪═══════════════╡ @@ -327,12 +328,12 @@ def mapper_fntr( │ D ┆ 1 ┆ 0 │ └──────┴─────────────────┴───────────────┘ >>> code_modifiers = ["modifier1"] - >>> stage_cfg = DictConfig({"aggregations": ["code/n_patients", "values/n_ints"]}) + >>> stage_cfg = DictConfig({"aggregations": ["code/n_subjects", "values/n_ints"]}) >>> mapper = mapper_fntr(stage_cfg, ListConfig(code_modifiers)) >>> mapper(df.lazy()).collect() shape: (5, 4) ┌──────┬───────────┬─────────────────┬───────────────┐ - │ code ┆ modifier1 ┆ code/n_patients ┆ values/n_ints │ + │ code ┆ modifier1 ┆ code/n_subjects ┆ values/n_ints │ │ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ i64 ┆ u32 ┆ u32 │ ╞══════╪═══════════╪═════════════════╪═══════════════╡ @@ -376,12 +377,12 @@ def mapper_fntr( │ C ┆ 1 ┆ 2 ┆ 12.5 │ │ D ┆ null ┆ 1 ┆ 0.0 │ └──────┴───────────┴────────────────────┴────────────┘ - >>> stage_cfg = DictConfig({"aggregations": ["values/n_patients", "values/n_occurrences"]}) + >>> stage_cfg = DictConfig({"aggregations": ["values/n_subjects", "values/n_occurrences"]}) >>> mapper = mapper_fntr(stage_cfg, code_modifiers) >>> mapper(df.lazy()).collect() shape: (5, 4) ┌──────┬───────────┬───────────────────┬──────────────────────┐ - │ code ┆ modifier1 ┆ values/n_patients ┆ values/n_occurrences │ + │ code ┆ modifier1 ┆ values/n_subjects ┆ values/n_occurrences │ │ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ i64 ┆ u32 ┆ u32 │ ╞══════╪═══════════╪═══════════════════╪══════════════════════╡ @@ -455,7 +456,7 @@ def mapper_fntr( def by_code_mapper(df: pl.LazyFrame) -> pl.LazyFrame: return df.group_by(code_key_columns).agg(**agg_operations).sort(code_key_columns) - def all_patients_mapper(df: pl.LazyFrame) -> pl.LazyFrame: + def all_subjects_mapper(df: pl.LazyFrame) -> pl.LazyFrame: local_agg_operations = agg_operations.copy() if METADATA_FN.VALUES_QUANTILES in agg_operations: local_agg_operations[METADATA_FN.VALUES_QUANTILES] = agg_operations[ @@ -467,8 +468,8 @@ def all_patients_mapper(df: pl.LazyFrame) -> pl.LazyFrame: def mapper(df: pl.LazyFrame) -> pl.LazyFrame: by_code = by_code_mapper(df) - all_patients = all_patients_mapper(df) - return pl.concat([all_patients, by_code], how="diagonal_relaxed").select( + all_subjects = all_subjects_mapper(df) + return pl.concat([all_subjects, by_code], how="diagonal_relaxed").select( *code_key_columns, *agg_operations.keys() ) @@ -502,9 +503,9 @@ def reducer_fntr( >>> df_1 = pl.DataFrame({ ... "code": [None, "A", "A", "B", "C"], ... "modifier1": [None, 1, 2, 1, 2], - ... "code/n_patients": [10, 1, 1, 2, 2], + ... "code/n_subjects": [10, 1, 1, 2, 2], ... "code/n_occurrences": [13, 2, 1, 3, 2], - ... "values/n_patients": [8, 1, 1, 2, 2], + ... "values/n_subjects": [8, 1, 1, 2, 2], ... "values/n_occurrences": [12, 2, 1, 3, 2], ... "values/n_ints": [4, 0, 1, 3, 1], ... "values/sum": [13.2, 2.2, 6.0, 14.0, 12.5], @@ -516,9 +517,9 @@ def reducer_fntr( >>> df_2 = pl.DataFrame({ ... "code": ["A", "A", "B", "C"], ... "modifier1": [1, 2, 1, None], - ... "code/n_patients": [3, 3, 4, 4], + ... "code/n_subjects": [3, 3, 4, 4], ... "code/n_occurrences": [10, 11, 8, 11], - ... "values/n_patients": [0, 1, 2, 2], + ... "values/n_subjects": [0, 1, 2, 2], ... "values/n_occurrences": [0, 4, 3, 2], ... "values/n_ints": [0, 1, 3, 1], ... "values/sum": [0., 7.0, 14.0, 12.5], @@ -530,9 +531,9 @@ def reducer_fntr( >>> df_3 = pl.DataFrame({ ... "code": ["D"], ... "modifier1": [1], - ... "code/n_patients": [2], + ... "code/n_subjects": [2], ... "code/n_occurrences": [2], - ... "values/n_patients": [1], + ... "values/n_subjects": [1], ... "values/n_occurrences": [3], ... "values/n_ints": [3], ... "values/sum": [2], @@ -542,12 +543,12 @@ def reducer_fntr( ... "values/quantiles": [[]], ... }) >>> code_modifiers = ["modifier1"] - >>> stage_cfg = DictConfig({"aggregations": ["code/n_patients", "values/n_ints"]}) + >>> stage_cfg = DictConfig({"aggregations": ["code/n_subjects", "values/n_ints"]}) >>> reducer = reducer_fntr(stage_cfg, code_modifiers) >>> reducer(df_1, df_2, df_3) shape: (7, 4) ┌──────┬───────────┬─────────────────┬───────────────┐ - │ code ┆ modifier1 ┆ code/n_patients ┆ values/n_ints │ + │ code ┆ modifier1 ┆ code/n_subjects ┆ values/n_ints │ │ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ i64 ┆ i64 ┆ i64 │ ╞══════╪═══════════╪═════════════════╪═══════════════╡ @@ -562,9 +563,9 @@ def reducer_fntr( >>> cfg = DictConfig({ ... "code_modifiers": ["modifier1"], ... "code_processing_stages": { - ... "stage1": ["code/n_patients", "values/n_ints"], + ... "stage1": ["code/n_subjects", "values/n_ints"], ... "stage2": ["code/n_occurrences", "values/sum"], - ... "stage3.A": ["values/n_patients", "values/n_occurrences"], + ... "stage3.A": ["values/n_subjects", "values/n_occurrences"], ... "stage3.B": ["values/sum_sqd", "values/min", "values/max"], ... "stage4": ["INVALID"], ... } @@ -586,12 +587,12 @@ def reducer_fntr( │ C ┆ 2 ┆ 2 ┆ 12.5 │ │ D ┆ 1 ┆ 2 ┆ 2.0 │ └──────┴───────────┴────────────────────┴────────────┘ - >>> stage_cfg = DictConfig({"aggregations": ["values/n_patients", "values/n_occurrences"]}) + >>> stage_cfg = DictConfig({"aggregations": ["values/n_subjects", "values/n_occurrences"]}) >>> reducer = reducer_fntr(stage_cfg, code_modifiers) >>> reducer(df_1, df_2, df_3) shape: (7, 4) ┌──────┬───────────┬───────────────────┬──────────────────────┐ - │ code ┆ modifier1 ┆ values/n_patients ┆ values/n_occurrences │ + │ code ┆ modifier1 ┆ values/n_subjects ┆ values/n_occurrences │ │ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ i64 ┆ i64 ┆ i64 │ ╞══════╪═══════════╪═══════════════════╪══════════════════════╡ @@ -730,5 +731,5 @@ def main(cfg: DictConfig): run_map_reduce(cfg) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/configs/extract.yaml b/src/MEDS_transforms/configs/extract.yaml index ae3bf07..3abd498 100644 --- a/src/MEDS_transforms/configs/extract.yaml +++ b/src/MEDS_transforms/configs/extract.yaml @@ -2,7 +2,7 @@ defaults: - pipeline - stage_configs: - shard_events - - split_and_shard_patients + - split_and_shard_subjects - merge_to_MEDS_cohort - extract_code_metadata - finalize_MEDS_metadata @@ -32,7 +32,7 @@ shards_map_fp: "${cohort_dir}/metadata/.shards.json" stages: - shard_events - - split_and_shard_patients + - split_and_shard_subjects - convert_to_sharded_events - merge_to_MEDS_cohort - extract_code_metadata diff --git a/src/MEDS_transforms/configs/preprocess.yaml b/src/MEDS_transforms/configs/preprocess.yaml index ea509cd..dab87a9 100644 --- a/src/MEDS_transforms/configs/preprocess.yaml +++ b/src/MEDS_transforms/configs/preprocess.yaml @@ -2,7 +2,7 @@ defaults: - pipeline - stage_configs: - reshard_to_split - - filter_patients + - filter_subjects - add_time_derived_measurements - count_code_occurrences - filter_measurements @@ -24,7 +24,7 @@ code_modifiers: ??? # Pipeline Structure stages: - - filter_patients + - filter_subjects - add_time_derived_measurements - preliminary_counts - filter_measurements diff --git a/src/MEDS_transforms/configs/stage_configs/count_code_occurrences.yaml b/src/MEDS_transforms/configs/stage_configs/count_code_occurrences.yaml index 076a1a0..b17b74e 100644 --- a/src/MEDS_transforms/configs/stage_configs/count_code_occurrences.yaml +++ b/src/MEDS_transforms/configs/stage_configs/count_code_occurrences.yaml @@ -1,5 +1,5 @@ count_code_occurrences: aggregations: - "code/n_occurrences" - - "code/n_patients" + - "code/n_subjects" do_summarize_over_all_codes: true # This indicates we should include overall, code-independent counts diff --git a/src/MEDS_transforms/configs/stage_configs/filter_measurements.yaml b/src/MEDS_transforms/configs/stage_configs/filter_measurements.yaml index 12ff62f..0d0a5bd 100644 --- a/src/MEDS_transforms/configs/stage_configs/filter_measurements.yaml +++ b/src/MEDS_transforms/configs/stage_configs/filter_measurements.yaml @@ -1,3 +1,3 @@ filter_measurements: - min_patients_per_code: null + min_subjects_per_code: null min_occurrences_per_code: null diff --git a/src/MEDS_transforms/configs/stage_configs/filter_patients.yaml b/src/MEDS_transforms/configs/stage_configs/filter_patients.yaml deleted file mode 100644 index 70332b1..0000000 --- a/src/MEDS_transforms/configs/stage_configs/filter_patients.yaml +++ /dev/null @@ -1,3 +0,0 @@ -filter_patients: - min_events_per_patient: null - min_measurements_per_patient: null diff --git a/src/MEDS_transforms/configs/stage_configs/filter_subjects.yaml b/src/MEDS_transforms/configs/stage_configs/filter_subjects.yaml new file mode 100644 index 0000000..2706ffc --- /dev/null +++ b/src/MEDS_transforms/configs/stage_configs/filter_subjects.yaml @@ -0,0 +1,3 @@ +filter_subjects: + min_events_per_subject: null + min_measurements_per_subject: null diff --git a/src/MEDS_transforms/configs/stage_configs/fit_normalization.yaml b/src/MEDS_transforms/configs/stage_configs/fit_normalization.yaml index e522470..6bd90cb 100644 --- a/src/MEDS_transforms/configs/stage_configs/fit_normalization.yaml +++ b/src/MEDS_transforms/configs/stage_configs/fit_normalization.yaml @@ -1,7 +1,7 @@ fit_normalization: aggregations: - "code/n_occurrences" - - "code/n_patients" + - "code/n_subjects" - "values/n_occurrences" - "values/sum" - "values/sum_sqd" diff --git a/src/MEDS_transforms/configs/stage_configs/reshard_to_split.yaml b/src/MEDS_transforms/configs/stage_configs/reshard_to_split.yaml index 16dc505..fd0dc8a 100644 --- a/src/MEDS_transforms/configs/stage_configs/reshard_to_split.yaml +++ b/src/MEDS_transforms/configs/stage_configs/reshard_to_split.yaml @@ -1,2 +1,2 @@ reshard_to_split: - n_patients_per_shard: 50000 + n_subjects_per_shard: 50000 diff --git a/src/MEDS_transforms/configs/stage_configs/split_and_shard_patients.yaml b/src/MEDS_transforms/configs/stage_configs/split_and_shard_subjects.yaml similarity index 73% rename from src/MEDS_transforms/configs/stage_configs/split_and_shard_patients.yaml rename to src/MEDS_transforms/configs/stage_configs/split_and_shard_subjects.yaml index c4015bd..7dbed11 100644 --- a/src/MEDS_transforms/configs/stage_configs/split_and_shard_patients.yaml +++ b/src/MEDS_transforms/configs/stage_configs/split_and_shard_subjects.yaml @@ -1,7 +1,7 @@ -split_and_shard_patients: +split_and_shard_subjects: is_metadata: True output_dir: ${cohort_dir}/metadata - n_patients_per_shard: 50000 + n_subjects_per_shard: 50000 external_splits_json_fp: null split_fracs: train: 0.8 diff --git a/src/MEDS_transforms/extract/README.md b/src/MEDS_transforms/extract/README.md index 47f7e7d..a60a8c8 100644 --- a/src/MEDS_transforms/extract/README.md +++ b/src/MEDS_transforms/extract/README.md @@ -4,8 +4,8 @@ This directory contains the scripts and functions used to extract raw data into dataset is: 1. Arranged in a series of files on disk of an allowed format (e.g., `.csv`, `.csv.gz`, `.parquet`)... -2. Such that each file stores a dataframe containing data about patients such that each row of any given - table corresponds to zero or more observations about a patient at a given time... +2. Such that each file stores a dataframe containing data about subjects such that each row of any given + table corresponds to zero or more observations about a subject at a given time... 3. And you can configure how to extract those observations in the time, code, and numeric value format of MEDS in the event conversion `yaml` file format specified below, then... this tool can automatically extract your raw data into a MEDS dataset for you in an efficient, reproducible, @@ -53,7 +53,7 @@ step](#step-0-pre-meds) and the [Data Cleaning step](#step-3-data-cleanup), for ### Event Conversion Configuration The event conversion configuration file tells MEDS Extract how to convert each row of a file among your raw -data files into one or more MEDS measurements (meaning a tuple of a patient ID, a time, a categorical +data files into one or more MEDS measurements (meaning a tuple of a subject ID, a time, a categorical code, and/or various other value or properties columns, most commonly a numeric value). This file is written in yaml and has the following format: @@ -93,11 +93,11 @@ each row of the file will be converted into a MEDS event according to the logic here, as string literals _cannot_ be used for these columns. There are several more nuanced aspects to the configuration file that have not yet been discussed. First, the -configuration file also specifies how to identify the patient ID from the raw data. This can be done either by -specifying a global `patient_id_col` field at the top level of the configuration file, or by specifying a -`patient_id_col` field at the per-file or per-event level. Multiple specifications can be used simultaneously, -with the most local taking precedent. If no patient ID column is specified, the patient ID will be assumed to -be stored in a `patient_id` column. If the patient ID column is not found, an error will be raised. +configuration file also specifies how to identify the subject ID from the raw data. This can be done either by +specifying a global `subject_id_col` field at the top level of the configuration file, or by specifying a +`subject_id_col` field at the per-file or per-event level. Multiple specifications can be used simultaneously, +with the most local taking precedent. If no subject ID column is specified, the subject ID will be assumed to +be stored in a `subject_id` column. If the subject ID column is not found, an error will be raised. Second, you can also specify how to link the codes constructed for each event block to code-specific metadata in these blocks. This is done by specifying a `_metadata` block in the event block. The format of this block @@ -117,7 +117,7 @@ the [Partial MIMIC-IV Example](#partial-mimic-iv-example) below for an example o ```yaml subjects: - patient_id_col: MRN + subject_id_col: MRN eye_color: code: - EYE_COLOR @@ -144,7 +144,7 @@ admit_vitals: ##### Partial MIMIC-IV Example ```yaml -patient_id_col: subject_id +subject_id_col: subject_id hosp/admissions: admission: code: @@ -259,4 +259,4 @@ Note that this tool is _not_: TODO: Add issues for all of these. 1. Single event blocks for files should be specifiable directly, without an event block name. -2. Time format should be specifiable at the file or global level, like patient ID. +2. Time format should be specifiable at the file or global level, like subject ID. diff --git a/src/MEDS_transforms/extract/convert_to_sharded_events.py b/src/MEDS_transforms/extract/convert_to_sharded_events.py index ee4e9d7..8ac66ac 100755 --- a/src/MEDS_transforms/extract/convert_to_sharded_events.py +++ b/src/MEDS_transforms/extract/convert_to_sharded_events.py @@ -97,7 +97,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy """Extracts a single event dataframe from the raw data. Args: - df: The raw data DataFrame. This must have a `"patient_id"` column containing the patient ID. The + df: The raw data DataFrame. This must have a `"subject_id"` column containing the subject ID. The other columns it must have are determined by the `event_cfg` configuration dictionary. event_cfg: A dictionary containing the configuration for the event. This must contain two critical keys (`"code"` and `"time"`) and may contain additional keys for other columns to include @@ -128,7 +128,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy A DataFrame containing the event data extracted from the raw data, containing only unique rows across all columns. If the raw data has no duplicates when considering the event column space, the output dataframe will have the same number of rows as the raw data and be in the same order. The output - dataframe will contain at least three columns: `"patient_id"`, `"code"`, and `"time"`. If the + dataframe will contain at least three columns: `"subject_id"`, `"code"`, and `"time"`. If the event has additional columns, they will be included in the output dataframe as well. **_Events that would be extracted with a null code or a time that should be specified via a column with or without a formatting option but in practice is null will be dropped._** Note that this dropping logic @@ -145,7 +145,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> _ = pl.Config.set_tbl_rows(20) >>> _ = pl.Config.set_tbl_cols(20) >>> raw_data = pl.DataFrame({ - ... "patient_id": [1, 1, 2, 2], + ... "subject_id": [1, 1, 2, 2], ... "code": ["A", "B", "C", "D"], ... "code_modifier": ["1", "2", "3", "4"], ... "time": ["2021-01-01", "2021-01-02", "2021-01-03", "2021-01-04"], @@ -160,7 +160,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> extract_event(raw_data, event_cfg) shape: (4, 4) ┌────────────┬───────────┬─────────────────────┬───────────────┐ - │ patient_id ┆ code ┆ time ┆ numeric_value │ + │ subject_id ┆ code ┆ time ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ datetime[μs] ┆ i64 │ ╞════════════╪═══════════╪═════════════════════╪═══════════════╡ @@ -170,7 +170,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy │ 2 ┆ FOO//D//4 ┆ 2021-01-04 00:00:00 ┆ 4 │ └────────────┴───────────┴─────────────────────┴───────────────┘ >>> data_with_nulls = pl.DataFrame({ - ... "patient_id": [1, 1, 2, 2], + ... "subject_id": [1, 1, 2, 2], ... "code": ["A", None, "C", "D"], ... "code_modifier": ["1", "2", "3", None], ... "time": [None, "2021-01-02", "2021-01-03", "2021-01-04"], @@ -185,7 +185,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> extract_event(data_with_nulls, event_cfg) shape: (2, 4) ┌────────────┬────────┬─────────────────────┬───────────────┐ - │ patient_id ┆ code ┆ time ┆ numeric_value │ + │ subject_id ┆ code ┆ time ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ datetime[μs] ┆ i64 │ ╞════════════╪════════╪═════════════════════╪═══════════════╡ @@ -195,7 +195,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> from datetime import datetime >>> complex_raw_data = pl.DataFrame( ... { - ... "patient_id": [1, 1, 2, 2, 2, 3], + ... "subject_id": [1, 1, 2, 2, 2, 3], ... "admission_time": [ ... "2021-01-01 00:00:00", ... "2021-01-02 00:00:00", @@ -227,7 +227,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy ... "eye_color": ["blue", "blue", "green", "green", "green", "brown"], ... }, ... schema={ - ... "patient_id": pl.UInt8, + ... "subject_id": pl.UInt8, ... "admission_time": pl.Utf8, ... "discharge_time": pl.Datetime, ... "admission_type": pl.Utf8, @@ -263,7 +263,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> complex_raw_data shape: (6, 9) ┌────────────┬─────────────────────┬─────────────────────┬────────────────┬────────────────────┬──────────────────┬────────────────┬────────────┬───────────┐ - │ patient_id ┆ admission_time ┆ discharge_time ┆ admission_type ┆ discharge_location ┆ discharge_status ┆ severity_score ┆ death_time ┆ eye_color │ + │ subject_id ┆ admission_time ┆ discharge_time ┆ admission_type ┆ discharge_location ┆ discharge_status ┆ severity_score ┆ death_time ┆ eye_color │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] ┆ str ┆ cat ┆ str ┆ f64 ┆ str ┆ cat │ ╞════════════╪═════════════════════╪═════════════════════╪════════════════╪════════════════════╪══════════════════╪════════════════╪════════════╪═══════════╡ @@ -277,7 +277,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> extract_event(complex_raw_data, valid_admission_event_cfg) shape: (6, 4) ┌────────────┬──────────────┬─────────────────────┬───────────────┐ - │ patient_id ┆ code ┆ time ┆ numeric_value │ + │ subject_id ┆ code ┆ time ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] ┆ f64 │ ╞════════════╪══════════════╪═════════════════════╪═══════════════╡ @@ -294,7 +294,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy ... ) shape: (6, 4) ┌────────────┬──────────────┬─────────────────────┬───────────────┐ - │ patient_id ┆ code ┆ time ┆ numeric_value │ + │ subject_id ┆ code ┆ time ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] ┆ f64 │ ╞════════════╪══════════════╪═════════════════════╪═══════════════╡ @@ -311,7 +311,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy ... ) shape: (6, 4) ┌────────────┬──────────────┬─────────────────────┬───────────────┐ - │ patient_id ┆ code ┆ time ┆ numeric_value │ + │ subject_id ┆ code ┆ time ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] ┆ f64 │ ╞════════════╪══════════════╪═════════════════════╪═══════════════╡ @@ -325,7 +325,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> extract_event(complex_raw_data, valid_discharge_event_cfg) shape: (6, 5) ┌────────────┬─────────────────┬─────────────────────┬───────────────────┬────────────┐ - │ patient_id ┆ code ┆ time ┆ categorical_value ┆ text_value │ + │ subject_id ┆ code ┆ time ┆ categorical_value ┆ text_value │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] ┆ str ┆ str │ ╞════════════╪═════════════════╪═════════════════════╪═══════════════════╪════════════╡ @@ -339,7 +339,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> extract_event(complex_raw_data, valid_death_event_cfg) shape: (3, 3) ┌────────────┬───────┬─────────────────────┐ - │ patient_id ┆ code ┆ time │ + │ subject_id ┆ code ┆ time │ │ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] │ ╞════════════╪═══════╪═════════════════════╡ @@ -351,7 +351,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> extract_event(complex_raw_data, valid_static_event_cfg) shape: (3, 3) ┌────────────┬──────────────────┬──────────────┐ - │ patient_id ┆ code ┆ time │ + │ subject_id ┆ code ┆ time │ │ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] │ ╞════════════╪══════════════════╪══════════════╡ @@ -371,10 +371,10 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy Traceback (most recent call last): ... ValueError: Invalid time literal: 12-01-23 - >>> extract_event(complex_raw_data, {"code": "test", "time": None, "patient_id": 3}) + >>> extract_event(complex_raw_data, {"code": "test", "time": None, "subject_id": 3}) Traceback (most recent call last): ... - KeyError: "Event column name 'patient_id' cannot be overridden." + KeyError: "Event column name 'subject_id' cannot be overridden." >>> extract_event(complex_raw_data, {"code": "test", "time": None, "foobar": "fuzz"}) Traceback (most recent call last): ... @@ -389,7 +389,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy ValueError: Source column 'discharge_time' for event column foobar is not numeric, string, or categorical! Cannot be used as an event col. """ # noqa: E501 event_cfg = copy.deepcopy(event_cfg) - event_exprs = {"patient_id": pl.col("patient_id")} + event_exprs = {"subject_id": pl.col("subject_id")} if "code" not in event_cfg: raise KeyError( @@ -401,8 +401,8 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy "Event configuration dictionary must contain 'time' key. " f"Got: [{', '.join(event_cfg.keys())}]." ) - if "patient_id" in event_cfg: - raise KeyError("Event column name 'patient_id' cannot be overridden.") + if "subject_id" in event_cfg: + raise KeyError("Event column name 'subject_id' cannot be overridden.") code_expr, code_null_filter_expr, needed_cols = get_code_expr(event_cfg.pop("code")) @@ -502,7 +502,7 @@ def convert_to_events( """Converts a DataFrame of raw data into a DataFrame of events. Args: - df: The raw data DataFrame. This must have a `"patient_id"` column containing the patient ID. The + df: The raw data DataFrame. This must have a `"subject_id"` column containing the subject ID. The other columns it must have are determined by the `event_cfgs` configuration dictionary. For the precise mechanism of column determination, see the `extract_event` function. event_cfgs: A dictionary containing the configurations for the events to extract. The keys of this @@ -518,7 +518,7 @@ def convert_to_events( events extracted from the raw data, with the rows from each event DataFrame concatenated together. After concatenation, this dataframe will not be deduplicated, so if the raw data results in duplicates across events of different name, these will be preserved in the output DataFrame. - The output DataFrame will contain at least three columns: `"patient_id"`, `"code"`, and `"time"`. + The output DataFrame will contain at least three columns: `"subject_id"`, `"code"`, and `"time"`. If any events have additional columns, these will be included in the output DataFrame as well. All columns across all event configurations will be included in the output DataFrame, with `null` values filled in for events that do not have a particular column. @@ -533,7 +533,7 @@ def convert_to_events( >>> from datetime import datetime >>> complex_raw_data = pl.DataFrame( ... { - ... "patient_id": [1, 1, 2, 2, 2, 3], + ... "subject_id": [1, 1, 2, 2, 2, 3], ... "admission_time": [ ... "2021-01-01 00:00:00", ... "2021-01-02 00:00:00", @@ -564,7 +564,7 @@ def convert_to_events( ... "eye_color": ["blue", "blue", "green", "green", "green", "brown"], ... }, ... schema={ - ... "patient_id": pl.UInt8, + ... "subject_id": pl.UInt8, ... "admission_time": pl.Utf8, ... "discharge_time": pl.Datetime, ... "admission_type": pl.Utf8, @@ -602,7 +602,7 @@ def convert_to_events( >>> complex_raw_data shape: (6, 8) ┌────────────┬─────────────────────┬─────────────────────┬────────────────┬────────────────────┬────────────────┬────────────┬───────────┐ - │ patient_id ┆ admission_time ┆ discharge_time ┆ admission_type ┆ discharge_location ┆ severity_score ┆ death_time ┆ eye_color │ + │ subject_id ┆ admission_time ┆ discharge_time ┆ admission_type ┆ discharge_location ┆ severity_score ┆ death_time ┆ eye_color │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] ┆ str ┆ cat ┆ f64 ┆ str ┆ cat │ ╞════════════╪═════════════════════╪═════════════════════╪════════════════╪════════════════════╪════════════════╪════════════╪═══════════╡ @@ -616,7 +616,7 @@ def convert_to_events( >>> convert_to_events(complex_raw_data, event_cfgs) shape: (18, 7) ┌────────────┬───────────┬─────────────────────┬────────────────┬───────────────────────┬────────────────────┬───────────┐ - │ patient_id ┆ code ┆ time ┆ admission_type ┆ severity_on_admission ┆ discharge_location ┆ eye_color │ + │ subject_id ┆ code ┆ time ┆ admission_type ┆ severity_on_admission ┆ discharge_location ┆ eye_color │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] ┆ str ┆ f64 ┆ cat ┆ cat │ ╞════════════╪═══════════╪═════════════════════╪════════════════╪═══════════════════════╪════════════════════╪═══════════╡ @@ -666,7 +666,7 @@ def convert_to_events( @hydra.main(version_base=None, config_path=str(CONFIG_YAML.parent), config_name=CONFIG_YAML.stem) def main(cfg: DictConfig): - """Converts the event-sharded raw data into MEDS events and storing them in patient subsharded flat files. + """Converts the event-sharded raw data into MEDS events and storing them in subject subsharded flat files. All arguments are specified through the command line into the `cfg` object through Hydra. @@ -680,7 +680,7 @@ def main(cfg: DictConfig): file. """ - input_dir, patient_subsharded_dir, metadata_input_dir = stage_init(cfg) + input_dir, subject_subsharded_dir, metadata_input_dir = stage_init(cfg) shards = json.loads(Path(cfg.shards_map_fp).read_text()) @@ -694,13 +694,13 @@ def main(cfg: DictConfig): event_conversion_cfg = OmegaConf.load(event_conversion_cfg_fp) logger.info(f"Event conversion config:\n{OmegaConf.to_yaml(event_conversion_cfg)}") - default_patient_id_col = event_conversion_cfg.pop("patient_id_col", "patient_id") + default_subject_id_col = event_conversion_cfg.pop("subject_id_col", "subject_id") - patient_subsharded_dir.mkdir(parents=True, exist_ok=True) - OmegaConf.save(event_conversion_cfg, patient_subsharded_dir / "event_conversion_config.yaml") + subject_subsharded_dir.mkdir(parents=True, exist_ok=True) + OmegaConf.save(event_conversion_cfg, subject_subsharded_dir / "event_conversion_config.yaml") - patient_splits = list(shards.items()) - random.shuffle(patient_splits) + subject_splits = list(shards.items()) + random.shuffle(subject_splits) event_configs = list(event_conversion_cfg.items()) random.shuffle(event_configs) @@ -708,28 +708,28 @@ def main(cfg: DictConfig): # Here, we'll be reading files directly, so we'll turn off globbing read_fn = partial(pl.scan_parquet, glob=False) - for sp, patients in patient_splits: + for sp, subjects in subject_splits: for input_prefix, event_cfgs in event_configs: event_cfgs = copy.deepcopy(event_cfgs) - input_patient_id_column = event_cfgs.pop("patient_id_col", default_patient_id_col) + input_subject_id_column = event_cfgs.pop("subject_id_col", default_subject_id_col) event_shards = list((input_dir / input_prefix).glob("*.parquet")) random.shuffle(event_shards) for shard_fp in event_shards: - out_fp = patient_subsharded_dir / sp / input_prefix / shard_fp.name + out_fp = subject_subsharded_dir / sp / input_prefix / shard_fp.name logger.info(f"Converting {shard_fp} to events and saving to {out_fp}") def compute_fn(df: pl.LazyFrame) -> pl.LazyFrame: - typed_patients = pl.Series(patients, dtype=df.schema[input_patient_id_column]) + typed_subjects = pl.Series(subjects, dtype=df.schema[input_subject_id_column]) - if input_patient_id_column != "patient_id": - df = df.rename({input_patient_id_column: "patient_id"}) + if input_subject_id_column != "subject_id": + df = df.rename({input_subject_id_column: "subject_id"}) try: logger.info(f"Extracting events for {input_prefix}/{shard_fp.name}") return convert_to_events( - df.filter(pl.col("patient_id").is_in(typed_patients)), + df.filter(pl.col("subject_id").is_in(typed_subjects)), event_cfgs=copy.deepcopy(event_cfgs), ) except Exception as e: @@ -744,5 +744,5 @@ def compute_fn(df: pl.LazyFrame) -> pl.LazyFrame: logger.info("Subsharded into converted events.") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/extract/extract_code_metadata.py b/src/MEDS_transforms/extract/extract_code_metadata.py index e9133eb..31d883c 100644 --- a/src/MEDS_transforms/extract/extract_code_metadata.py +++ b/src/MEDS_transforms/extract/extract_code_metadata.py @@ -250,9 +250,9 @@ def get_events_and_metadata_by_metadata_fp(event_configs: dict | DictConfig) -> Examples: >>> event_configs = { - ... "patient_id_col": "MRN", + ... "subject_id_col": "MRN", ... "icu/procedureevents": { - ... "patient_id_col": "subject_id", + ... "subject_id_col": "subject_id", ... "start": { ... "code": ["PROCEDURE", "START", "col(itemid)"], ... "_metadata": { @@ -304,11 +304,11 @@ def get_events_and_metadata_by_metadata_fp(event_configs: dict | DictConfig) -> out = {} for file_pfx, event_cfgs_for_pfx in event_configs.items(): - if file_pfx == "patient_id_col": + if file_pfx == "subject_id_col": continue for event_key, event_cfg in event_cfgs_for_pfx.items(): - if event_key == "patient_id_col": + if event_key == "subject_id_col": continue for metadata_pfx, metadata_cfg in event_cfg.get("_metadata", {}).items(): @@ -449,5 +449,5 @@ def reducer_fn(*dfs): logger.info(f"Finished reduction in {datetime.now() - start}") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/extract/finalize_MEDS_data.py b/src/MEDS_transforms/extract/finalize_MEDS_data.py index f9d6873..e8a2b6a 100644 --- a/src/MEDS_transforms/extract/finalize_MEDS_data.py +++ b/src/MEDS_transforms/extract/finalize_MEDS_data.py @@ -38,7 +38,7 @@ def get_and_validate_data_schema(df: pl.LazyFrame, stage_cfg: DictConfig) -> pa. >>> get_and_validate_data_schema(df.lazy(), dict(do_retype=False)) # doctest: +NORMALIZE_WHITESPACE Traceback (most recent call last): ... - ValueError: MEDS Data DataFrame must have a 'patient_id' column of type Int64. + ValueError: MEDS Data DataFrame must have a 'subject_id' column of type Int64. MEDS Data DataFrame must have a 'time' column of type Datetime(time_unit='us', time_zone=None). MEDS Data DataFrame must have a 'code' column of type String. @@ -46,28 +46,28 @@ def get_and_validate_data_schema(df: pl.LazyFrame, stage_cfg: DictConfig) -> pa. >>> get_and_validate_data_schema(df.lazy(), {}) # doctest: +NORMALIZE_WHITESPACE Traceback (most recent call last): ... - ValueError: MEDS Data DataFrame must have a 'patient_id' column of type Int64. + ValueError: MEDS Data DataFrame must have a 'subject_id' column of type Int64. MEDS Data DataFrame must have a 'code' column of type String. >>> from datetime import datetime >>> df = pl.DataFrame({ - ... "patient_id": pl.Series([1, 2], dtype=pl.UInt32), + ... "subject_id": pl.Series([1, 2], dtype=pl.UInt32), ... "time": [datetime(2021, 1, 1), datetime(2021, 1, 2)], ... "code": ["A", "B"], "text_value": ["1", None], "numeric_value": [None, 34.2] ... }) >>> get_and_validate_data_schema(df.lazy(), dict(do_retype=False)) # doctest: +NORMALIZE_WHITESPACE Traceback (most recent call last): ... - ValueError: MEDS Data 'patient_id' column must be of type Int64. Got UInt32. + ValueError: MEDS Data 'subject_id' column must be of type Int64. Got UInt32. MEDS Data 'numeric_value' column must be of type Float32. Got Float64. >>> get_and_validate_data_schema(df.lazy(), {}) pyarrow.Table - patient_id: int64 + subject_id: int64 time: timestamp[us] code: string numeric_value: float text_value: large_string ---- - patient_id: [[1,2]] + subject_id: [[1,2]] time: [[2021-01-01 00:00:00.000000,2021-01-02 00:00:00.000000]] code: [["A","B"]] numeric_value: [[null,34.2]] @@ -111,7 +111,7 @@ def main(cfg: DictConfig): """Writes out schema compliant MEDS data files for the extracted dataset. In particular, this script ensures that all shard files are MEDS compliant with the mandatory columns - - `patient_id` (Int64) + - `subject_id` (Int64) - `time` (DateTime) - `code` (String) - `numeric_value` (Float32) @@ -134,5 +134,5 @@ def main(cfg: DictConfig): map_over(cfg, compute_fn=get_and_validate_data_schema, write_fn=pq.write_table) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/extract/finalize_MEDS_metadata.py b/src/MEDS_transforms/extract/finalize_MEDS_metadata.py index 366d89a..6554930 100755 --- a/src/MEDS_transforms/extract/finalize_MEDS_metadata.py +++ b/src/MEDS_transforms/extract/finalize_MEDS_metadata.py @@ -15,7 +15,8 @@ code_metadata_schema, dataset_metadata_schema, held_out_split, - patient_split_schema, + subject_id_field, + subject_split_schema, train_split, tuning_split, ) @@ -121,8 +122,8 @@ def main(cfg: DictConfig): - `etl_name` (string) - `etl_version` (string) - `meds_version` (string) - (3) a `metadata/patient_splits.parquet` file exists that has the mandatory columns - - `patient_id` (Int64) + (3) a `metadata/subject_splits.parquet` file exists that has the mandatory columns + - `subject_id` (Int64) - `split` (string) This stage *_should almost always be the last metadata stage in an extraction pipeline._* @@ -151,9 +152,9 @@ def main(cfg: DictConfig): output_code_metadata_fp = output_metadata_dir / "codes.parquet" dataset_metadata_fp = output_metadata_dir / "dataset.json" - patient_splits_fp = output_metadata_dir / "patient_splits.parquet" + subject_splits_fp = output_metadata_dir / "subject_splits.parquet" - for out_fp in [output_code_metadata_fp, dataset_metadata_fp, patient_splits_fp]: + for out_fp in [output_code_metadata_fp, dataset_metadata_fp, subject_splits_fp]: out_fp.parent.mkdir(parents=True, exist_ok=True) if out_fp.exists() and cfg.do_overwrite: out_fp.unlink() @@ -194,29 +195,29 @@ def main(cfg: DictConfig): # Split creation shards_map_fp = Path(cfg.shards_map_fp) - logger.info("Creating patient splits from {str(shards_map_fp.resolve())}") + logger.info("Creating subject splits from {str(shards_map_fp.resolve())}") shards_map = json.loads(shards_map_fp.read_text()) - patient_splits = [] + subject_splits = [] seen_splits = {train_split: 0, tuning_split: 0, held_out_split: 0} - for shard, patient_ids in shards_map.items(): + for shard, subject_ids in shards_map.items(): split = "/".join(shard.split("/")[:-1]) if split not in seen_splits: seen_splits[split] = 0 - seen_splits[split] += len(patient_ids) + seen_splits[split] += len(subject_ids) - patient_splits.extend([{"patient_id": pid, "split": split} for pid in patient_ids]) + subject_splits.extend([{subject_id_field: pid, "split": split} for pid in subject_ids]) for split, cnt in seen_splits.items(): if cnt: - logger.info(f"Split {split} has {cnt} patients") + logger.info(f"Split {split} has {cnt} subjects") else: logger.warning(f"Split {split} not found in shards map") - patient_splits_tbl = pa.Table.from_pylist(patient_splits, schema=patient_split_schema) - logger.info(f"Writing finalized patient splits to {str(patient_splits_fp.resolve())}") - pq.write_table(patient_splits_tbl, patient_splits_fp) + subject_splits_tbl = pa.Table.from_pylist(subject_splits, schema=subject_split_schema) + logger.info(f"Writing finalized subject splits to {str(subject_splits_fp.resolve())}") + pq.write_table(subject_splits_tbl, subject_splits_fp) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py b/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py index 49a45e1..2b75eb0 100755 --- a/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py +++ b/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py @@ -30,8 +30,8 @@ def merge_subdirs_and_sort( warning is logged, but an error is *not* raised. Which rows are retained if the uniqeu-by columns are not all columns is not guaranteed, but is also *not* random, so this may have statistical implications. - additional_sort_by: Additional columns to sort by, in addition to the default sorting by patient ID - and time. If `None`, only patient ID and time are used. If a list of strings, these + additional_sort_by: Additional columns to sort by, in addition to the default sorting by subject ID + and time. If `None`, only subject ID and time are used. If a list of strings, these columns are used in addition to the default sorting. If a column is not found in the dataframe, it is omitted from the sort-by, a warning is logged, but an error is *not* raised. This functionality is useful both for deterministic testing and in cases where a data owner wants to impose @@ -41,7 +41,7 @@ def merge_subdirs_and_sort( A single dataframe containing all the data from the parquet files in the subdirs of `sp_dir`. These files will be concatenated diagonally, taking the union of all rows in all dataframes and all unique columns in all dataframes to form the merged output. The returned dataframe will be made unique by the - columns specified in `unique_by` and sorted by first patient ID, then time, then all columns in + columns specified in `unique_by` and sorted by first subject ID, then time, then all columns in `additional_sort_by`, if any. Raises: @@ -50,15 +50,15 @@ def merge_subdirs_and_sort( Examples: >>> from tempfile import TemporaryDirectory - >>> df1 = pl.DataFrame({"patient_id": [1, 2], "time": [10, 20], "code": ["A", "B"]}) + >>> df1 = pl.DataFrame({"subject_id": [1, 2], "time": [10, 20], "code": ["A", "B"]}) >>> df2 = pl.DataFrame({ - ... "patient_id": [1, 1, 3], + ... "subject_id": [1, 1, 3], ... "time": [2, 1, 8], ... "code": ["C", "D", "E"], ... "numeric_value": [None, 2.0, None], ... }) >>> df3 = pl.DataFrame({ - ... "patient_id": [1, 1, 3], + ... "subject_id": [1, 1, 3], ... "time": [2, 2, 8], ... "code": ["C", "D", "E"], ... "numeric_value": [6.2, 2.0, None], @@ -84,7 +84,7 @@ def merge_subdirs_and_sort( ... ).collect() shape: (8, 4) ┌────────────┬──────┬──────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str ┆ f64 │ ╞════════════╪══════╪══════╪═══════════════╡ @@ -112,7 +112,7 @@ def merge_subdirs_and_sort( ... ).collect() shape: (7, 4) ┌────────────┬──────┬──────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str ┆ f64 │ ╞════════════╪══════╪══════╪═══════════════╡ @@ -131,18 +131,18 @@ def merge_subdirs_and_sort( ... df2.write_parquet(sp_dir / "subdir1" / "file2.parquet") ... (sp_dir / "subdir2").mkdir() ... df3.write_parquet(sp_dir / "subdir2" / "df.parquet") - ... # We just display the patient ID, time, and code columns as the numeric value column + ... # We just display the subject ID, time, and code columns as the numeric value column ... # is not guaranteed to be deterministic in the output given some rows will be dropped due to ... # the unique-by constraint. ... merge_subdirs_and_sort( ... sp_dir, ... event_subsets=["subdir1", "subdir2"], - ... unique_by=["patient_id", "time", "code"], + ... unique_by=["subject_id", "time", "code"], ... additional_sort_by=["code", "numeric_value"] - ... ).select("patient_id", "time", "code").collect() + ... ).select("subject_id", "time", "code").collect() shape: (6, 3) ┌────────────┬──────┬──────┐ - │ patient_id ┆ time ┆ code │ + │ subject_id ┆ time ┆ code │ │ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str │ ╞════════════╪══════╪══════╡ @@ -188,7 +188,7 @@ def merge_subdirs_and_sort( case _: raise ValueError(f"Invalid unique_by value: {unique_by}") - sort_by = ["patient_id", "time"] + sort_by = ["subject_id", "time"] if additional_sort_by is not None: for s in additional_sort_by: if s in df_columns: @@ -201,12 +201,12 @@ def merge_subdirs_and_sort( @hydra.main(version_base=None, config_path=str(CONFIG_YAML.parent), config_name=CONFIG_YAML.stem) def main(cfg: DictConfig): - """Merges the patient sub-sharded events into a single parquet file per patient shard. + """Merges the subject sub-sharded events into a single parquet file per subject shard. This function takes all dataframes (in parquet files) in any subdirs of the `cfg.stage_cfg.input_dir` and merges them into a single dataframe. All dataframes in the subdirs are assumed to be in the unnested, MEDS - format, and cover the same group of patients (specific to the shard being processed). The merged dataframe - will also be sorted by patient ID and time. + format, and cover the same group of subjects (specific to the shard being processed). The merged dataframe + will also be sorted by subject ID and time. All arguments are specified through the command line into the `cfg` object through Hydra. @@ -219,14 +219,14 @@ def main(cfg: DictConfig): stage_configs.merge_to_MEDS_cohort.unique_by: The list of columns that should be ensured to be unique after the dataframes are merged. Defaults to `"*"`, which means all columns are used. stage_configs.merge_to_MEDS_cohort.additional_sort_by: Additional columns to sort by, in addition to - the default sorting by patient ID and time. Defaults to `None`, which means only patient ID + the default sorting by subject ID and time. Defaults to `None`, which means only subject ID and time are used. Returns: Writes the merged dataframes to the shard-specific output filepath in the `cfg.stage_cfg.output_dir`. """ event_conversion_cfg = OmegaConf.load(cfg.event_conversion_config_fp) - event_conversion_cfg.pop("patient_id_col", None) + event_conversion_cfg.pop("subject_id_col", None) read_fn = partial( merge_subdirs_and_sort, @@ -242,5 +242,5 @@ def main(cfg: DictConfig): ) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/extract/shard_events.py b/src/MEDS_transforms/extract/shard_events.py index 18450bb..cd8e474 100755 --- a/src/MEDS_transforms/extract/shard_events.py +++ b/src/MEDS_transforms/extract/shard_events.py @@ -11,6 +11,7 @@ import hydra import polars as pl from loguru import logger +from meds import subject_id_field from omegaconf import DictConfig, OmegaConf from MEDS_transforms.extract import CONFIG_YAML @@ -169,7 +170,7 @@ def retrieve_columns(event_conversion_cfg: DictConfig) -> dict[str, list[str]]: event conversion configurations that are specific to each file based on its stem (filename without the extension). It compiles a list of column names needed for each file from the configuration, which includes both general - columns like row index and patient ID, as well as specific columns defined + columns like row index and subject ID, as well as specific columns defined for medical events and times formatted in a special 'col(column_name)' syntax. Args: @@ -185,7 +186,7 @@ def retrieve_columns(event_conversion_cfg: DictConfig) -> dict[str, list[str]]: Examples: >>> cfg = DictConfig({ - ... "patient_id_col": "patient_id_global", + ... "subject_id_col": "subject_id_global", ... "hosp/patients": { ... "eye_color": { ... "code": ["EYE_COLOR", "col(eye_color)"], "time": None, "mod": "mod_col" @@ -195,7 +196,7 @@ def retrieve_columns(event_conversion_cfg: DictConfig) -> dict[str, list[str]]: ... } ... }, ... "icu/chartevents": { - ... "patient_id_col": "patient_id_icu", + ... "subject_id_col": "subject_id_icu", ... "heart_rate": { ... "code": "HEART_RATE", "time": "charttime", "numeric_value": "HR" ... }, @@ -212,19 +213,19 @@ def retrieve_columns(event_conversion_cfg: DictConfig) -> dict[str, list[str]]: ... } ... }) >>> retrieve_columns(cfg) # doctest: +NORMALIZE_WHITESPACE - {'hosp/patients': ['eye_color', 'height', 'mod_col', 'patient_id_global'], - 'icu/chartevents': ['HR', 'charttime', 'itemid', 'mod_lab', 'patient_id_icu', 'value', 'valuenum', + {'hosp/patients': ['eye_color', 'height', 'mod_col', 'subject_id_global'], + 'icu/chartevents': ['HR', 'charttime', 'itemid', 'mod_lab', 'subject_id_icu', 'value', 'valuenum', 'valueuom'], - 'icu/meds': ['medication', 'medtime', 'patient_id_global']} + 'icu/meds': ['medication', 'medtime', 'subject_id_global']} >>> cfg = DictConfig({ ... "subjects": { - ... "patient_id_col": "MRN", + ... "subject_id_col": "MRN", ... "eye_color": {"code": ["col(eye_color)"], "time": None}, ... }, ... "labs": {"lab": {"code": "col(labtest)", "time": "charttime"}}, ... }) >>> retrieve_columns(cfg) - {'subjects': ['MRN', 'eye_color'], 'labs': ['charttime', 'labtest', 'patient_id']} + {'subjects': ['MRN', 'eye_color'], 'labs': ['charttime', 'labtest', 'subject_id']} """ event_conversion_cfg = copy.deepcopy(event_conversion_cfg) @@ -232,11 +233,11 @@ def retrieve_columns(event_conversion_cfg: DictConfig) -> dict[str, list[str]]: # Initialize a dictionary to store file paths as keys and lists of column names as values. prefix_to_columns = {} - default_patient_id_col = event_conversion_cfg.pop("patient_id_col", "patient_id") + default_subject_id_col = event_conversion_cfg.pop("subject_id_col", subject_id_field) for input_prefix, event_cfgs in event_conversion_cfg.items(): - input_patient_id_column = event_cfgs.pop("patient_id_col", default_patient_id_col) + input_subject_id_column = event_cfgs.pop("subject_id_col", default_subject_id_col) - prefix_to_columns[input_prefix] = {input_patient_id_column} + prefix_to_columns[input_prefix] = {input_subject_id_column} for event_cfg in event_cfgs.values(): # If the config has a 'code' key and it contains column fields, parse and add them. @@ -429,5 +430,5 @@ def main(cfg: DictConfig): logger.info(f"Sub-sharding completed in {datetime.now() - start}") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/extract/split_and_shard_patients.py b/src/MEDS_transforms/extract/split_and_shard_subjects.py similarity index 71% rename from src/MEDS_transforms/extract/split_and_shard_patients.py rename to src/MEDS_transforms/extract/split_and_shard_subjects.py index a385c73..0dc3342 100755 --- a/src/MEDS_transforms/extract/split_and_shard_patients.py +++ b/src/MEDS_transforms/extract/split_and_shard_subjects.py @@ -14,75 +14,75 @@ from MEDS_transforms.utils import stage_init -def shard_patients( - patients: np.ndarray, - n_patients_per_shard: int = 50000, +def shard_subjects( + subjects: np.ndarray, + n_subjects_per_shard: int = 50000, external_splits: dict[str, Sequence[int]] | None = None, split_fracs_dict: dict[str, float] | None = {"train": 0.8, "tuning": 0.1, "held_out": 0.1}, seed: int = 1, ) -> dict[str, list[int]]: - """Shard a list of patients, nested within train/tuning/held-out splits. + """Shard a list of subjects, nested within train/tuning/held-out splits. - This function takes a list of patients and shards them into train/tuning/held-out splits, with the shards + This function takes a list of subjects and shards them into train/tuning/held-out splits, with the shards of a consistent size, nested within the splits. The function will also respect external splits, if provided, such that mandated splits (such as prospective held out sets or pre-existing, task-specific held out sets) are with-held and sharded as separate splits from the IID splits defined by `split_fracs_dict`. It returns a dictionary mapping the split and shard names (realized as f"{split}/{shard}") to the list of - patients in that shard. + subjects in that shard. Args: - patients: The list of patients to shard. - n_patients_per_shard: The maximum number of patients to include in each shard. + subjects: The list of subjects to shard. + n_subjects_per_shard: The maximum number of subjects to include in each shard. external_splits: The externally defined splits to respect. If provided, the keys of this dictionary - will be used as split names, and the values as the list of patients in that split. These + will be used as split names, and the values as the list of subjects in that split. These pre-defined splits will be excluded from IID splits generated by this function, but will be sharded like normal. Note that this is largely only appropriate for held-out sets for pre-defined - tasks or test cases (e.g., prospective tests); training patients should often still be included in + tasks or test cases (e.g., prospective tests); training subjects should often still be included in the IID splits to maximize the amount of data that can be used for training. - split_fracs_dict: A dictionary mapping the split name to the fraction of patients to include in that + split_fracs_dict: A dictionary mapping the split name to the fraction of subjects to include in that split. Defaults to 80% train, 10% tuning, 10% held-out. This can be None or empty only when external splits fully specify the population. - seed: The random seed to use for shuffling the patients before seeding and sharding. This is useful + seed: The random seed to use for shuffling the subjects before seeding and sharding. This is useful for ensuring reproducibility. Returns: - A dictionary mapping f"{split}/{shard}" to the list of patients in that shard. This may include - overlapping patients across a subset of these splits, but never across shards within a split. Any + A dictionary mapping f"{split}/{shard}" to the list of subjects in that shard. This may include + overlapping subjects across a subset of these splits, but never across shards within a split. Any overlap will solely occur between the an external split and another external split. Raises: ValueError: If the sum of the split fractions in `split_fracs_dict` is not equal to 1. Examples: - >>> patients = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=int) - >>> shard_patients(patients, n_patients_per_shard=3) + >>> subjects = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=int) + >>> shard_subjects(subjects, n_subjects_per_shard=3) {'train/0': [9, 4, 8], 'train/1': [2, 1, 10], 'train/2': [6, 5], 'tuning/0': [3], 'held_out/0': [7]} >>> external_splits = { ... 'taskA/held_out': np.array([8, 9, 10], dtype=int), ... 'taskB/held_out': np.array([10, 8, 9], dtype=int), ... } - >>> shard_patients(patients, 3, external_splits) # doctest: +NORMALIZE_WHITESPACE + >>> shard_subjects(subjects, 3, external_splits) # doctest: +NORMALIZE_WHITESPACE {'train/0': [5, 7, 4], 'train/1': [1, 2], 'tuning/0': [3], 'held_out/0': [6], 'taskA/held_out/0': [8, 9, 10], 'taskB/held_out/0': [10, 8, 9]} - >>> shard_patients(patients, n_patients_per_shard=3, split_fracs_dict={'train': 0.5}) + >>> shard_subjects(subjects, n_subjects_per_shard=3, split_fracs_dict={'train': 0.5}) Traceback (most recent call last): ... ValueError: The sum of the split fractions must be equal to 1. Got 0.5 through {'train': 0.5}. - >>> shard_patients([1, 2], n_patients_per_shard=3) + >>> shard_subjects([1, 2], n_subjects_per_shard=3) Traceback (most recent call last): ... - ValueError: Unable to adjust splits to ensure all splits have at least 1 patient. + ValueError: Unable to adjust splits to ensure all splits have at least 1 subject. >>> external_splits = { ... 'train': np.array([1, 2, 3, 4, 5, 6], dtype=int), ... 'test': np.array([7, 8, 9, 10], dtype=int), ... } - >>> shard_patients(patients, 6, external_splits, split_fracs_dict=None) + >>> shard_subjects(subjects, 6, external_splits, split_fracs_dict=None) {'train/0': [1, 2, 3, 4, 5, 6], 'test/0': [7, 8, 9, 10]} - >>> shard_patients(patients, 3, external_splits) + >>> shard_subjects(subjects, 3, external_splits) {'train/0': [5, 1, 3], 'train/1': [2, 6, 4], 'test/0': [10, 7], 'test/1': [8, 9]} """ @@ -93,23 +93,23 @@ def shard_patients( if not isinstance(external_splits[k], np.ndarray): logger.warning( f"External split {k} is not a numpy array and thus type safety is not guaranteed. " - f"Attempting to convert to numpy array of dtype {patients.dtype}." + f"Attempting to convert to numpy array of dtype {subjects.dtype}." ) - external_splits[k] = np.array(external_splits[k], dtype=patients.dtype) + external_splits[k] = np.array(external_splits[k], dtype=subjects.dtype) - patients = np.unique(patients) + subjects = np.unique(subjects) # Splitting all_external_splits = set().union(*external_splits.values()) - is_in_external_split = np.isin(patients, list(all_external_splits)) - patient_ids_to_split = patients[~is_in_external_split] + is_in_external_split = np.isin(subjects, list(all_external_splits)) + subject_ids_to_split = subjects[~is_in_external_split] splits = external_splits splits_cover = sum(split_fracs_dict.values()) if split_fracs_dict else 0 rng = np.random.default_rng(seed) - if n_patients := len(patient_ids_to_split): + if n_subjects := len(subject_ids_to_split): if not math.isclose(splits_cover, 1): raise ValueError( f"The sum of the split fractions must be equal to 1. Got {splits_cover} " @@ -118,42 +118,42 @@ def shard_patients( split_names_idx = rng.permutation(len(split_fracs_dict)) split_names = np.array(list(split_fracs_dict.keys()))[split_names_idx] split_fracs = np.array([split_fracs_dict[k] for k in split_names]) - split_lens = np.round(split_fracs[:-1] * n_patients).astype(int) - split_lens = np.append(split_lens, n_patients - split_lens.sum()) + split_lens = np.round(split_fracs[:-1] * n_subjects).astype(int) + split_lens = np.append(split_lens, n_subjects - split_lens.sum()) if split_lens.min() == 0: logger.warning( - "Some splits are empty. Adjusting splits to ensure all splits have at least 1 patient." + "Some splits are empty. Adjusting splits to ensure all splits have at least 1 subject." ) max_split = split_lens.argmax() split_lens[max_split] -= 1 split_lens[split_lens.argmin()] += 1 if split_lens.min() == 0: - raise ValueError("Unable to adjust splits to ensure all splits have at least 1 patient.") + raise ValueError("Unable to adjust splits to ensure all splits have at least 1 subject.") - patients = rng.permutation(patient_ids_to_split) - patients_per_split = np.split(patients, split_lens.cumsum()) + subjects = rng.permutation(subject_ids_to_split) + subjects_per_split = np.split(subjects, split_lens.cumsum()) - splits = {**{k: v for k, v in zip(split_names, patients_per_split)}, **splits} + splits = {**{k: v for k, v in zip(split_names, subjects_per_split)}, **splits} else: if split_fracs_dict: logger.warning( - "External splits were provided covering all patients, but split_fracs_dict was not empty. " + "External splits were provided covering all subjects, but split_fracs_dict was not empty. " "Ignoring the split_fracs_dict." ) else: - logger.info("External splits were provided covering all patients.") + logger.info("External splits were provided covering all subjects.") # Sharding final_shards = {} for sp, pts in splits.items(): - if len(pts) <= n_patients_per_shard: + if len(pts) <= n_subjects_per_shard: final_shards[f"{sp}/0"] = pts.tolist() else: pts = rng.permutation(pts) n_pts = len(pts) - n_shards = int(np.ceil(n_pts / n_patients_per_shard)) + n_shards = int(np.ceil(n_pts / n_subjects_per_shard)) shards = np.array_split(pts, n_shards) for i, shard in enumerate(shards): final_shards[f"{sp}/{i}"] = shard.tolist() @@ -161,12 +161,12 @@ def shard_patients( seen = {} for k, pts in final_shards.items(): - logger.info(f"Split {k} has {len(pts)} patients.") + logger.info(f"Split {k} has {len(pts)} subjects.") for kk, v in seen.items(): shared = set(pts).intersection(v) if shared: - logger.info(f" - intersects {kk} on {len(shared)} patients.") + logger.info(f" - intersects {kk} on {len(shared)} subjects.") seen[k] = set(pts) @@ -175,9 +175,9 @@ def shard_patients( @hydra.main(version_base=None, config_path=str(CONFIG_YAML.parent), config_name=CONFIG_YAML.stem) def main(cfg: DictConfig): - """Extracts the set of unique patients from the raw data and splits/shards them and saves the result. + """Extracts the set of unique subjects from the raw data and splits/shards them and saves the result. - This stage splits the patients into training, tuning, and held-out sets, and further splits those sets + This stage splits the subjects into training, tuning, and held-out sets, and further splits those sets into shards. All arguments are specified through the command line into the `cfg` object through Hydra. @@ -185,19 +185,19 @@ def main(cfg: DictConfig): The `cfg.stage_cfg` object is a special key that is imputed by OmegaConf to contain the stage-specific configuration arguments based on the global, pipeline-level configuration file. It cannot be overwritten directly on the command line, but can be overwritten implicitly by overwriting components of the - `stage_configs.split_and_shard_patients` key. + `stage_configs.split_and_shard_subjects` key. Args: - stage_configs.split_and_shard_patients.n_patients_per_shard: The maximum number of patients to include - in any shard. Realized shards will not necessarily have this many patients, though they will never - exceed this number. Instead, the number of shards necessary to include all patients in a split - such that no shard exceeds this number will be calculated, then the patients will be evenly, + stage_configs.split_and_shard_subjects.n_subjects_per_shard: The maximum number of subjects to include + in any shard. Realized shards will not necessarily have this many subjects, though they will never + exceed this number. Instead, the number of shards necessary to include all subjects in a split + such that no shard exceeds this number will be calculated, then the subjects will be evenly, randomly split amongst those shards so that all shards within a split have approximately the same number of patietns. - stage_configs.split_and_shard_patients.external_splits_json_fp: The path to a json file containing any + stage_configs.split_and_shard_subjects.external_splits_json_fp: The path to a json file containing any pre-defined splits for specialty held-out test sets beyond the IID held out set that will be produced (e.g., for prospective datasets, etc.). - stage_configs.split_and_shard_patients.split_fracs: The fraction of patients to include in the IID + stage_configs.split_and_shard_subjects.split_fracs: The fraction of subjects to include in the IID training, tuning, and held-out sets. Split fractions can be changed for the default names by adding a hydra-syntax command line argument for the nested name; e.g., `split_fracs.train=0.7 split_fracs.tuning=0.1 split_fracs.held_out=0.2`. A split can be removed with the `~` override @@ -213,38 +213,38 @@ def main(cfg: DictConfig): raise FileNotFoundError(f"Event conversion config file not found: {event_conversion_cfg_fp}") logger.info( - f"Reading event conversion config from {event_conversion_cfg_fp} (needed for patient ID columns)" + f"Reading event conversion config from {event_conversion_cfg_fp} (needed for subject ID columns)" ) event_conversion_cfg = OmegaConf.load(event_conversion_cfg_fp) logger.info(f"Event conversion config:\n{OmegaConf.to_yaml(event_conversion_cfg)}") dfs = [] - default_patient_id_col = event_conversion_cfg.pop("patient_id_col", "patient_id") + default_subject_id_col = event_conversion_cfg.pop("subject_id_col", "subject_id") for input_prefix, event_cfgs in event_conversion_cfg.items(): - input_patient_id_column = event_cfgs.get("patient_id_col", default_patient_id_col) + input_subject_id_column = event_cfgs.get("subject_id_col", default_subject_id_col) input_fps = list((subsharded_dir / input_prefix).glob("**/*.parquet")) input_fps_strs = "\n".join(f" - {str(fp.resolve())}" for fp in input_fps) - logger.info(f"Reading patient IDs from {input_prefix} files:\n{input_fps_strs}") + logger.info(f"Reading subject IDs from {input_prefix} files:\n{input_fps_strs}") for input_fp in input_fps: dfs.append( pl.scan_parquet(input_fp, glob=False) - .select(pl.col(input_patient_id_column).alias("patient_id")) + .select(pl.col(input_subject_id_column).alias("subject_id")) .unique() ) - logger.info(f"Joining all patient IDs from {len(dfs)} dataframes") - patient_ids = ( + logger.info(f"Joining all subject IDs from {len(dfs)} dataframes") + subject_ids = ( pl.concat(dfs) - .select(pl.col("patient_id").drop_nulls().drop_nans().unique()) - .collect(streaming=True)["patient_id"] + .select(pl.col("subject_id").drop_nulls().drop_nans().unique()) + .collect(streaming=True)["subject_id"] .to_numpy(use_pyarrow=True) ) - logger.info(f"Found {len(patient_ids)} unique patient IDs of type {patient_ids.dtype}") + logger.info(f"Found {len(subject_ids)} unique subject IDs of type {subject_ids.dtype}") if cfg.stage_cfg.external_splits_json_fp: external_splits_json_fp = Path(cfg.stage_cfg.external_splits_json_fp) @@ -259,22 +259,22 @@ def main(cfg: DictConfig): else: external_splits = None - logger.info("Sharding and splitting patients") + logger.info("Sharding and splitting subjects") - sharded_patients = shard_patients( - patients=patient_ids, + sharded_subjects = shard_subjects( + subjects=subject_ids, external_splits=external_splits, split_fracs_dict=cfg.stage_cfg.split_fracs, - n_patients_per_shard=cfg.stage_cfg.n_patients_per_shard, + n_subjects_per_shard=cfg.stage_cfg.n_subjects_per_shard, seed=cfg.seed, ) shards_map_fp = Path(cfg.shards_map_fp) - logger.info(f"Writing sharded patients to {str(shards_map_fp.resolve())}") + logger.info(f"Writing sharded subjects to {str(shards_map_fp.resolve())}") shards_map_fp.parent.mkdir(parents=True, exist_ok=True) - shards_map_fp.write_text(json.dumps(sharded_patients)) - logger.info("Done writing sharded patients") + shards_map_fp.write_text(json.dumps(sharded_subjects)) + logger.info("Done writing sharded subjects") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/filters/README.md b/src/MEDS_transforms/filters/README.md index 9d582f0..22baa4b 100644 --- a/src/MEDS_transforms/filters/README.md +++ b/src/MEDS_transforms/filters/README.md @@ -1,5 +1,5 @@ # Filters -Filters remove wholesale events within the data, either at the patient or event level. For transformations +Filters remove wholesale events within the data, either at the subject or event level. For transformations that simply _occlude_ aspects of the data (e.g., by setting a code variable to `UNK`), see the `transforms` library section. diff --git a/src/MEDS_transforms/filters/filter_measurements.py b/src/MEDS_transforms/filters/filter_measurements.py index 36a6938..979dc0a 100644 --- a/src/MEDS_transforms/filters/filter_measurements.py +++ b/src/MEDS_transforms/filters/filter_measurements.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -"""A polars-to-polars transformation function for filtering patients by sequence length.""" +"""A polars-to-polars transformation function for filtering subjects by sequence length.""" from collections.abc import Callable import hydra @@ -13,7 +13,7 @@ def filter_measurements_fntr( stage_cfg: DictConfig, code_metadata: pl.LazyFrame, code_modifiers: list[str] | None = None ) -> Callable[[pl.LazyFrame], pl.LazyFrame]: - """Returns a function that filters patient events to only encompass those with a set of permissible codes. + """Returns a function that filters subject events to only encompass those with a set of permissible codes. Args: df: The input DataFrame. @@ -26,44 +26,44 @@ def filter_measurements_fntr( >>> code_metadata_df = pl.DataFrame({ ... "code": ["A", "A", "B", "C"], ... "modifier1": [1, 2, 1, 2], - ... "code/n_patients": [2, 1, 3, 2], + ... "code/n_subjects": [2, 1, 3, 2], ... "code/n_occurrences": [4, 5, 3, 2], ... }) >>> data = pl.DataFrame({ - ... "patient_id": [1, 1, 2, 2], + ... "subject_id": [1, 1, 2, 2], ... "code": ["A", "B", "A", "C"], ... "modifier1": [1, 1, 2, 2], ... }).lazy() - >>> stage_cfg = DictConfig({"min_patients_per_code": 2, "min_occurrences_per_code": 3}) + >>> stage_cfg = DictConfig({"min_subjects_per_code": 2, "min_occurrences_per_code": 3}) >>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"]) >>> fn(data).collect() shape: (2, 3) ┌────────────┬──────┬───────────┐ - │ patient_id ┆ code ┆ modifier1 │ + │ subject_id ┆ code ┆ modifier1 │ │ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ i64 │ ╞════════════╪══════╪═══════════╡ │ 1 ┆ A ┆ 1 │ │ 1 ┆ B ┆ 1 │ └────────────┴──────┴───────────┘ - >>> stage_cfg = DictConfig({"min_patients_per_code": 1, "min_occurrences_per_code": 4}) + >>> stage_cfg = DictConfig({"min_subjects_per_code": 1, "min_occurrences_per_code": 4}) >>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"]) >>> fn(data).collect() shape: (2, 3) ┌────────────┬──────┬───────────┐ - │ patient_id ┆ code ┆ modifier1 │ + │ subject_id ┆ code ┆ modifier1 │ │ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ i64 │ ╞════════════╪══════╪═══════════╡ │ 1 ┆ A ┆ 1 │ │ 2 ┆ A ┆ 2 │ └────────────┴──────┴───────────┘ - >>> stage_cfg = DictConfig({"min_patients_per_code": 1}) + >>> stage_cfg = DictConfig({"min_subjects_per_code": 1}) >>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"]) >>> fn(data).collect() shape: (4, 3) ┌────────────┬──────┬───────────┐ - │ patient_id ┆ code ┆ modifier1 │ + │ subject_id ┆ code ┆ modifier1 │ │ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ i64 │ ╞════════════╪══════╪═══════════╡ @@ -72,12 +72,12 @@ def filter_measurements_fntr( │ 2 ┆ A ┆ 2 │ │ 2 ┆ C ┆ 2 │ └────────────┴──────┴───────────┘ - >>> stage_cfg = DictConfig({"min_patients_per_code": None, "min_occurrences_per_code": None}) + >>> stage_cfg = DictConfig({"min_subjects_per_code": None, "min_occurrences_per_code": None}) >>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"]) >>> fn(data).collect() shape: (4, 3) ┌────────────┬──────┬───────────┐ - │ patient_id ┆ code ┆ modifier1 │ + │ subject_id ┆ code ┆ modifier1 │ │ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ i64 │ ╞════════════╪══════╪═══════════╡ @@ -91,20 +91,46 @@ def filter_measurements_fntr( >>> fn(data).collect() shape: (1, 3) ┌────────────┬──────┬───────────┐ - │ patient_id ┆ code ┆ modifier1 │ + │ subject_id ┆ code ┆ modifier1 │ │ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ i64 │ ╞════════════╪══════╪═══════════╡ │ 2 ┆ A ┆ 2 │ └────────────┴──────┴───────────┘ + + This stage works even if the default row index column exists: + >>> code_metadata_df = pl.DataFrame({ + ... "code": ["A", "A", "B", "C"], + ... "modifier1": [1, 2, 1, 2], + ... "code/n_subjects": [2, 1, 3, 2], + ... "code/n_occurrences": [4, 5, 3, 2], + ... }) + >>> data = pl.DataFrame({ + ... "subject_id": [1, 1, 2, 2], + ... "code": ["A", "B", "A", "C"], + ... "modifier1": [1, 1, 2, 2], + ... "_row_idx": [1, 1, 1, 1], + ... }).lazy() + >>> stage_cfg = DictConfig({"min_subjects_per_code": 2, "min_occurrences_per_code": 3}) + >>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"]) + >>> fn(data).collect() + shape: (2, 4) + ┌────────────┬──────┬───────────┬──────────┐ + │ subject_id ┆ code ┆ modifier1 ┆ _row_idx │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ str ┆ i64 ┆ i64 │ + ╞════════════╪══════╪═══════════╪══════════╡ + │ 1 ┆ A ┆ 1 ┆ 1 │ + │ 1 ┆ B ┆ 1 ┆ 1 │ + └────────────┴──────┴───────────┴──────────┘ """ - min_patients_per_code = stage_cfg.get("min_patients_per_code", None) + min_subjects_per_code = stage_cfg.get("min_subjects_per_code", None) min_occurrences_per_code = stage_cfg.get("min_occurrences_per_code", None) filter_exprs = [] - if min_patients_per_code is not None: - filter_exprs.append(pl.col("code/n_patients") >= min_patients_per_code) + if min_subjects_per_code is not None: + filter_exprs.append(pl.col("code/n_subjects") >= min_subjects_per_code) if min_occurrences_per_code is not None: filter_exprs.append(pl.col("code/n_occurrences") >= min_occurrences_per_code) @@ -118,10 +144,10 @@ def filter_measurements_fntr( allowed_code_metadata = (code_metadata.filter(pl.all_horizontal(filter_exprs)).select(join_cols)).lazy() def filter_measurements_fn(df: pl.LazyFrame) -> pl.LazyFrame: - f"""Filters patient events to only encompass those with a set of permissible codes. + f"""Filters subject events to only encompass those with a set of permissible codes. In particular, this function filters the DataFrame to only include (code, modifier) pairs that have - at least {min_patients_per_code} patients and {min_occurrences_per_code} occurrences. + at least {min_subjects_per_code} subjects and {min_occurrences_per_code} occurrences. """ idx_col = "_row_idx" @@ -147,5 +173,5 @@ def main(cfg: DictConfig): map_over(cfg, compute_fn=filter_measurements_fntr) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/filters/filter_patients.py b/src/MEDS_transforms/filters/filter_subjects.py similarity index 56% rename from src/MEDS_transforms/filters/filter_patients.py rename to src/MEDS_transforms/filters/filter_subjects.py index 36dc398..88d1fe6 100644 --- a/src/MEDS_transforms/filters/filter_patients.py +++ b/src/MEDS_transforms/filters/filter_subjects.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -"""A polars-to-polars transformation function for filtering patients by sequence length.""" +"""A polars-to-polars transformation function for filtering subjects by sequence length.""" from collections.abc import Callable from functools import partial @@ -12,25 +12,25 @@ from MEDS_transforms.mapreduce.mapper import map_over -def filter_patients_by_num_measurements(df: pl.LazyFrame, min_measurements_per_patient: int) -> pl.LazyFrame: - """Filters patients by the number of measurements they have. +def filter_subjects_by_num_measurements(df: pl.LazyFrame, min_measurements_per_subject: int) -> pl.LazyFrame: + """Filters subjects by the number of dynamic (timestamp non-null) measurements they have. Args: df: The input DataFrame. - min_measurements_per_patient: The minimum number of measurements a patient must have to be included. + min_measurements_per_subject: The minimum number of measurements a subject must have to be included. Returns: The filtered DataFrame. Examples: >>> df = pl.DataFrame({ - ... "patient_id": [1, 1, 1, 2, 2, 3], - ... "time": [1, 2, 1, 1, 2, 1], + ... "subject_id": [1, 1, 1, 2, 2, 3, 3, 4], + ... "time": [1, 2, 1, 1, 2, 1, None, None], ... }) - >>> filter_patients_by_num_measurements(df, 1) - shape: (6, 2) + >>> filter_subjects_by_num_measurements(df, 1) + shape: (7, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ @@ -40,11 +40,12 @@ def filter_patients_by_num_measurements(df: pl.LazyFrame, min_measurements_per_p │ 2 ┆ 1 │ │ 2 ┆ 2 │ │ 3 ┆ 1 │ + │ 3 ┆ null │ └────────────┴──────┘ - >>> filter_patients_by_num_measurements(df, 2) + >>> filter_subjects_by_num_measurements(df, 2) shape: (5, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ @@ -54,10 +55,10 @@ def filter_patients_by_num_measurements(df: pl.LazyFrame, min_measurements_per_p │ 2 ┆ 1 │ │ 2 ┆ 2 │ └────────────┴──────┘ - >>> filter_patients_by_num_measurements(df, 3) + >>> filter_subjects_by_num_measurements(df, 3) shape: (3, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ @@ -65,47 +66,48 @@ def filter_patients_by_num_measurements(df: pl.LazyFrame, min_measurements_per_p │ 1 ┆ 2 │ │ 1 ┆ 1 │ └────────────┴──────┘ - >>> filter_patients_by_num_measurements(df, 4) + >>> filter_subjects_by_num_measurements(df, 4) shape: (0, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ └────────────┴──────┘ - >>> filter_patients_by_num_measurements(df, 2.2) + >>> filter_subjects_by_num_measurements(df, 2.2) Traceback (most recent call last): ... - TypeError: min_measurements_per_patient must be an integer; got 2.2 + TypeError: min_measurements_per_subject must be an integer; got 2.2 """ - if not isinstance(min_measurements_per_patient, int): + if not isinstance(min_measurements_per_subject, int): raise TypeError( - f"min_measurements_per_patient must be an integer; got {type(min_measurements_per_patient)} " - f"{min_measurements_per_patient}" + f"min_measurements_per_subject must be an integer; got {type(min_measurements_per_subject)} " + f"{min_measurements_per_subject}" ) - return df.filter(pl.col("time").count().over("patient_id") >= min_measurements_per_patient) + return df.filter(pl.col("time").count().over("subject_id") >= min_measurements_per_subject) -def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) -> pl.LazyFrame: - """Filters patients by the number of events (unique timepoints) they have. +def filter_subjects_by_num_events(df: pl.LazyFrame, min_events_per_subject: int) -> pl.LazyFrame: + """Filters subjects by the number of events (unique timepoints) they have. Args: df: The input DataFrame. - min_events_per_patient: The minimum number of events a patient must have to be included. + min_events_per_subject: The minimum number of events a subject must have to be included. Returns: The filtered DataFrame. Examples: >>> df = pl.DataFrame({ - ... "patient_id": [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4], + ... "subject_id": [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4], ... "time": [1, 1, 1, 1, 2, 1, 1, 2, 3, None, None, 1, 2, 3], ... }) - >>> filter_patients_by_num_events(df, 1) + >>> with pl.Config(tbl_rows=15): + ... filter_subjects_by_num_events(df, 1) shape: (14, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ @@ -124,10 +126,11 @@ def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) │ 4 ┆ 2 │ │ 4 ┆ 3 │ └────────────┴──────┘ - >>> filter_patients_by_num_events(df, 2) + >>> with pl.Config(tbl_rows=15): + ... filter_subjects_by_num_events(df, 2) shape: (11, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ @@ -143,10 +146,10 @@ def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) │ 4 ┆ 2 │ │ 4 ┆ 3 │ └────────────┴──────┘ - >>> filter_patients_by_num_events(df, 3) + >>> filter_subjects_by_num_events(df, 3) shape: (8, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ @@ -159,10 +162,10 @@ def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) │ 4 ┆ 2 │ │ 4 ┆ 3 │ └────────────┴──────┘ - >>> filter_patients_by_num_events(df, 4) + >>> filter_subjects_by_num_events(df, 4) shape: (5, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ @@ -172,48 +175,77 @@ def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) │ 4 ┆ 2 │ │ 4 ┆ 3 │ └────────────┴──────┘ - >>> filter_patients_by_num_events(df, 5) + >>> filter_subjects_by_num_events(df, 5) shape: (0, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ └────────────┴──────┘ - >>> filter_patients_by_num_events(df, 2.2) + >>> filter_subjects_by_num_events(df, 2.2) Traceback (most recent call last): ... - TypeError: min_events_per_patient must be an integer; got 2.2 + TypeError: min_events_per_subject must be an integer; got 2.2 """ - if not isinstance(min_events_per_patient, int): + if not isinstance(min_events_per_subject, int): raise TypeError( - f"min_events_per_patient must be an integer; got {type(min_events_per_patient)} " - f"{min_events_per_patient}" + f"min_events_per_subject must be an integer; got {type(min_events_per_subject)} " + f"{min_events_per_subject}" ) - return df.filter(pl.col("time").n_unique().over("patient_id") >= min_events_per_patient) + return df.filter(pl.col("time").n_unique().over("subject_id") >= min_events_per_subject) -def filter_patients_fntr(stage_cfg: DictConfig) -> Callable[[pl.LazyFrame], pl.LazyFrame]: +def filter_subjects_fntr(stage_cfg: DictConfig) -> Callable[[pl.LazyFrame], pl.LazyFrame]: + """Returns a function that filters subjects by the number of measurements and events they have. + + Args: + stage_cfg: The stage configuration. Arguments include: min_measurements_per_subject, + min_events_per_subject, both of which should be integers or None which specify the minimum number + of measurements and events a subject must have to be included, respectively. + + Returns: + The function that filters subjects by the number of measurements and/or events they have. + + Examples: + >>> df = pl.DataFrame({ + ... "subject_id": [1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5], + ... "time": [1, 1, 1, 1, 1, 1, 2, 3, None, None, 1, 2, 2, None, 1, 2, 3, 1], + ... }) + >>> stage_cfg = DictConfig({"min_measurements_per_subject": 4, "min_events_per_subject": 2}) + >>> filter_subjects_fntr(stage_cfg)(df) + shape: (4, 2) + ┌────────────┬──────┐ + │ subject_id ┆ time │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞════════════╪══════╡ + │ 5 ┆ 1 │ + │ 5 ┆ 2 │ + │ 5 ┆ 3 │ + │ 5 ┆ 1 │ + └────────────┴──────┘ + """ compute_fns = [] - if stage_cfg.min_measurements_per_patient: + if stage_cfg.min_measurements_per_subject: logger.info( - f"Filtering patients with fewer than {stage_cfg.min_measurements_per_patient} measurements " + f"Filtering subjects with fewer than {stage_cfg.min_measurements_per_subject} measurements " "(observations of any kind)." ) compute_fns.append( partial( - filter_patients_by_num_measurements, - min_measurements_per_patient=stage_cfg.min_measurements_per_patient, + filter_subjects_by_num_measurements, + min_measurements_per_subject=stage_cfg.min_measurements_per_subject, ) ) - if stage_cfg.min_events_per_patient: + if stage_cfg.min_events_per_subject: logger.info( - f"Filtering patients with fewer than {stage_cfg.min_events_per_patient} events " + f"Filtering subjects with fewer than {stage_cfg.min_events_per_subject} events " "(unique timepoints)." ) compute_fns.append( - partial(filter_patients_by_num_events, min_events_per_patient=stage_cfg.min_events_per_patient) + partial(filter_subjects_by_num_events, min_events_per_subject=stage_cfg.min_events_per_subject) ) def fn(data: pl.LazyFrame) -> pl.LazyFrame: @@ -230,8 +262,8 @@ def fn(data: pl.LazyFrame) -> pl.LazyFrame: def main(cfg: DictConfig): """TODO.""" - map_over(cfg, compute_fn=filter_patients_fntr) + map_over(cfg, compute_fn=filter_subjects_fntr) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/fit_vocabulary_indices.py b/src/MEDS_transforms/fit_vocabulary_indices.py index 0fb249b..b5e4327 100644 --- a/src/MEDS_transforms/fit_vocabulary_indices.py +++ b/src/MEDS_transforms/fit_vocabulary_indices.py @@ -236,5 +236,5 @@ def main(cfg: DictConfig): logger.info(f"Done with {cfg.stage}") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/mapreduce/mapper.py b/src/MEDS_transforms/mapreduce/mapper.py index cdc60e0..6cc44e8 100644 --- a/src/MEDS_transforms/mapreduce/mapper.py +++ b/src/MEDS_transforms/mapreduce/mapper.py @@ -11,6 +11,7 @@ import hydra import polars as pl from loguru import logger +from meds import subject_id_field from omegaconf import DictConfig, ListConfig from ..parser import is_matcher, matcher_to_expr @@ -336,7 +337,7 @@ def match_revise_fntr(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_CO compute function with the ``local_arg_1=baz`` parameter. Both of these local compute functions will be applied to the input DataFrame in sequence, and the resulting DataFrames will be concatenated alongside any of the dataframe that matches no matcher (which will be left unmodified) and merged in a sorted way - that respects the ``patient_id``, ``time`` ordering first, then the order of the match & revise blocks + that respects the ``subject_id``, ``time`` ordering first, then the order of the match & revise blocks themselves, then the order of the rows in each match & revise block output. Each local compute function will also use the ``global_arg_1=foo`` parameter. @@ -358,7 +359,7 @@ def match_revise_fntr(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_CO Examples: >>> df = pl.DataFrame({ - ... "patient_id": [1, 1, 1, 2, 2, 2], + ... "subject_id": [1, 1, 1, 2, 2, 2], ... "time": [1, 2, 2, 1, 1, 2], ... "initial_idx": [0, 1, 2, 3, 4, 5], ... "code": ["FINAL", "CODE//TEMP_2", "CODE//TEMP_1", "FINAL", "CODE//TEMP_2", "CODE//TEMP_1"] @@ -380,7 +381,7 @@ def match_revise_fntr(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_CO >>> match_revise_fn(df.lazy()).collect() shape: (6, 4) ┌────────────┬──────┬─────────────┬────────────────┐ - │ patient_id ┆ time ┆ initial_idx ┆ code │ + │ subject_id ┆ time ┆ initial_idx ┆ code │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ i64 ┆ str │ ╞════════════╪══════╪═════════════╪════════════════╡ @@ -403,7 +404,7 @@ def match_revise_fntr(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_CO >>> match_revise_fn(df.lazy()).collect() shape: (6, 4) ┌────────────┬──────┬─────────────┬─────────────────┐ - │ patient_id ┆ time ┆ initial_idx ┆ code │ + │ subject_id ┆ time ┆ initial_idx ┆ code │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ i64 ┆ str │ ╞════════════╪══════╪═════════════╪═════════════════╡ @@ -424,7 +425,7 @@ def match_revise_fntr(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_CO ... ValueError: Missing needed columns {'missing'} for local matcher 0: [(col("missing")) == (String(CODE//TEMP_2))].all_horizontal() - Columns available: 'code', 'initial_idx', 'patient_id', 'time' + Columns available: 'code', 'initial_idx', 'subject_id', 'time' >>> stage_cfg = DictConfig({"global_code_end": "foo"}) >>> cfg = DictConfig({"stage_cfg": stage_cfg}) >>> match_revise_fn = match_revise_fntr(cfg, stage_cfg, compute_fn) @@ -479,7 +480,7 @@ def match_revise_fn(df: DF_T) -> DF_T: revision_parts.append(matchable_df.filter(pl.all_horizontal(final_part_filters))) else: revision_parts.append(matchable_df) - return pl.concat(revision_parts, how="vertical").sort(["patient_id", "time"], maintain_order=True) + return pl.concat(revision_parts, how="vertical").sort([subject_id_field, "time"], maintain_order=True) return match_revise_fn @@ -620,7 +621,7 @@ def map_over( start = datetime.now() train_only = cfg.stage_cfg.get("train_only", False) - split_fp = Path(cfg.input_dir) / "metadata" / "patient_split.parquet" + split_fp = Path(cfg.input_dir) / "metadata" / "subject_split.parquet" shards, includes_only_train = shard_iterator_fntr(cfg) @@ -631,18 +632,18 @@ def map_over( ) elif split_fp.exists(): logger.info(f"Processing train split only by filtering read dfs via {str(split_fp.resolve())}") - train_patients = ( + train_subjects = ( pl.scan_parquet(split_fp) .filter(pl.col("split") == "train") - .select(pl.col("patient_id")) + .select(subject_id_field) .collect() .to_list() ) - read_fn = read_and_filter_fntr(train_patients, read_fn) + read_fn = read_and_filter_fntr(train_subjects, read_fn) else: raise FileNotFoundError( f"Train split requested, but shard prefixes can't be used and " - f"patient split file not found at {str(split_fp.resolve())}." + f"subject split file not found at {str(split_fp.resolve())}." ) elif includes_only_train: raise ValueError("All splits should be used, but shard iterator is returning only train splits?!?") diff --git a/src/MEDS_transforms/mapreduce/utils.py b/src/MEDS_transforms/mapreduce/utils.py index 300e4c3..a653d20 100644 --- a/src/MEDS_transforms/mapreduce/utils.py +++ b/src/MEDS_transforms/mapreduce/utils.py @@ -315,14 +315,14 @@ def shard_iterator( >>> from tempfile import TemporaryDirectory >>> import polars as pl >>> df = pl.DataFrame({ - ... "patient_id": [1, 2, 3, 4, 5, 6, 7, 8, 9], + ... "subject_id": [1, 2, 3, 4, 5, 6, 7, 8, 9], ... "code": ["A", "B", "C", "D", "E", "F", "G", "H", "I"], ... "time": [1, 2, 3, 4, 5, 6, 1, 2, 3], ... }) >>> shards = {"train/0": [1, 2, 3, 4], "train/1": [5, 6, 7], "tuning": [8], "held_out": [9]} >>> def write_dfs(input_dir: Path, df: pl.DataFrame=df, shards: dict=shards, sfx: str=".parquet"): - ... for shard_name, patient_ids in shards.items(): - ... df = df.filter(pl.col("patient_id").is_in(patient_ids)) + ... for shard_name, subject_ids in shards.items(): + ... df = df.filter(pl.col("subject_id").is_in(subject_ids)) ... shard_fp = input_dir / f"{shard_name}{sfx}" ... shard_fp.parent.mkdir(exist_ok=True, parents=True) ... if sfx == ".parquet": df.write_parquet(shard_fp) @@ -483,9 +483,9 @@ def shard_iterator( shards = train_shards includes_only_train = True elif train_only: - logger.info( + logger.warning( f"train_only={train_only} requested but no dedicated train shards found; processing all shards " - "and relying on `patient_splits.parquet` for filtering." + "and relying on `subject_splits.parquet` for filtering." ) shards = shuffle_shards(shards, cfg) diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index f936196..0fc06fe 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -13,13 +13,37 @@ from omegaconf import DictConfig from MEDS_transforms import PREPROCESS_CONFIG_YAML -from MEDS_transforms.extract.split_and_shard_patients import shard_patients +from MEDS_transforms.extract.split_and_shard_subjects import shard_subjects from MEDS_transforms.mapreduce.utils import rwlock_wrap, shard_iterator, shuffle_shards from MEDS_transforms.utils import stage_init, write_lazyframe def valid_json_file(fp: Path) -> bool: - """Check if a file is a valid JSON file.""" + """Check if a file is a valid JSON file. + + Args: + fp: Path to the file. + + Returns: + True if the file is a valid JSON file, False otherwise. + + Examples: + >>> import tempfile + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... fp = Path(tmpdir) / "test.json" + ... valid_json_file(fp) + False + >>> with tempfile.NamedTemporaryFile(suffix=".json") as tmpfile: + ... fp = Path(tmpfile.name) + ... _ = fp.write_text("foobar not a json file.\tHello, world!") + ... valid_json_file(fp) + False + >>> with tempfile.NamedTemporaryFile(suffix=".json") as tmpfile: + ... fp = Path(tmpfile.name) + ... _ = fp.write_text('{"foo": "bar"}') + ... valid_json_file(fp) + True + """ if not fp.is_file(): return False try: @@ -30,13 +54,14 @@ def valid_json_file(fp: Path) -> bool: def make_new_shards_fn(df: pl.DataFrame, cfg: DictConfig, stage_cfg: DictConfig) -> dict[str, list[str]]: + """This function creates a new sharding scheme for the MEDS cohort.""" splits_map = defaultdict(list) for pt_id, sp in df.iter_rows(): splits_map[sp].append(pt_id) - return shard_patients( - patients=df["patient_id"].to_numpy(), - n_patients_per_shard=stage_cfg.n_patients_per_shard, + return shard_subjects( + subjects=df["subject_id"].to_numpy(), + n_subjects_per_shard=stage_cfg.n_subjects_per_shard, external_splits=splits_map, split_fracs_dict=None, seed=cfg.get("seed", 1), @@ -44,6 +69,20 @@ def make_new_shards_fn(df: pl.DataFrame, cfg: DictConfig, stage_cfg: DictConfig) def write_json(d: dict, fp: Path) -> None: + """Write a dictionary to a JSON file. + + Args: + d: Dictionary to write. + fp: Path to the file. + + Examples: + >>> import tempfile + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... fp = Path(tmpdir) / "test.json" + ... write_json({"foo": "bar"}, fp) + ... fp.read_text() + '{"foo": "bar"}' + """ fp.write_text(json.dumps(d)) @@ -51,13 +90,13 @@ def write_json(d: dict, fp: Path) -> None: version_base=None, config_path=str(PREPROCESS_CONFIG_YAML.parent), config_name=PREPROCESS_CONFIG_YAML.stem ) def main(cfg: DictConfig): - """Re-shard a MEDS cohort to in a manner that subdivides patient splits.""" + """Re-shard a MEDS cohort to in a manner that subdivides subject splits.""" stage_init(cfg) output_dir = Path(cfg.stage_cfg.output_dir) - splits_file = Path(cfg.input_dir) / "metadata" / "patient_splits.parquet" + splits_file = Path(cfg.input_dir) / "metadata" / "subject_splits.parquet" shards_fp = output_dir / ".shards.json" rwlock_wrap( @@ -79,9 +118,10 @@ def main(cfg: DictConfig): new_sharded_splits = json.loads(shards_fp.read_text()) - orig_shards_iter, include_only_train = shard_iterator(cfg, out_suffix="") - if include_only_train: - raise ValueError("This stage does not support include_only_train=True") + if cfg.stage_cfg.get("train_only", False): + raise ValueError("This stage does not support train_only=True") + + orig_shards_iter, _ = shard_iterator(cfg, out_suffix="") orig_shards_iter = [(in_fp, out_fp.relative_to(output_dir)) for in_fp, out_fp in orig_shards_iter] @@ -92,22 +132,22 @@ def main(cfg: DictConfig): logger.info("Starting sub-sharding") for subshard_name, out_fp in new_shards_iter: - patients = new_sharded_splits[subshard_name] + subjects = new_sharded_splits[subshard_name] def read_fn(input_dir: Path) -> pl.LazyFrame: df = None logger.info(f"Reading shards for {subshard_name} (file names are in the input sharding scheme):") for in_fp, _ in orig_shards_iter: logger.info(f" - {str(in_fp.relative_to(input_dir).resolve())}") - new_df = pl.scan_parquet(in_fp, glob=False).filter(pl.col("patient_id").is_in(patients)) + new_df = pl.scan_parquet(in_fp, glob=False).filter(pl.col("subject_id").is_in(subjects)) if df is None: df = new_df else: - df = df.merge_sorted(new_df, key="patient_id") + df = df.merge_sorted(new_df, key="subject_id") return df def compute_fn(df: list[pl.DataFrame]) -> pl.LazyFrame: - return df.sort(by=["patient_id", "time"], maintain_order=True, multithreaded=False) + return df.sort(by=["subject_id", "time"], maintain_order=True, multithreaded=False) def write_fn(df: pl.LazyFrame, out_fp: Path) -> None: write_lazyframe(df, out_fp) @@ -125,5 +165,5 @@ def write_fn(df: pl.LazyFrame, out_fp: Path) -> None: logger.info(f"Done with {cfg.stage}") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/transforms/add_time_derived_measurements.py b/src/MEDS_transforms/transforms/add_time_derived_measurements.py index 01ec7f9..d1d3d1e 100644 --- a/src/MEDS_transforms/transforms/add_time_derived_measurements.py +++ b/src/MEDS_transforms/transforms/add_time_derived_measurements.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -"""Transformations for adding time-derived measurements (e.g., a patient's age) to a MEDS dataset.""" +"""Transformations for adding time-derived measurements (e.g., a subject's age) to a MEDS dataset.""" from collections.abc import Callable import hydra @@ -25,7 +25,7 @@ def add_new_events_fntr(fn: Callable[[pl.DataFrame], pl.DataFrame]) -> Callable[ >>> from datetime import datetime >>> df = pl.DataFrame( ... { - ... "patient_id": [1, 1, 1, 1, 2, 2, 3, 3], + ... "subject_id": [1, 1, 1, 1, 2, 2, 3, 3], ... "time": [ ... None, ... datetime(1990, 1, 1), @@ -38,12 +38,12 @@ def add_new_events_fntr(fn: Callable[[pl.DataFrame], pl.DataFrame]) -> Callable[ ... ], ... "code": ["static", "DOB", "lab//A", "lab//B", "DOB", "lab//A", "lab//B", "dx//1"], ... }, - ... schema={"patient_id": pl.UInt32, "time": pl.Datetime, "code": pl.Utf8}, + ... schema={"subject_id": pl.UInt32, "time": pl.Datetime, "code": pl.Utf8}, ... ) >>> df shape: (8, 3) ┌────────────┬─────────────────────┬────────┐ - │ patient_id ┆ time ┆ code │ + │ subject_id ┆ time ┆ code │ │ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ str │ ╞════════════╪═════════════════════╪════════╡ @@ -62,7 +62,7 @@ def add_new_events_fntr(fn: Callable[[pl.DataFrame], pl.DataFrame]) -> Callable[ >>> age_fn(df) shape: (2, 4) ┌────────────┬─────────────────────┬──────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ str ┆ f32 │ ╞════════════╪═════════════════════╪══════╪═══════════════╡ @@ -74,7 +74,7 @@ def add_new_events_fntr(fn: Callable[[pl.DataFrame], pl.DataFrame]) -> Callable[ >>> add_age_fn(df) shape: (10, 4) ┌────────────┬─────────────────────┬────────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ str ┆ f32 │ ╞════════════╪═════════════════════╪════════╪═══════════════╡ @@ -96,7 +96,7 @@ def out_fn(df: pl.DataFrame) -> pl.DataFrame: df = df.with_row_index("__idx") new_events = new_events.with_columns(pl.lit(0, dtype=df.schema["__idx"]).alias("__idx")) return ( - pl.concat([df, new_events], how="diagonal").sort(by=["patient_id", "time", "__idx"]).drop("__idx") + pl.concat([df, new_events], how="diagonal").sort(by=["subject_id", "time", "__idx"]).drop("__idx") ) return out_fn @@ -170,7 +170,7 @@ def normalize_time_unit(unit: str) -> tuple[str, float]: def age_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: - """Create a function that adds a patient's age to a DataFrame. + """Create a function that adds a subject's age to a DataFrame. Args: cfg: The configuration for the age function. This must contain the following mandatory keys: @@ -179,8 +179,8 @@ def age_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: - "age_unit": The unit for the age event when converted to a numeric value in the output data. Returns: - A function that returns the to-be-added "age" events with the patient's age for all input events with - unique, non-null times in the data, for all patients who have an observed date of birth. It does + A function that returns the to-be-added "age" events with the subject's age for all input events with + unique, non-null times in the data, for all subjects who have an observed date of birth. It does not add an event for times that are equal to the date of birth. Raises: @@ -190,7 +190,7 @@ def age_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: >>> from datetime import datetime >>> df = pl.DataFrame( ... { - ... "patient_id": [1, 1, 1, 1, 1, 2, 2, 3, 3], + ... "subject_id": [1, 1, 1, 1, 1, 2, 2, 3, 3], ... "time": [ ... None, ... datetime(1990, 1, 1), @@ -204,12 +204,12 @@ def age_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: ... ], ... "code": ["static", "DOB", "lab//A", "lab//B", "rx", "DOB", "lab//A", "lab//B", "dx//1"], ... }, - ... schema={"patient_id": pl.UInt32, "time": pl.Datetime, "code": pl.Utf8}, + ... schema={"subject_id": pl.UInt32, "time": pl.Datetime, "code": pl.Utf8}, ... ) >>> df shape: (9, 3) ┌────────────┬─────────────────────┬────────┐ - │ patient_id ┆ time ┆ code │ + │ subject_id ┆ time ┆ code │ │ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ str │ ╞════════════╪═════════════════════╪════════╡ @@ -228,7 +228,7 @@ def age_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: >>> age_fn(df) shape: (3, 4) ┌────────────┬─────────────────────┬──────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ str ┆ f32 │ ╞════════════╪═════════════════════╪══════╪═══════════════╡ @@ -248,15 +248,15 @@ def age_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: microseconds_in_unit = int(1e6) * seconds_in_unit def fn(df: pl.LazyFrame) -> pl.LazyFrame: - dob_expr = pl.when(pl.col("code") == cfg.DOB_code).then(pl.col("time")).min().over("patient_id") + dob_expr = pl.when(pl.col("code") == cfg.DOB_code).then(pl.col("time")).min().over("subject_id") age_expr = (pl.col("time") - dob_expr).dt.total_microseconds() / microseconds_in_unit age_expr = age_expr.cast(pl.Float32, strict=False) return ( df.drop_nulls(subset=["time"]) - .unique(subset=["patient_id", "time"], maintain_order=True) + .unique(subset=["subject_id", "time"], maintain_order=True) .select( - "patient_id", + "subject_id", "time", pl.lit(cfg.age_code, dtype=df.schema["code"]).alias("code"), age_expr.alias("numeric_value"), @@ -283,7 +283,7 @@ def time_of_day_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: >>> from datetime import datetime >>> df = pl.DataFrame( ... { - ... "patient_id": [1, 1, 1, 1, 2, 2, 3, 3], + ... "subject_id": [1, 1, 1, 1, 2, 2, 3, 3], ... "time": [ ... None, ... datetime(1990, 1, 1, 1, 0), @@ -296,12 +296,12 @@ def time_of_day_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: ... ], ... "code": ["static", "DOB", "lab//A", "lab//B", "DOB", "lab//A", "lab//B", "dx//1"], ... }, - ... schema={"patient_id": pl.UInt32, "time": pl.Datetime, "code": pl.Utf8}, + ... schema={"subject_id": pl.UInt32, "time": pl.Datetime, "code": pl.Utf8}, ... ) >>> df shape: (8, 3) ┌────────────┬─────────────────────┬────────┐ - │ patient_id ┆ time ┆ code │ + │ subject_id ┆ time ┆ code │ │ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ str │ ╞════════════╪═════════════════════╪════════╡ @@ -319,7 +319,7 @@ def time_of_day_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: >>> time_of_day_fn(df) shape: (6, 3) ┌────────────┬─────────────────────┬──────────────────────┐ - │ patient_id ┆ time ┆ code │ + │ subject_id ┆ time ┆ code │ │ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ str │ ╞════════════╪═════════════════════╪══════════════════════╡ @@ -357,8 +357,8 @@ def tod_code(start: int, end: int) -> str: time_of_day = time_of_day.when(hour >= end).then(tod_code(end, 24)) return ( df.drop_nulls(subset=["time"]) - .unique(subset=["patient_id", "time"], maintain_order=True) - .select("patient_id", "time", time_of_day.alias("code")) + .unique(subset=["subject_id", "time"], maintain_order=True) + .select("subject_id", "time", time_of_day.alias("code")) ) return fn @@ -398,5 +398,5 @@ def main(cfg: DictConfig): map_over(cfg, compute_fn=add_time_derived_measurements_fntr) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/transforms/extract_values.py b/src/MEDS_transforms/transforms/extract_values.py index a1d42b6..f233550 100644 --- a/src/MEDS_transforms/transforms/extract_values.py +++ b/src/MEDS_transforms/transforms/extract_values.py @@ -5,6 +5,7 @@ import hydra import polars as pl from loguru import logger +from meds import subject_id_field from omegaconf import DictConfig from MEDS_transforms import DEPRECATED_NAMES, INFERRED_STAGE_KEYS, MANDATORY_TYPES, PREPROCESS_CONFIG_YAML @@ -31,13 +32,13 @@ def extract_values_fntr(stage_cfg: DictConfig) -> Callable[[pl.LazyFrame], pl.La >>> stage_cfg = {"numeric_value": "foo", "categorical_value": "bar"} >>> fn = extract_values_fntr(stage_cfg) >>> df = pl.DataFrame({ - ... "patient_id": [1, 1, 1], "time": [1, 2, 3], + ... "subject_id": [1, 1, 1], "time": [1, 2, 3], ... "foo": ["1", "2", "3"], "bar": [1.0, 2.0, 4.0], ... }) >>> fn(df) shape: (3, 6) ┌────────────┬──────┬─────┬─────┬───────────────┬───────────────────┐ - │ patient_id ┆ time ┆ foo ┆ bar ┆ numeric_value ┆ categorical_value │ + │ subject_id ┆ time ┆ foo ┆ bar ┆ numeric_value ┆ categorical_value │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str ┆ f64 ┆ f32 ┆ str │ ╞════════════╪══════╪═════╪═════╪═══════════════╪═══════════════════╡ @@ -57,7 +58,7 @@ def extract_values_fntr(stage_cfg: DictConfig) -> Callable[[pl.LazyFrame], pl.La ValueError: Error building expression for numeric_value... >>> stage_cfg = {"numeric_value": "foo", "categorical_value": "bar"} >>> fn = extract_values_fntr(stage_cfg) - >>> df = pl.DataFrame({"patient_id": [1, 1, 1], "time": [1, 2, 3]}) + >>> df = pl.DataFrame({"subject_id": [1, 1, 1], "time": [1, 2, 3]}) >>> fn(df) Traceback (most recent call last): ... @@ -66,11 +67,11 @@ def extract_values_fntr(stage_cfg: DictConfig) -> Callable[[pl.LazyFrame], pl.La Note that deprecated column names like "numerical_value" or "timestamp" won't be re-typed. >>> stage_cfg = {"numerical_value": "foo"} >>> fn = extract_values_fntr(stage_cfg) - >>> df = pl.DataFrame({"patient_id": [1, 1, 1], "time": [1, 2, 3], "foo": ["1", "2", "3"]}) + >>> df = pl.DataFrame({"subject_id": [1, 1, 1], "time": [1, 2, 3], "foo": ["1", "2", "3"]}) >>> fn(df) shape: (3, 4) ┌────────────┬──────┬─────┬─────────────────┐ - │ patient_id ┆ time ┆ foo ┆ numerical_value │ + │ subject_id ┆ time ┆ foo ┆ numerical_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str ┆ str │ ╞════════════╪══════╪═════╪═════════════════╡ @@ -94,8 +95,10 @@ def extract_values_fntr(stage_cfg: DictConfig) -> Callable[[pl.LazyFrame], pl.La match out_col_n: case str() if out_col_n in MANDATORY_TYPES: expr = expr.cast(MANDATORY_TYPES[out_col_n]) - if out_col_n == "patient_id": - logger.warning("You should almost CERTAINLY not be extracting patient_id as a value.") + if out_col_n == subject_id_field: + logger.warning( + f"You should almost CERTAINLY not be extracting {subject_id_field} as a value." + ) if out_col_n == "time": logger.warning("Warning: `time` is being extracted post-hoc!") case str() if out_col_n in DEPRECATED_NAMES: @@ -116,7 +119,7 @@ def compute_fn(df: pl.LazyFrame) -> pl.LazyFrame: if not need_cols.issubset(in_cols): raise ValueError(f"Missing columns: {sorted(list(need_cols - in_cols))}") - return df.with_columns(new_cols).sort("patient_id", "time", maintain_order=True) + return df.with_columns(new_cols).sort(subject_id_field, "time", maintain_order=True) return compute_fn @@ -130,5 +133,5 @@ def main(cfg: DictConfig): map_over(cfg, compute_fn=extract_values_fntr) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/transforms/normalization.py b/src/MEDS_transforms/transforms/normalization.py index fbad9ac..5109861 100644 --- a/src/MEDS_transforms/transforms/normalization.py +++ b/src/MEDS_transforms/transforms/normalization.py @@ -16,7 +16,7 @@ def normalize( """Normalize a MEDS dataset across both categorical and continuous dimensions. This function expects a MEDS dataset in flattened form, with columns for: - - `patient_id` + - `subject_id` - `time` - `code` - `numeric_value` @@ -61,7 +61,7 @@ def normalize( >>> from datetime import datetime >>> MEDS_df = pl.DataFrame( ... { - ... "patient_id": [1, 1, 1, 2, 2, 2, 3], + ... "subject_id": [1, 1, 1, 2, 2, 2, 3], ... "time": [ ... datetime(2021, 1, 1), ... datetime(2021, 1, 1), @@ -76,7 +76,7 @@ def normalize( ... "unit": ["mg/dL", "g/dL", None, "mg/dL", None, None, None], ... }, ... schema = { - ... "patient_id": pl.UInt32, + ... "subject_id": pl.UInt32, ... "time": pl.Datetime, ... "code": pl.Utf8, ... "numeric_value": pl.Float64, @@ -100,7 +100,7 @@ def normalize( >>> normalize(MEDS_df.lazy(), code_metadata).collect() shape: (6, 4) ┌────────────┬─────────────────────┬──────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ u32 ┆ f64 │ ╞════════════╪═════════════════════╪══════╪═══════════════╡ @@ -113,7 +113,7 @@ def normalize( └────────────┴─────────────────────┴──────┴───────────────┘ >>> MEDS_df = pl.DataFrame( ... { - ... "patient_id": [1, 1, 1, 2, 2, 2, 3], + ... "subject_id": [1, 1, 1, 2, 2, 2, 3], ... "time": [ ... datetime(2021, 1, 1), ... datetime(2021, 1, 1), @@ -128,7 +128,7 @@ def normalize( ... "unit": ["mg/dL", "g/dL", None, "mg/dL", None, None, None], ... }, ... schema = { - ... "patient_id": pl.UInt32, + ... "subject_id": pl.UInt32, ... "time": pl.Datetime, ... "code": pl.Utf8, ... "numeric_value": pl.Float64, @@ -154,7 +154,7 @@ def normalize( >>> normalize(MEDS_df.lazy(), code_metadata, ["unit"]).collect() shape: (6, 4) ┌────────────┬─────────────────────┬──────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ u32 ┆ f64 │ ╞════════════╪═════════════════════╪══════╪═══════════════╡ @@ -201,7 +201,7 @@ def normalize( ) .select( idx_col, - "patient_id", + "subject_id", "time", pl.col("code/vocab_index").alias("code"), ((pl.col("numeric_value") - pl.col("values/mean")) / pl.col("values/std")).alias("numeric_value"), @@ -220,5 +220,5 @@ def main(cfg: DictConfig): map_over(cfg, compute_fn=normalize) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/transforms/occlude_outliers.py b/src/MEDS_transforms/transforms/occlude_outliers.py index 107407d..d65ecd5 100644 --- a/src/MEDS_transforms/transforms/occlude_outliers.py +++ b/src/MEDS_transforms/transforms/occlude_outliers.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -"""A polars-to-polars transformation function for filtering patients by sequence length.""" +"""A polars-to-polars transformation function for filtering subjects by sequence length.""" from collections.abc import Callable import hydra @@ -13,7 +13,7 @@ def occlude_outliers_fntr( stage_cfg: DictConfig, code_metadata: pl.LazyFrame, code_modifiers: list[str] | None = None ) -> Callable[[pl.LazyFrame], pl.LazyFrame]: - """Filters patient events to only encompass those with a set of permissible codes. + """Filters subject events to only encompass those with a set of permissible codes. Args: df: The input DataFrame. @@ -33,7 +33,7 @@ def occlude_outliers_fntr( ... # for clarity: --- stddev = [3.0, 0.0, 3.0, 1.0] ... }) >>> data = pl.DataFrame({ - ... "patient_id": [1, 1, 2, 2], + ... "subject_id": [1, 1, 2, 2], ... "code": ["A", "B", "A", "C"], ... "modifier1": [1, 1, 2, 2], ... # for clarity: mean [0.0, 4.0, 4.0, 1.0] @@ -45,7 +45,7 @@ def occlude_outliers_fntr( >>> fn(data).collect() shape: (4, 5) ┌────────────┬──────┬───────────┬───────────────┬─────────────────────────┐ - │ patient_id ┆ code ┆ modifier1 ┆ numeric_value ┆ numeric_value/is_inlier │ + │ subject_id ┆ code ┆ modifier1 ┆ numeric_value ┆ numeric_value/is_inlier │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ i64 ┆ f64 ┆ bool │ ╞════════════╪══════╪═══════════╪═══════════════╪═════════════════════════╡ @@ -78,7 +78,7 @@ def occlude_outliers_fntr( code_metadata = code_metadata.lazy().select(cols_to_select) def occlude_outliers_fn(df: pl.LazyFrame) -> pl.LazyFrame: - f"""Filters out outlier numeric values from patient events. + f"""Filters out outlier numeric values from subject events. In particular, this function filters the DataFrame to only include numeric values that are within {stddev_cutoff} standard deviations of the mean for the corresponding (code, modifier) pair. @@ -110,5 +110,5 @@ def main(cfg: DictConfig): map_over(cfg, compute_fn=occlude_outliers_fntr) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/transforms/reorder_measurements.py b/src/MEDS_transforms/transforms/reorder_measurements.py index 1205f77..0ad3747 100644 --- a/src/MEDS_transforms/transforms/reorder_measurements.py +++ b/src/MEDS_transforms/transforms/reorder_measurements.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -"""A polars-to-polars transformation function for filtering patients by sequence length.""" +"""A polars-to-polars transformation function for filtering subjects by sequence length.""" from collections.abc import Callable import hydra @@ -19,7 +19,7 @@ def reorder_by_code_fntr( Args: stage_cfg: The stage-specific configuration object which contains the `ordered_code_patterns` field - that defines the order of the codes within each patient event (unique timepoint). Each element of + that defines the order of the codes within each subject event (unique timepoint). Each element of this list should be a regex pattern that matches codes that should be re-ordered at the index of the regex pattern in the list. Codes are matched in the order of the list, and if a code matches multiple regex patterns, it will be ordered by the first regex pattern that matches it. @@ -34,7 +34,7 @@ def reorder_by_code_fntr( Examples: >>> code_metadata_df = pl.DataFrame({"code": ["A", "A", "B", "C"], "modifier1": [1, 2, 1, 2]}) >>> data = pl.DataFrame({ - ... "patient_id":[1, 1, 2, 2], "time": [1, 1, 1, 1], + ... "subject_id":[1, 1, 2, 2], "time": [1, 1, 1, 1], ... "code": ["A", "B", "A", "C"], "modifier1": [1, 2, 1, 2] ... }) >>> stage_cfg = DictConfig({"ordered_code_patterns": ["B", "A"]}) @@ -42,7 +42,7 @@ def reorder_by_code_fntr( >>> fn(data.lazy()).collect() shape: (4, 4) ┌────────────┬──────┬──────┬───────────┐ - │ patient_id ┆ time ┆ code ┆ modifier1 │ + │ subject_id ┆ time ┆ code ┆ modifier1 │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str ┆ i64 │ ╞════════════╪══════╪══════╪═══════════╡ @@ -55,7 +55,7 @@ def reorder_by_code_fntr( ... "code": ["LAB//foo", "ADMISSION//bar", "LAB//baz", "ADMISSION//qux", "DISCHARGE"], ... }) >>> data = pl.DataFrame({ - ... "patient_id":[1, 1, 1, 2, 2, 2], + ... "subject_id":[1, 1, 1, 2, 2, 2], ... "time": [1, 1, 1, 1, 2, 3], ... "code": ["LAB//foo", "ADMISSION//bar", "LAB//baz", "ADMISSION//qux", "DISCHARGE", "LAB//baz"], ... }) @@ -66,7 +66,7 @@ def reorder_by_code_fntr( >>> fn(data.lazy()).collect() shape: (6, 3) ┌────────────┬──────┬────────────────┐ - │ patient_id ┆ time ┆ code │ + │ subject_id ┆ time ┆ code │ │ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str │ ╞════════════╪══════╪════════════════╡ @@ -81,7 +81,7 @@ def reorder_by_code_fntr( >>> fn(data.lazy()).collect() shape: (6, 3) ┌────────────┬──────┬────────────────┐ - │ patient_id ┆ time ┆ code │ + │ subject_id ┆ time ┆ code │ │ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str │ ╞════════════╪══════╪════════════════╡ @@ -141,7 +141,7 @@ def reorder_fn(df: pl.LazyFrame) -> pl.LazyFrame: return ( df.join(code_indices, on=join_cols, how="left", coalesce=True) - .sort("patient_id", "time", "code_order_idx", maintain_order=True) + .sort("subject_id", "time", "code_order_idx", maintain_order=True) .drop("code_order_idx") ) @@ -152,13 +152,13 @@ def reorder_fn(df: pl.LazyFrame) -> pl.LazyFrame: version_base=None, config_path=str(PREPROCESS_CONFIG_YAML.parent), config_name=PREPROCESS_CONFIG_YAML.stem ) def main(cfg: DictConfig): - """Reorders measurements within each patient event (unique timepoint) by the specified code order. + """Reorders measurements within each subject event (unique timepoint) by the specified code order. In particular, given a set of [regex crate compatible](https://docs.rs/regex/latest/regex/) regexes in the `stage_cfg.ordered_code_patterns` list, this script will re-order the measurements within each event (unique timepoint) such that the measurements are sorted by the index of the first regex that matches their code in the `ordered_code_patterns` list. So, if the `ordered_code_patterns` list is - `["foo$", "bar", "foo.*"]`, and a single patient event has measurements with codes + `["foo$", "bar", "foo.*"]`, and a single subject event has measurements with codes `["foobar", "barbaz", "foo", "quat"]`, the measurements will be re-ordered to the order: `["foo", "foobar", "barbaz", "quat"]`, because: - "foo" matches the first regex in the list (the `foo$` matches any string with "foo" at the end). @@ -176,7 +176,7 @@ def main(cfg: DictConfig): Args: stage_configs.reorder_measurements.ordered_code_patterns: A list of regex patterns that specify the - order of the codes within each patient event (unique timepoint). To specify this on the command + order of the codes within each subject event (unique timepoint). To specify this on the command line, use the hydra list syntax by enclosing the entire key-value string argument in single quotes: ``'stage_configs.reorder_measurements.ordered_code_patterns=["foo$", "bar", "foo.*"]'``. """ @@ -184,5 +184,5 @@ def main(cfg: DictConfig): map_over(cfg, reorder_by_code_fntr) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/transforms/tensorization.py b/src/MEDS_transforms/transforms/tensorization.py index 0266ce2..5fd7389 100644 --- a/src/MEDS_transforms/transforms/tensorization.py +++ b/src/MEDS_transforms/transforms/tensorization.py @@ -32,7 +32,7 @@ def convert_to_NRT(df: pl.LazyFrame) -> JointNestedRaggedTensorDict: Examples: >>> df = pl.DataFrame({ - ... "patient_id": [1, 2], + ... "subject_id": [1, 2], ... "time_delta_days": [[float("nan"), 12.0], [float("nan")]], ... "code": [[[101.0, 102.0], [103.0]], [[201.0, 202.0]]], ... "numeric_value": [[[2.0, 3.0], [4.0]], [[6.0, 7.0]]] @@ -40,7 +40,7 @@ def convert_to_NRT(df: pl.LazyFrame) -> JointNestedRaggedTensorDict: >>> df shape: (2, 4) ┌────────────┬─────────────────┬───────────────────────────┬─────────────────────┐ - │ patient_id ┆ time_delta_days ┆ code ┆ numeric_value │ + │ subject_id ┆ time_delta_days ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ list[f64] ┆ list[list[f64]] ┆ list[list[f64]] │ ╞════════════╪═════════════════╪═══════════════════════════╪═════════════════════╡ @@ -115,5 +115,5 @@ def main(cfg: DictConfig): ) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/transforms/tokenization.py b/src/MEDS_transforms/transforms/tokenization.py index d6f5003..8d60dcb 100644 --- a/src/MEDS_transforms/transforms/tokenization.py +++ b/src/MEDS_transforms/transforms/tokenization.py @@ -6,7 +6,7 @@ All these functions take in _normalized_ data -- meaning data where there are _no longer_ any code modifiers, as those have been normalized alongside codes into integer indices (in the output code column). The only -columns of concern here thus are `patient_id`, `time`, `code`, `numeric_value`. +columns of concern here thus are `subject_id`, `time`, `code`, `numeric_value`. """ from pathlib import Path @@ -69,7 +69,7 @@ def split_static_and_dynamic(df: pl.LazyFrame) -> tuple[pl.LazyFrame, pl.LazyFra Examples: >>> from datetime import datetime >>> df = pl.DataFrame({ - ... "patient_id": [1, 1, 2, 2], + ... "subject_id": [1, 1, 2, 2], ... "time": [None, datetime(2021, 1, 1), None, datetime(2021, 1, 2)], ... "code": [100, 101, 200, 201], ... "numeric_value": [1.0, 2.0, 3.0, 4.0] @@ -78,7 +78,7 @@ def split_static_and_dynamic(df: pl.LazyFrame) -> tuple[pl.LazyFrame, pl.LazyFra >>> static.collect() shape: (2, 3) ┌────────────┬──────┬───────────────┐ - │ patient_id ┆ code ┆ numeric_value │ + │ subject_id ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ f64 │ ╞════════════╪══════╪═══════════════╡ @@ -88,7 +88,7 @@ def split_static_and_dynamic(df: pl.LazyFrame) -> tuple[pl.LazyFrame, pl.LazyFra >>> dynamic.collect() shape: (2, 4) ┌────────────┬─────────────────────┬──────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ datetime[μs] ┆ i64 ┆ f64 │ ╞════════════╪═════════════════════╪══════╪═══════════════╡ @@ -103,19 +103,19 @@ def split_static_and_dynamic(df: pl.LazyFrame) -> tuple[pl.LazyFrame, pl.LazyFra def extract_statics_and_schema(df: pl.LazyFrame) -> pl.LazyFrame: - """This function extracts static data and schema information (sequence of patient unique times). + """This function extracts static data and schema information (sequence of subject unique times). Args: df: The input data. Returns: - A `pl.LazyFrame` object containing the static data and the unique times of the patient, grouped - by patient as lists, in the same order as the patient IDs occurred in the original file. + A `pl.LazyFrame` object containing the static data and the unique times of the subject, grouped + by subject as lists, in the same order as the subject IDs occurred in the original file. Examples: >>> from datetime import datetime >>> df = pl.DataFrame({ - ... "patient_id": [1, 1, 1, 1, 2, 2, 2], + ... "subject_id": [1, 1, 1, 1, 2, 2, 2], ... "time": [ ... None, datetime(2021, 1, 1), datetime(2021, 1, 1), datetime(2021, 1, 13), ... None, datetime(2021, 1, 2), datetime(2021, 1, 2)], @@ -126,17 +126,17 @@ def extract_statics_and_schema(df: pl.LazyFrame) -> pl.LazyFrame: >>> df.drop("time") shape: (2, 4) ┌────────────┬───────────┬───────────────┬─────────────────────┐ - │ patient_id ┆ code ┆ numeric_value ┆ start_time │ + │ subject_id ┆ code ┆ numeric_value ┆ start_time │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ list[i64] ┆ list[f64] ┆ datetime[μs] │ ╞════════════╪═══════════╪═══════════════╪═════════════════════╡ │ 1 ┆ [100] ┆ [1.0] ┆ 2021-01-01 00:00:00 │ │ 2 ┆ [200] ┆ [5.0] ┆ 2021-01-02 00:00:00 │ └────────────┴───────────┴───────────────┴─────────────────────┘ - >>> df.select("patient_id", "time").explode("time") + >>> df.select("subject_id", "time").explode("time") shape: (3, 2) ┌────────────┬─────────────────────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ datetime[μs] │ ╞════════════╪═════════════════════╡ @@ -148,21 +148,21 @@ def extract_statics_and_schema(df: pl.LazyFrame) -> pl.LazyFrame: static, dynamic = split_static_and_dynamic(df) - # This collects static data by patient ID and stores only (as a list) the codes and numeric values. - static_by_patient = static.group_by("patient_id", maintain_order=True).agg("code", "numeric_value") + # This collects static data by subject ID and stores only (as a list) the codes and numeric values. + static_by_subject = static.group_by("subject_id", maintain_order=True).agg("code", "numeric_value") - # This collects the unique times for each patient. - schema_by_patient = dynamic.group_by("patient_id", maintain_order=True).agg( + # This collects the unique times for each subject. + schema_by_subject = dynamic.group_by("subject_id", maintain_order=True).agg( pl.col("time").min().alias("start_time"), pl.col("time").unique(maintain_order=True) ) - # TODO(mmd): Consider tracking patient offset explicitly here. + # TODO(mmd): Consider tracking subject offset explicitly here. - return static_by_patient.join(schema_by_patient, on="patient_id", how="inner") + return static_by_subject.join(schema_by_subject, on="subject_id", how="inner") -def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: - """This function extracts sequences of patient events, which are sequences of measurements. +def extract_seq_of_subject_events(df: pl.LazyFrame) -> pl.LazyFrame: + """This function extracts sequences of subject events, which are sequences of measurements. The result of this can be naturally tensorized into a `JointNestedRaggedTensorDict` object. @@ -170,8 +170,8 @@ def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: df: The input data. Returns: - A `pl.LazyFrame` object containing the sequences of patient events, with the following columns: - - `patient_id`: The patient ID. + A `pl.LazyFrame` object containing the sequences of subject events, with the following columns: + - `subject_id`: The subject ID. - `time_delta_days`: The time delta in days, as a list of floats (ragged). - `code`: The code, as a list of lists of ints (ragged in both levels). - `numeric_value`: The numeric value as a list of lists of floats (ragged in both levels). @@ -179,17 +179,17 @@ def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: Examples: >>> from datetime import datetime >>> df = pl.DataFrame({ - ... "patient_id": [1, 1, 1, 1, 2, 2, 2], + ... "subject_id": [1, 1, 1, 1, 2, 2, 2], ... "time": [ ... None, datetime(2021, 1, 1), datetime(2021, 1, 1), datetime(2021, 1, 13), ... None, datetime(2021, 1, 2), datetime(2021, 1, 2)], ... "code": [100, 101, 102, 103, 200, 201, 202], ... "numeric_value": pl.Series([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], dtype=pl.Float32) ... }).lazy() - >>> extract_seq_of_patient_events(df).collect() + >>> extract_seq_of_subject_events(df).collect() shape: (2, 4) ┌────────────┬─────────────────┬─────────────────────┬─────────────────────┐ - │ patient_id ┆ time_delta_days ┆ code ┆ numeric_value │ + │ subject_id ┆ time_delta_days ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ list[f32] ┆ list[list[i64]] ┆ list[list[f32]] │ ╞════════════╪═════════════════╪═════════════════════╪═════════════════════╡ @@ -203,9 +203,9 @@ def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: time_delta_days_expr = (pl.col("time").diff().dt.total_seconds() / SECONDS_PER_DAY).cast(pl.Float32) return ( - dynamic.group_by("patient_id", "time", maintain_order=True) + dynamic.group_by("subject_id", "time", maintain_order=True) .agg(pl.col("code").name.keep(), fill_to_nans("numeric_value").name.keep()) - .group_by("patient_id", maintain_order=True) + .group_by("subject_id", maintain_order=True) .agg( fill_to_nans(time_delta_days_expr).alias("time_delta_days"), "code", @@ -229,11 +229,10 @@ def main(cfg: DictConfig): ) output_dir = Path(cfg.stage_cfg.output_dir) + if train_only := cfg.stage_cfg.get("train_only", False): + raise ValueError(f"train_only={train_only} is not supported for this stage.") shards_single_output, include_only_train = shard_iterator(cfg) - if include_only_train: - raise ValueError("Not supported for this stage.") - for in_fp, out_fp in shards_single_output: sharded_path = out_fp.relative_to(output_dir) @@ -258,12 +257,12 @@ def main(cfg: DictConfig): event_seq_out_fp, pl.scan_parquet, write_lazyframe, - extract_seq_of_patient_events, + extract_seq_of_subject_events, do_overwrite=cfg.do_overwrite, ) logger.info(f"Done with {cfg.stage}") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/utils.py b/src/MEDS_transforms/utils.py index 59a7cd6..b62f7d1 100644 --- a/src/MEDS_transforms/utils.py +++ b/src/MEDS_transforms/utils.py @@ -412,15 +412,15 @@ def is_col_field(field: str | None) -> bool: bool: True if the field is formatted as "col(column_name)", False otherwise. Examples: - >>> is_col_field("col(patient_id)") + >>> is_col_field("col(subject_id)") True - >>> is_col_field("col(patient_id") + >>> is_col_field("col(subject_id") False - >>> is_col_field("patient_id)") + >>> is_col_field("subject_id)") False - >>> is_col_field("column(patient_id)") + >>> is_col_field("column(subject_id)") False - >>> is_col_field("patient_id") + >>> is_col_field("subject_id") False >>> is_col_field(None) False @@ -440,16 +440,16 @@ def parse_col_field(field: str) -> str: ValueError: If the input string does not match the expected format. Examples: - >>> parse_col_field("col(patient_id)") - 'patient_id' - >>> parse_col_field("col(patient_id") + >>> parse_col_field("col(subject_id)") + 'subject_id' + >>> parse_col_field("col(subject_id") Traceback (most recent call last): ... - ValueError: Invalid column field: col(patient_id - >>> parse_col_field("column(patient_id)") + ValueError: Invalid column field: col(subject_id + >>> parse_col_field("column(subject_id)") Traceback (most recent call last): ... - ValueError: Invalid column field: column(patient_id) + ValueError: Invalid column field: column(subject_id) """ if not is_col_field(field): raise ValueError(f"Invalid column field: {field}") diff --git a/tests/MEDS_Extract/__init__.py b/tests/MEDS_Extract/__init__.py new file mode 100644 index 0000000..14ddbce --- /dev/null +++ b/tests/MEDS_Extract/__init__.py @@ -0,0 +1,24 @@ +import os + +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + +extraction_root = root / "src" / "MEDS_transforms" / "extract" + +if os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1": + SHARD_EVENTS_SCRIPT = extraction_root / "shard_events.py" + SPLIT_AND_SHARD_SCRIPT = extraction_root / "split_and_shard_subjects.py" + CONVERT_TO_SHARDED_EVENTS_SCRIPT = extraction_root / "convert_to_sharded_events.py" + MERGE_TO_MEDS_COHORT_SCRIPT = extraction_root / "merge_to_MEDS_cohort.py" + EXTRACT_CODE_METADATA_SCRIPT = extraction_root / "extract_code_metadata.py" + FINALIZE_DATA_SCRIPT = extraction_root / "finalize_MEDS_data.py" + FINALIZE_METADATA_SCRIPT = extraction_root / "finalize_MEDS_metadata.py" +else: + SHARD_EVENTS_SCRIPT = "MEDS_extract-shard_events" + SPLIT_AND_SHARD_SCRIPT = "MEDS_extract-split_and_shard_subjects" + CONVERT_TO_SHARDED_EVENTS_SCRIPT = "MEDS_extract-convert_to_sharded_events" + MERGE_TO_MEDS_COHORT_SCRIPT = "MEDS_extract-merge_to_MEDS_cohort" + EXTRACT_CODE_METADATA_SCRIPT = "MEDS_extract-extract_code_metadata" + FINALIZE_DATA_SCRIPT = "MEDS_extract-finalize_MEDS_data" + FINALIZE_METADATA_SCRIPT = "MEDS_extract-finalize_MEDS_metadata" diff --git a/tests/MEDS_Extract/test_convert_to_sharded_events.py b/tests/MEDS_Extract/test_convert_to_sharded_events.py new file mode 100644 index 0000000..074e897 --- /dev/null +++ b/tests/MEDS_Extract/test_convert_to_sharded_events.py @@ -0,0 +1,229 @@ +"""Tests the convert to sharded events process. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + +from io import StringIO + +import polars as pl + +from tests.MEDS_Extract import CONVERT_TO_SHARDED_EVENTS_SCRIPT +from tests.utils import parse_shards_yaml, single_stage_tester + +SUBJECTS_CSV = """ +MRN,dob,eye_color,height +1195293,06/20/1978,BLUE,164.6868838269085 +239684,12/28/1980,BROWN,175.271115221764 +1500733,07/20/1986,BROWN,158.60131573580904 +814703,03/28/1976,HAZEL,156.48559093209357 +754281,12/19/1988,BROWN,166.22261567137025 +68729,03/09/1978,HAZEL,160.3953106166676 +""" + +ADMIT_VITALS_0_10_CSV = """ +subject_id,admit_date,disch_date,department,vitals_date,HR,temp +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:57:18",112.6,95.5 +754281,"01/03/2010, 06:27:59","01/03/2010, 08:22:13",PULMONARY,"01/03/2010, 06:27:59",142.0,99.8 +814703,"02/05/2010, 05:55:39","02/05/2010, 07:02:30",ORTHOPEDIC,"02/05/2010, 05:55:39",170.2,100.1 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:25:35",113.4,95.8 +68729,"05/26/2010, 02:30:56","05/26/2010, 04:51:52",PULMONARY,"05/26/2010, 02:30:56",86.0,97.8 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:12:31",112.5,99.8 +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 16:20:49",90.1,100.1 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 17:48:48",105.1,96.2 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 17:41:51",102.6,96.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:25:32",114.1,100.0 +""" + +ADMIT_VITALS_10_16_CSV = """ +subject_id,admit_date,disch_date,department,vitals_date,HR,temp +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 14:54:38",91.4,100.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:41:33",107.5,100.4 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:24:44",107.7,100.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:45:19",119.8,99.9 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:23:52",109.0,100.0 +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 15:39:49",84.4,100.3 +""" + +EVENT_CFGS_YAML = """ +subjects: + subject_id_col: MRN + eye_color: + code: + - EYE_COLOR + - col(eye_color) + time: null + _metadata: + demo_metadata: + description: description + height: + code: HEIGHT + time: null + numeric_value: height + dob: + code: DOB + time: col(dob) + time_format: "%m/%d/%Y" +admit_vitals: + admissions: + code: + - ADMISSION + - col(department) + time: col(admit_date) + time_format: "%m/%d/%Y, %H:%M:%S" + discharge: + code: DISCHARGE + time: col(disch_date) + time_format: "%m/%d/%Y, %H:%M:%S" + HR: + code: HR + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: HR + _metadata: + input_metadata: + description: {"title": {"lab_code": "HR"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "HR"}} + temp: + code: TEMP + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: temp + _metadata: + input_metadata: + description: {"title": {"lab_code": "temp"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "temp"}} +""" + +SHARDS_JSON = { + "train/0": [239684, 1195293], + "train/1": [68729, 814703], + "tuning/0": [754281], + "held_out/0": [1500733], +} + +WANT_OUTPUTS = parse_shards_yaml( + """ + data/train/0/subjects/[0-6).parquet: |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + + data/train/1/subjects/[0-6).parquet: |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, + + data/tuning/0/subjects/[0-6).parquet: |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, + + data/held_out/0/subjects/[0-6).parquet: |-2 + subject_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + + data/train/0/admit_vitals/[0-10).parquet: |-2 + subject_id,time,code,numeric_value + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + data/train/0/admit_vitals/[10-16).parquet: |-2 + subject_id,time,code,numeric_value + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + data/train/1/admit_vitals/[0-10).parquet: |-2 + subject_id,time,code,numeric_value + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, + 68729,"05/26/2010, 02:30:56",HR,86.0 + 68729,"05/26/2010, 02:30:56",TEMP,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, + 814703,"02/05/2010, 05:55:39",HR,170.2 + 814703,"02/05/2010, 05:55:39",TEMP,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE, + + data/train/1/admit_vitals/[10-16).parquet: |-2 + subject_id,time,code,numeric_value + + data/tuning/0/admit_vitals/[0-10).parquet: |-2 + subject_id,time,code,numeric_value + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, + 754281,"01/03/2010, 06:27:59",HR,142.0 + 754281,"01/03/2010, 06:27:59",TEMP,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE, + + data/tuning/0/admit_vitals/[10-16).parquet: |-2 + subject_id,time,code,numeric_value + + data/held_out/0/admit_vitals/[0-10).parquet: |-2 + subject_id,time,code,numeric_value + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + + data/held_out/0/admit_vitals/[10-16).parquet: |-2 + subject_id,time,code,numeric_value + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + """ +) + + +def test_convert_to_sharded_events(): + single_stage_tester( + script=CONVERT_TO_SHARDED_EVENTS_SCRIPT, + stage_name="convert_to_sharded_events", + stage_kwargs=None, + config_name="extract", + input_files={ + "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), + "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_0_10_CSV)), + "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_10_16_CSV)), + "event_cfgs.yaml": EVENT_CFGS_YAML, + "metadata/.shards.json": SHARDS_JSON, + }, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + shards_map_fp="{input_dir}/metadata/.shards.json", + want_outputs=WANT_OUTPUTS, + df_check_kwargs={"check_row_order": False, "check_column_order": False, "check_dtypes": False}, + ) diff --git a/tests/test_extract.py b/tests/MEDS_Extract/test_extract.py similarity index 86% rename from tests/test_extract.py rename to tests/MEDS_Extract/test_extract.py index d8a3c3e..954ca21 100644 --- a/tests/test_extract.py +++ b/tests/MEDS_Extract/test_extract.py @@ -4,32 +4,6 @@ scripts. """ -import os - -import rootutils - -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) - -code_root = root / "src" / "MEDS_transforms" -extraction_root = code_root / "extract" - -if os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1": - SHARD_EVENTS_SCRIPT = extraction_root / "shard_events.py" - SPLIT_AND_SHARD_SCRIPT = extraction_root / "split_and_shard_patients.py" - CONVERT_TO_SHARDED_EVENTS_SCRIPT = extraction_root / "convert_to_sharded_events.py" - MERGE_TO_MEDS_COHORT_SCRIPT = extraction_root / "merge_to_MEDS_cohort.py" - EXTRACT_CODE_METADATA_SCRIPT = extraction_root / "extract_code_metadata.py" - FINALIZE_DATA_SCRIPT = extraction_root / "finalize_MEDS_data.py" - FINALIZE_METADATA_SCRIPT = extraction_root / "finalize_MEDS_metadata.py" -else: - SHARD_EVENTS_SCRIPT = "MEDS_extract-shard_events" - SPLIT_AND_SHARD_SCRIPT = "MEDS_extract-split_and_shard_patients" - CONVERT_TO_SHARDED_EVENTS_SCRIPT = "MEDS_extract-convert_to_sharded_events" - MERGE_TO_MEDS_COHORT_SCRIPT = "MEDS_extract-merge_to_MEDS_cohort" - EXTRACT_CODE_METADATA_SCRIPT = "MEDS_extract-extract_code_metadata" - FINALIZE_DATA_SCRIPT = "MEDS_extract-finalize_MEDS_data" - FINALIZE_METADATA_SCRIPT = "MEDS_extract-finalize_MEDS_metadata" - import json import tempfile from io import StringIO @@ -38,7 +12,16 @@ import polars as pl from meds import __version__ as MEDS_VERSION -from .utils import assert_df_equal, run_command +from tests.MEDS_Extract import ( + CONVERT_TO_SHARDED_EVENTS_SCRIPT, + EXTRACT_CODE_METADATA_SCRIPT, + FINALIZE_DATA_SCRIPT, + FINALIZE_METADATA_SCRIPT, + MERGE_TO_MEDS_COHORT_SCRIPT, + SHARD_EVENTS_SCRIPT, + SPLIT_AND_SHARD_SCRIPT, +) +from tests.utils import assert_df_equal, run_command # Test data (inputs) @@ -53,7 +36,7 @@ """ ADMIT_VITALS_CSV = """ -patient_id,admit_date,disch_date,department,vitals_date,HR,temp +subject_id,admit_date,disch_date,department,vitals_date,HR,temp 239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:57:18",112.6,95.5 754281,"01/03/2010, 06:27:59","01/03/2010, 08:22:13",PULMONARY,"01/03/2010, 06:27:59",142.0,99.8 814703,"02/05/2010, 05:55:39","02/05/2010, 07:02:30",ORTHOPEDIC,"02/05/2010, 05:55:39",170.2,100.1 @@ -88,7 +71,7 @@ EVENT_CFGS_YAML = """ subjects: - patient_id_col: MRN + subject_id_col: MRN eye_color: code: - EYE_COLOR @@ -144,9 +127,9 @@ "held_out/0": [1500733], } -PATIENT_SPLITS_DF = pl.DataFrame( +SUBJECT_SPLITS_DF = pl.DataFrame( { - "patient_id": [239684, 1195293, 68729, 814703, 754281, 1500733], + "subject_id": [239684, 1195293, 68729, 814703, 754281, 1500733], "split": ["train", "train", "train", "train", "tuning", "held_out"], } ) @@ -156,17 +139,17 @@ def get_expected_output(df: str) -> pl.DataFrame: return ( pl.read_csv(source=StringIO(df)) .select( - "patient_id", + "subject_id", pl.col("time").str.strptime(pl.Datetime, "%m/%d/%Y, %H:%M:%S").alias("time"), pl.col("code"), "numeric_value", ) - .sort(by=["patient_id", "time"]) + .sort(by=["subject_id", "time"]) ) MEDS_OUTPUT_TRAIN_0_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -176,7 +159,7 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TRAIN_0_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, 239684,"05/11/2010, 17:41:51",HR,102.6 239684,"05/11/2010, 17:41:51",TEMP,96.0 @@ -204,7 +187,7 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TRAIN_1_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,,EYE_COLOR//HAZEL, 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00",DOB, @@ -214,7 +197,7 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TRAIN_1_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, 68729,"05/26/2010, 02:30:56",HR,86.0 68729,"05/26/2010, 02:30:56",TEMP,97.8 @@ -226,14 +209,14 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TUNING_0_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,,EYE_COLOR//BROWN, 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, """ MEDS_OUTPUT_TUNING_0_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, 754281,"01/03/2010, 06:27:59",HR,142.0 754281,"01/03/2010, 06:27:59",TEMP,99.8 @@ -241,14 +224,14 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_HELD_OUT_0_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, """ MEDS_OUTPUT_HELD_OUT_0_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, 1500733,"06/03/2010, 14:54:38",HR,91.4 1500733,"06/03/2010, 14:54:38",TEMP,100.0 @@ -338,19 +321,19 @@ def test_extraction(): # Run the extraction script # 1. Sub-shard the data (this will be a null operation in this case, but it is worth doing just in # case. - # 2. Collect the patient splits. - # 3. Extract the events and sub-shard by patient. + # 2. Collect the subject splits. + # 3. Extract the events and sub-shard by subject. # 4. Merge to the final output. extraction_config_kwargs = { "input_dir": str(raw_cohort_dir.resolve()), "cohort_dir": str(MEDS_cohort_dir.resolve()), "event_conversion_config_fp": str(event_cfgs_yaml.resolve()), - "stage_configs.split_and_shard_patients.split_fracs.train": 4 / 6, - "stage_configs.split_and_shard_patients.split_fracs.tuning": 1 / 6, - "stage_configs.split_and_shard_patients.split_fracs.held_out": 1 / 6, + "stage_configs.split_and_shard_subjects.split_fracs.train": 4 / 6, + "stage_configs.split_and_shard_subjects.split_fracs.tuning": 1 / 6, + "stage_configs.split_and_shard_subjects.split_fracs.held_out": 1 / 6, "stage_configs.shard_events.row_chunksize": 10, - "stage_configs.split_and_shard_patients.n_patients_per_shard": 2, + "stage_configs.split_and_shard_subjects.n_subjects_per_shard": 2, "hydra.verbose": True, "etl_metadata.dataset_name": "TEST", "etl_metadata.dataset_version": "1.0", @@ -405,11 +388,11 @@ def test_extraction(): check_row_order=False, ) - # Stage 2: Collect the patient splits + # Stage 2: Collect the subject splits stderr, stdout = run_command( SPLIT_AND_SHARD_SCRIPT, extraction_config_kwargs, - "split_and_shard_patients", + "split_and_shard_subjects", ) all_stderrs.append(stderr) @@ -435,12 +418,12 @@ def test_extraction(): "NEEDING TO BE UPDATED." ) except AssertionError as e: - print("Failed to split patients") + print("Failed to split subjects") print(f"stderr:\n{stderr}") print(f"stdout:\n{stdout}") raise e - # Stage 3: Extract the events and sub-shard by patient + # Stage 3: Extract the events and sub-shard by subject stderr, stdout = run_command( CONVERT_TO_SHARDED_EVENTS_SCRIPT, extraction_config_kwargs, @@ -449,8 +432,8 @@ def test_extraction(): all_stderrs.append(stderr) all_stdouts.append(stdout) - patient_subsharded_folder = MEDS_cohort_dir / "convert_to_sharded_events" - assert patient_subsharded_folder.is_dir(), f"Expected {patient_subsharded_folder} to be a directory." + subject_subsharded_folder = MEDS_cohort_dir / "convert_to_sharded_events" + assert subject_subsharded_folder.is_dir(), f"Expected {subject_subsharded_folder} to be a directory." for split, expected_outputs in SUB_SHARDED_OUTPUTS.items(): for prefix, expected_df_L in expected_outputs.items(): @@ -459,7 +442,7 @@ def test_extraction(): expected_df = pl.concat([get_expected_output(df) for df in expected_df_L]) - fps = list((patient_subsharded_folder / split / prefix).glob("*.parquet")) + fps = list((subject_subsharded_folder / split / prefix).glob("*.parquet")) assert len(fps) > 0 # We add a "unique" here as there may be some duplicates across the row-group sub-shards. @@ -511,12 +494,12 @@ def test_extraction(): check_row_order=False, ) - assert got_df["patient_id"].is_sorted(), f"Patient IDs should be sorted for split {split}." + assert got_df["subject_id"].is_sorted(), f"Subject IDs should be sorted for split {split}." for subj in splits[split]: - got_df_subj = got_df.filter(pl.col("patient_id") == subj) + got_df_subj = got_df.filter(pl.col("subject_id") == subj) assert got_df_subj[ "time" - ].is_sorted(), f"Times should be sorted for patient {subj} in split {split}." + ].is_sorted(), f"Times should be sorted for subject {subj} in split {split}." except AssertionError as e: print(f"Failed on split {split}") @@ -542,14 +525,9 @@ def test_extraction(): got_df = pl.read_parquet(output_file, glob=False) want_df = pl.read_csv(source=StringIO(MEDS_OUTPUT_CODE_METADATA_FILE)).with_columns( - pl.col("code"), pl.col("parent_codes").cast(pl.List(pl.Utf8)), ) - # We collapse the list type as it throws an error in the assert_df_equal otherwise - got_df = got_df.with_columns(pl.col("parent_codes").list.join("||")) - want_df = want_df.with_columns(pl.col("parent_codes").list.join("||")) - assert_df_equal( want=want_df, got=got_df, @@ -593,12 +571,12 @@ def test_extraction(): check_row_order=False, ) - assert got_df["patient_id"].is_sorted(), f"Patient IDs should be sorted for split {split}." + assert got_df["subject_id"].is_sorted(), f"Subject IDs should be sorted for split {split}." for subj in splits[split]: - got_df_subj = got_df.filter(pl.col("patient_id") == subj) + got_df_subj = got_df.filter(pl.col("subject_id") == subj) assert got_df_subj[ "time" - ].is_sorted(), f"Times should be sorted for patient {subj} in split {split}." + ].is_sorted(), f"Times should be sorted for subject {subj} in split {split}." except AssertionError as e: print(f"Failed on split {split}") @@ -651,14 +629,14 @@ def test_extraction(): assert got_json == MEDS_OUTPUT_DATASET_METADATA_JSON, f"Dataset metadata differs: {got_json}" # Check the splits parquet - output_file = MEDS_cohort_dir / "metadata" / "patient_splits.parquet" + output_file = MEDS_cohort_dir / "metadata" / "subject_splits.parquet" assert output_file.is_file(), f"Expected {output_file} to exist: stderr:\n{stderr}\nstdout:\n{stdout}" got_df = pl.read_parquet(output_file, glob=False, use_pyarrow=True) assert_df_equal( - PATIENT_SPLITS_DF, + SUBJECT_SPLITS_DF, got_df, - "Patient splits should be equal to the expected splits.", + "Subject splits should be equal to the expected splits.", check_column_order=False, check_row_order=False, ) diff --git a/tests/MEDS_Extract/test_extract_code_metadata.py b/tests/MEDS_Extract/test_extract_code_metadata.py new file mode 100644 index 0000000..74a3d02 --- /dev/null +++ b/tests/MEDS_Extract/test_extract_code_metadata.py @@ -0,0 +1,199 @@ +"""Tests the extract code metadata process. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + + +import polars as pl + +from tests.MEDS_Extract import EXTRACT_CODE_METADATA_SCRIPT +from tests.utils import parse_shards_yaml, single_stage_tester + +INPUT_SHARDS = parse_shards_yaml( + """ + data/train/0: |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + + data/train/1: |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, + 68729,"05/26/2010, 02:30:56",HR,86.0 + 68729,"05/26/2010, 02:30:56",TEMP,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, + 814703,"02/05/2010, 05:55:39",HR,170.2 + 814703,"02/05/2010, 05:55:39",TEMP,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE, + + data/tuning/0: |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, + 754281,"01/03/2010, 06:27:59",HR,142.0 + 754281,"01/03/2010, 06:27:59",TEMP,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE, + + data/held_out/0: |-2 + subject_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + """ +) + + +INPUT_METADATA_FILE = """ +lab_code,title,loinc +HR,Heart Rate,8867-4 +temp,Body Temperature,8310-5 +""" + +DEMO_METADATA_FILE = """ +eye_color,description +BROWN,"Brown Eyes. The most common eye color." +BLUE,"Blue Eyes. Less common than brown." +HAZEL,"Hazel eyes. These are uncommon" +GREEN,"Green eyes. These are rare." +""" + +EVENT_CFGS_YAML = """ +subjects: + subject_id_col: MRN + eye_color: + code: + - EYE_COLOR + - col(eye_color) + time: null + _metadata: + demo_metadata: + description: description + height: + code: HEIGHT + time: null + numeric_value: height + dob: + code: DOB + time: col(dob) + time_format: "%m/%d/%Y" +admit_vitals: + admissions: + code: + - ADMISSION + - col(department) + time: col(admit_date) + time_format: "%m/%d/%Y, %H:%M:%S" + discharge: + code: DISCHARGE + time: col(disch_date) + time_format: "%m/%d/%Y, %H:%M:%S" + HR: + code: HR + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: HR + _metadata: + input_metadata: + description: {"title": {"lab_code": "HR"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "HR"}} + temp: + code: TEMP + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: temp + _metadata: + input_metadata: + description: {"title": {"lab_code": "temp"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "temp"}} +""" + +SHARDS_JSON = { + "train/0": [239684, 1195293], + "train/1": [68729, 814703], + "tuning/0": [754281], + "held_out/0": [1500733], +} + +WANT_OUTPUTS = { + "metadata/codes": pl.DataFrame( + { + "code": ["EYE_COLOR//BLUE", "EYE_COLOR//BROWN", "EYE_COLOR//HAZEL", "HR", "TEMP"], + "description": [ + "Blue Eyes. Less common than brown.", + "Brown Eyes. The most common eye color.", + "Hazel eyes. These are uncommon", + "Heart Rate", + "Body Temperature", + ], + "parent_codes": [None, None, None, ["LOINC/8867-4"], ["LOINC/8310-5"]], + } + ), +} + + +def test_convert_to_sharded_events(): + single_stage_tester( + script=EXTRACT_CODE_METADATA_SCRIPT, + stage_name="extract_code_metadata", + stage_kwargs=None, + config_name="extract", + input_files={ + **INPUT_SHARDS, + "demo_metadata.csv": DEMO_METADATA_FILE, + "input_metadata.csv": INPUT_METADATA_FILE, + "event_cfgs.yaml": EVENT_CFGS_YAML, + "metadata/.shards.json": SHARDS_JSON, + }, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + shards_map_fp="{input_dir}/metadata/.shards.json", + want_outputs=WANT_OUTPUTS, + df_check_kwargs={"check_row_order": False, "check_column_order": False, "check_dtypes": True}, + assert_no_other_outputs=False, + ) diff --git a/tests/test_extract_no_metadata.py b/tests/MEDS_Extract/test_extract_no_metadata.py similarity index 86% rename from tests/test_extract_no_metadata.py rename to tests/MEDS_Extract/test_extract_no_metadata.py index f1945af..0fa8eec 100644 --- a/tests/test_extract_no_metadata.py +++ b/tests/MEDS_Extract/test_extract_no_metadata.py @@ -4,32 +4,6 @@ scripts. """ -import os - -import rootutils - -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) - -code_root = root / "src" / "MEDS_transforms" -extraction_root = code_root / "extract" - -if os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1": - SHARD_EVENTS_SCRIPT = extraction_root / "shard_events.py" - SPLIT_AND_SHARD_SCRIPT = extraction_root / "split_and_shard_patients.py" - CONVERT_TO_SHARDED_EVENTS_SCRIPT = extraction_root / "convert_to_sharded_events.py" - MERGE_TO_MEDS_COHORT_SCRIPT = extraction_root / "merge_to_MEDS_cohort.py" - EXTRACT_CODE_METADATA_SCRIPT = extraction_root / "extract_code_metadata.py" - FINALIZE_DATA_SCRIPT = extraction_root / "finalize_MEDS_data.py" - FINALIZE_METADATA_SCRIPT = extraction_root / "finalize_MEDS_metadata.py" -else: - SHARD_EVENTS_SCRIPT = "MEDS_extract-shard_events" - SPLIT_AND_SHARD_SCRIPT = "MEDS_extract-split_and_shard_patients" - CONVERT_TO_SHARDED_EVENTS_SCRIPT = "MEDS_extract-convert_to_sharded_events" - MERGE_TO_MEDS_COHORT_SCRIPT = "MEDS_extract-merge_to_MEDS_cohort" - EXTRACT_CODE_METADATA_SCRIPT = "MEDS_extract-extract_code_metadata" - FINALIZE_DATA_SCRIPT = "MEDS_extract-finalize_MEDS_data" - FINALIZE_METADATA_SCRIPT = "MEDS_extract-finalize_MEDS_metadata" - import json import tempfile from io import StringIO @@ -38,7 +12,16 @@ import polars as pl from meds import __version__ as MEDS_VERSION -from .utils import assert_df_equal, run_command +from tests.MEDS_Extract import ( + CONVERT_TO_SHARDED_EVENTS_SCRIPT, + EXTRACT_CODE_METADATA_SCRIPT, + FINALIZE_DATA_SCRIPT, + FINALIZE_METADATA_SCRIPT, + MERGE_TO_MEDS_COHORT_SCRIPT, + SHARD_EVENTS_SCRIPT, + SPLIT_AND_SHARD_SCRIPT, +) +from tests.utils import assert_df_equal, run_command # Test data (inputs) @@ -53,7 +36,7 @@ """ ADMIT_VITALS_CSV = """ -patient_id,admit_date,disch_date,department,vitals_date,HR,temp +subject_id,admit_date,disch_date,department,vitals_date,HR,temp 239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:57:18",112.6,95.5 754281,"01/03/2010, 06:27:59","01/03/2010, 08:22:13",PULMONARY,"01/03/2010, 06:27:59",142.0,99.8 814703,"02/05/2010, 05:55:39","02/05/2010, 07:02:30",ORTHOPEDIC,"02/05/2010, 05:55:39",170.2,100.1 @@ -88,7 +71,7 @@ EVENT_CFGS_YAML = """ subjects: - patient_id_col: MRN + subject_id_col: MRN eye_color: code: - EYE_COLOR @@ -133,9 +116,9 @@ "held_out/0": [1500733], } -PATIENT_SPLITS_DF = pl.DataFrame( +SUBJECT_SPLITS_DF = pl.DataFrame( { - "patient_id": [239684, 1195293, 68729, 814703, 754281, 1500733], + "subject_id": [239684, 1195293, 68729, 814703, 754281, 1500733], "split": ["train", "train", "train", "train", "tuning", "held_out"], } ) @@ -145,17 +128,17 @@ def get_expected_output(df: str) -> pl.DataFrame: return ( pl.read_csv(source=StringIO(df)) .select( - "patient_id", + "subject_id", pl.col("time").str.strptime(pl.Datetime, "%m/%d/%Y, %H:%M:%S").alias("time"), pl.col("code"), "numeric_value", ) - .sort(by=["patient_id", "time"]) + .sort(by=["subject_id", "time"]) ) MEDS_OUTPUT_TRAIN_0_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -165,7 +148,7 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TRAIN_0_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, 239684,"05/11/2010, 17:41:51",HR,102.6 239684,"05/11/2010, 17:41:51",TEMP,96.0 @@ -193,7 +176,7 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TRAIN_1_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,,EYE_COLOR//HAZEL, 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00",DOB, @@ -203,7 +186,7 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TRAIN_1_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, 68729,"05/26/2010, 02:30:56",HR,86.0 68729,"05/26/2010, 02:30:56",TEMP,97.8 @@ -215,14 +198,14 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TUNING_0_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,,EYE_COLOR//BROWN, 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, """ MEDS_OUTPUT_TUNING_0_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, 754281,"01/03/2010, 06:27:59",HR,142.0 754281,"01/03/2010, 06:27:59",TEMP,99.8 @@ -230,14 +213,14 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_HELD_OUT_0_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, """ MEDS_OUTPUT_HELD_OUT_0_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, 1500733,"06/03/2010, 14:54:38",HR,91.4 1500733,"06/03/2010, 14:54:38",TEMP,100.0 @@ -322,19 +305,19 @@ def test_extraction(): # Run the extraction script # 1. Sub-shard the data (this will be a null operation in this case, but it is worth doing just in # case. - # 2. Collect the patient splits. - # 3. Extract the events and sub-shard by patient. + # 2. Collect the subject splits. + # 3. Extract the events and sub-shard by subject. # 4. Merge to the final output. extraction_config_kwargs = { "input_dir": str(raw_cohort_dir.resolve()), "cohort_dir": str(MEDS_cohort_dir.resolve()), "event_conversion_config_fp": str(event_cfgs_yaml.resolve()), - "stage_configs.split_and_shard_patients.split_fracs.train": 4 / 6, - "stage_configs.split_and_shard_patients.split_fracs.tuning": 1 / 6, - "stage_configs.split_and_shard_patients.split_fracs.held_out": 1 / 6, + "stage_configs.split_and_shard_subjects.split_fracs.train": 4 / 6, + "stage_configs.split_and_shard_subjects.split_fracs.tuning": 1 / 6, + "stage_configs.split_and_shard_subjects.split_fracs.held_out": 1 / 6, "stage_configs.shard_events.row_chunksize": 10, - "stage_configs.split_and_shard_patients.n_patients_per_shard": 2, + "stage_configs.split_and_shard_subjects.n_subjects_per_shard": 2, "hydra.verbose": True, "etl_metadata.dataset_name": "TEST", "etl_metadata.dataset_version": "1.0", @@ -389,11 +372,11 @@ def test_extraction(): check_row_order=False, ) - # Stage 2: Collect the patient splits + # Stage 2: Collect the subject splits stderr, stdout = run_command( SPLIT_AND_SHARD_SCRIPT, extraction_config_kwargs, - "split_and_shard_patients", + "split_and_shard_subjects", ) all_stderrs.append(stderr) @@ -419,12 +402,12 @@ def test_extraction(): "NEEDING TO BE UPDATED." ) except AssertionError as e: - print("Failed to split patients") + print("Failed to split subjects") print(f"stderr:\n{stderr}") print(f"stdout:\n{stdout}") raise e - # Stage 3: Extract the events and sub-shard by patient + # Stage 3: Extract the events and sub-shard by subject stderr, stdout = run_command( CONVERT_TO_SHARDED_EVENTS_SCRIPT, extraction_config_kwargs, @@ -433,8 +416,8 @@ def test_extraction(): all_stderrs.append(stderr) all_stdouts.append(stdout) - patient_subsharded_folder = MEDS_cohort_dir / "convert_to_sharded_events" - assert patient_subsharded_folder.is_dir(), f"Expected {patient_subsharded_folder} to be a directory." + subject_subsharded_folder = MEDS_cohort_dir / "convert_to_sharded_events" + assert subject_subsharded_folder.is_dir(), f"Expected {subject_subsharded_folder} to be a directory." for split, expected_outputs in SUB_SHARDED_OUTPUTS.items(): for prefix, expected_df_L in expected_outputs.items(): @@ -443,7 +426,7 @@ def test_extraction(): expected_df = pl.concat([get_expected_output(df) for df in expected_df_L]) - fps = list((patient_subsharded_folder / split / prefix).glob("*.parquet")) + fps = list((subject_subsharded_folder / split / prefix).glob("*.parquet")) assert len(fps) > 0 # We add a "unique" here as there may be some duplicates across the row-group sub-shards. @@ -495,12 +478,12 @@ def test_extraction(): check_row_order=False, ) - assert got_df["patient_id"].is_sorted(), f"Patient IDs should be sorted for split {split}." + assert got_df["subject_id"].is_sorted(), f"Subject IDs should be sorted for split {split}." for subj in splits[split]: - got_df_subj = got_df.filter(pl.col("patient_id") == subj) + got_df_subj = got_df.filter(pl.col("subject_id") == subj) assert got_df_subj[ "time" - ].is_sorted(), f"Times should be sorted for patient {subj} in split {split}." + ].is_sorted(), f"Times should be sorted for subject {subj} in split {split}." except AssertionError as e: print(f"Failed on split {split}") @@ -560,12 +543,12 @@ def test_extraction(): check_row_order=False, ) - assert got_df["patient_id"].is_sorted(), f"Patient IDs should be sorted for split {split}." + assert got_df["subject_id"].is_sorted(), f"Subject IDs should be sorted for split {split}." for subj in splits[split]: - got_df_subj = got_df.filter(pl.col("patient_id") == subj) + got_df_subj = got_df.filter(pl.col("subject_id") == subj) assert got_df_subj[ "time" - ].is_sorted(), f"Times should be sorted for patient {subj} in split {split}." + ].is_sorted(), f"Times should be sorted for subject {subj} in split {split}." except AssertionError as e: print(f"Failed on split {split}") @@ -618,14 +601,14 @@ def test_extraction(): assert got_json == MEDS_OUTPUT_DATASET_METADATA_JSON, f"Dataset metadata differs: {got_json}" # Check the splits parquet - output_file = MEDS_cohort_dir / "metadata" / "patient_splits.parquet" + output_file = MEDS_cohort_dir / "metadata" / "subject_splits.parquet" assert output_file.is_file(), f"Expected {output_file} to exist: stderr:\n{stderr}\nstdout:\n{stdout}" got_df = pl.read_parquet(output_file, glob=False, use_pyarrow=True) assert_df_equal( - PATIENT_SPLITS_DF, + SUBJECT_SPLITS_DF, got_df, - "Patient splits should be equal to the expected splits.", + "Subject splits should be equal to the expected splits.", check_column_order=False, check_row_order=False, ) diff --git a/tests/MEDS_Extract/test_finalize_MEDS_data.py b/tests/MEDS_Extract/test_finalize_MEDS_data.py new file mode 100644 index 0000000..d9a3e0a --- /dev/null +++ b/tests/MEDS_Extract/test_finalize_MEDS_data.py @@ -0,0 +1,111 @@ +"""Tests the finalize MEDS data process. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + +import polars as pl + +from tests.MEDS_Extract import FINALIZE_DATA_SCRIPT +from tests.utils import parse_shards_yaml, single_stage_tester + +INPUT_SHARDS = parse_shards_yaml( + """ + data/train/0: |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + + data/train/1: |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, + 68729,"05/26/2010, 02:30:56",HR,86.0 + 68729,"05/26/2010, 02:30:56",TEMP,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, + 814703,"02/05/2010, 05:55:39",HR,170.2 + 814703,"02/05/2010, 05:55:39",TEMP,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE, + + data/tuning/0: |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, + 754281,"01/03/2010, 06:27:59",HR,142.0 + 754281,"01/03/2010, 06:27:59",TEMP,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE, + + data/held_out/0: |-2 + subject_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + """ +) + +WANT_OUTPUTS = { + k: v.with_columns( + pl.col("subject_id").cast(pl.Int64), + pl.col("time").cast(pl.Datetime("us")), + pl.col("code").cast(pl.String), + pl.col("numeric_value").cast(pl.Float32), + ) + for k, v in INPUT_SHARDS.items() +} + + +def test_convert_to_sharded_events(): + single_stage_tester( + script=FINALIZE_DATA_SCRIPT, + stage_name="finalize_MEDS_data", + stage_kwargs=None, + config_name="extract", + input_files=INPUT_SHARDS, + want_outputs=WANT_OUTPUTS, + df_check_kwargs={"check_column_order": True, "check_dtypes": True, "check_row_order": True}, + ) diff --git a/tests/MEDS_Extract/test_finalize_MEDS_metadata.py b/tests/MEDS_Extract/test_finalize_MEDS_metadata.py new file mode 100644 index 0000000..274997f --- /dev/null +++ b/tests/MEDS_Extract/test_finalize_MEDS_metadata.py @@ -0,0 +1,91 @@ +"""Tests the finalize MEDS metadata process. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + + +import polars as pl +from meds import __version__ as MEDS_VERSION + +from MEDS_transforms.utils import get_package_version as get_meds_transform_version +from tests.MEDS_Extract import FINALIZE_METADATA_SCRIPT +from tests.utils import single_stage_tester + +SHARDS_JSON = { + "train/0": [239684, 1195293], + "train/1": [68729, 814703], + "tuning/0": [754281], + "held_out/0": [1500733], +} + +WANT_OUTPUTS = { + "metadata/codes": pl.DataFrame( + { + "code": ["EYE_COLOR//BLUE", "EYE_COLOR//BROWN", "EYE_COLOR//HAZEL", "HR", "TEMP"], + "description": [ + "Blue Eyes. Less common than brown.", + "Brown Eyes. The most common eye color.", + "Hazel eyes. These are uncommon", + "Heart Rate", + "Body Temperature", + ], + "parent_codes": [None, None, None, ["LOINC/8867-4"], ["LOINC/8310-5"]], + } + ), +} + +METADATA_DF = pl.DataFrame( + { + "code": ["EYE_COLOR//BLUE", "EYE_COLOR//BROWN", "EYE_COLOR//HAZEL", "HR", "TEMP"], + "description": [ + "Blue Eyes. Less common than brown.", + "Brown Eyes. The most common eye color.", + "Hazel eyes. These are uncommon", + "Heart Rate", + "Body Temperature", + ], + "parent_codes": [None, None, None, ["LOINC/8867-4"], ["LOINC/8310-5"]], + } +) + +WANT_OUTPUTS = { + "metadata/codes": ( + METADATA_DF.with_columns( + pl.col("code").cast(pl.String), + pl.col("description").cast(pl.String), + pl.col("parent_codes").cast(pl.List(pl.String)), + ).select(["code", "description", "parent_codes"]) + ), + "metadata/subject_splits": pl.DataFrame( + { + "subject_id": [239684, 1195293, 68729, 814703, 754281, 1500733], + "split": ["train", "train", "train", "train", "tuning", "held_out"], + } + ), + "metadata/dataset.json": { + "dataset_name": "TEST", + "dataset_version": "1.0", + "etl_name": "MEDS_transforms", + "etl_version": get_meds_transform_version(), + "meds_version": MEDS_VERSION, + }, +} + + +def test_convert_to_sharded_events(): + single_stage_tester( + script=FINALIZE_METADATA_SCRIPT, + stage_name="finalize_MEDS_metadata", + stage_kwargs=None, + config_name="extract", + input_files={ + "metadata/codes": METADATA_DF, + "metadata/.shards.json": SHARDS_JSON, + }, + **{"etl_metadata.dataset_name": "TEST", "etl_metadata.dataset_version": "1.0"}, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + shards_map_fp="{input_dir}/metadata/.shards.json", + want_outputs=WANT_OUTPUTS, + df_check_kwargs={"check_row_order": False, "check_column_order": True, "check_dtypes": True}, + ) diff --git a/tests/MEDS_Extract/test_merge_to_MEDS_cohort.py b/tests/MEDS_Extract/test_merge_to_MEDS_cohort.py new file mode 100644 index 0000000..b9a21ff --- /dev/null +++ b/tests/MEDS_Extract/test_merge_to_MEDS_cohort.py @@ -0,0 +1,268 @@ +"""Tests the merge to MEDS events process. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + +from tests.MEDS_Extract import MERGE_TO_MEDS_COHORT_SCRIPT +from tests.utils import parse_shards_yaml, single_stage_tester + +EVENT_CFGS_YAML = """ +subjects: + subject_id_col: MRN + eye_color: + code: + - EYE_COLOR + - col(eye_color) + time: null + _metadata: + demo_metadata: + description: description + height: + code: HEIGHT + time: null + numeric_value: height + dob: + code: DOB + time: col(dob) + time_format: "%m/%d/%Y" +admit_vitals: + admissions: + code: + - ADMISSION + - col(department) + time: col(admit_date) + time_format: "%m/%d/%Y, %H:%M:%S" + discharge: + code: DISCHARGE + time: col(disch_date) + time_format: "%m/%d/%Y, %H:%M:%S" + HR: + code: HR + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: HR + _metadata: + input_metadata: + description: {"title": {"lab_code": "HR"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "HR"}} + temp: + code: TEMP + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: temp + _metadata: + input_metadata: + description: {"title": {"lab_code": "temp"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "temp"}} +""" + +SHARDS_JSON = { + "train/0": [239684, 1195293], + "train/1": [68729, 814703], + "tuning/0": [754281], + "held_out/0": [1500733], +} + +INPUT_SHARDS = parse_shards_yaml( + """ + data/train/0/subjects/[0-6): |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + + data/train/1/subjects/[0-6): |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, + + data/tuning/0/subjects/[0-6): |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, + + data/held_out/0/subjects/[0-6): |-2 + subject_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + + data/train/0/admit_vitals/[0-10): |-2 + subject_id,time,code,numeric_value + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + data/train/0/admit_vitals/[10-16): |-2 + subject_id,time,code,numeric_value + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + data/train/1/admit_vitals/[0-10): |-2 + subject_id,time,code,numeric_value + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, + 68729,"05/26/2010, 02:30:56",HR,86.0 + 68729,"05/26/2010, 02:30:56",TEMP,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, + 814703,"02/05/2010, 05:55:39",HR,170.2 + 814703,"02/05/2010, 05:55:39",TEMP,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE, + + data/train/1/admit_vitals/[10-16): |-2 + subject_id,time,code,numeric_value + + data/tuning/0/admit_vitals/[0-10): |-2 + subject_id,time,code,numeric_value + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, + 754281,"01/03/2010, 06:27:59",HR,142.0 + 754281,"01/03/2010, 06:27:59",TEMP,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE, + + data/tuning/0/admit_vitals/[10-16): |-2 + subject_id,time,code,numeric_value + + data/held_out/0/admit_vitals/[0-10): |-2 + subject_id,time,code,numeric_value + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + + data/held_out/0/admit_vitals/[10-16): |-2 + subject_id,time,code,numeric_value + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + """ +) + +WANT_OUTPUTS = parse_shards_yaml( + """ + data/train/0: |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + + data/train/1: |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, + 68729,"05/26/2010, 02:30:56",HR,86.0 + 68729,"05/26/2010, 02:30:56",TEMP,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, + 814703,"02/05/2010, 05:55:39",HR,170.2 + 814703,"02/05/2010, 05:55:39",TEMP,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE, + + data/tuning/0: |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, + 754281,"01/03/2010, 06:27:59",HR,142.0 + 754281,"01/03/2010, 06:27:59",TEMP,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE, + + data/held_out/0: |-2 + subject_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + """ +) + + +def test_convert_to_sharded_events(): + single_stage_tester( + script=MERGE_TO_MEDS_COHORT_SCRIPT, + stage_name="merge_to_MEDS_cohort", + stage_kwargs=None, + config_name="extract", + input_files={ + **INPUT_SHARDS, + "event_cfgs.yaml": EVENT_CFGS_YAML, + "metadata/.shards.json": SHARDS_JSON, + }, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + shards_map_fp="{input_dir}/metadata/.shards.json", + want_outputs=WANT_OUTPUTS, + df_check_kwargs={"check_column_order": False}, + ) diff --git a/tests/MEDS_Extract/test_shard_events.py b/tests/MEDS_Extract/test_shard_events.py new file mode 100644 index 0000000..f19746e --- /dev/null +++ b/tests/MEDS_Extract/test_shard_events.py @@ -0,0 +1,114 @@ +"""Tests the shard events stage in isolation. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + +from io import StringIO + +import polars as pl + +from tests.MEDS_Extract import SHARD_EVENTS_SCRIPT +from tests.utils import single_stage_tester + +SUBJECTS_CSV = """ +MRN,dob,eye_color,height +1195293,06/20/1978,BLUE,164.6868838269085 +239684,12/28/1980,BROWN,175.271115221764 +1500733,07/20/1986,BROWN,158.60131573580904 +814703,03/28/1976,HAZEL,156.48559093209357 +754281,12/19/1988,BROWN,166.22261567137025 +68729,03/09/1978,HAZEL,160.3953106166676 +""" + +ADMIT_VITALS_CSV = """ +subject_id,admit_date,disch_date,department,vitals_date,HR,temp +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:57:18",112.6,95.5 +754281,"01/03/2010, 06:27:59","01/03/2010, 08:22:13",PULMONARY,"01/03/2010, 06:27:59",142.0,99.8 +814703,"02/05/2010, 05:55:39","02/05/2010, 07:02:30",ORTHOPEDIC,"02/05/2010, 05:55:39",170.2,100.1 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:25:35",113.4,95.8 +68729,"05/26/2010, 02:30:56","05/26/2010, 04:51:52",PULMONARY,"05/26/2010, 02:30:56",86.0,97.8 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:12:31",112.5,99.8 +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 16:20:49",90.1,100.1 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 17:48:48",105.1,96.2 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 17:41:51",102.6,96.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:25:32",114.1,100.0 +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 14:54:38",91.4,100.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:41:33",107.5,100.4 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:24:44",107.7,100.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:45:19",119.8,99.9 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:23:52",109.0,100.0 +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 15:39:49",84.4,100.3 +""" + +EVENT_CFGS_YAML = """ +subjects: + subject_id_col: MRN + eye_color: + code: + - EYE_COLOR + - col(eye_color) + time: null + _metadata: + demo_metadata: + description: description + height: + code: HEIGHT + time: null + numeric_value: height + dob: + code: DOB + time: col(dob) + time_format: "%m/%d/%Y" +admit_vitals: + admissions: + code: + - ADMISSION + - col(department) + time: col(admit_date) + time_format: "%m/%d/%Y, %H:%M:%S" + discharge: + code: DISCHARGE + time: col(disch_date) + time_format: "%m/%d/%Y, %H:%M:%S" + HR: + code: HR + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: HR + _metadata: + input_metadata: + description: {"title": {"lab_code": "HR"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "HR"}} + temp: + code: TEMP + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: temp + _metadata: + input_metadata: + description: {"title": {"lab_code": "temp"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "temp"}} +""" + + +def test_shard_events(): + single_stage_tester( + script=SHARD_EVENTS_SCRIPT, + stage_name="shard_events", + stage_kwargs={"row_chunksize": 10}, + config_name="extract", + input_files={ + "subjects.csv": SUBJECTS_CSV, + "admit_vitals.csv": ADMIT_VITALS_CSV, + "admit_vitals.parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV)), + "event_cfgs.yaml": EVENT_CFGS_YAML, + }, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + want_outputs={ + "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), + "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV))[:10], + "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV))[10:], + }, + df_check_kwargs={"check_column_order": False}, + ) diff --git a/tests/MEDS_Extract/test_split_and_shard_subjects.py b/tests/MEDS_Extract/test_split_and_shard_subjects.py new file mode 100644 index 0000000..db74896 --- /dev/null +++ b/tests/MEDS_Extract/test_split_and_shard_subjects.py @@ -0,0 +1,133 @@ +"""Tests the full end-to-end extraction process. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + +from io import StringIO + +import polars as pl + +from tests.MEDS_Extract import SPLIT_AND_SHARD_SCRIPT +from tests.utils import single_stage_tester + +SUBJECTS_CSV = """ +MRN,dob,eye_color,height +1195293,06/20/1978,BLUE,164.6868838269085 +239684,12/28/1980,BROWN,175.271115221764 +1500733,07/20/1986,BROWN,158.60131573580904 +814703,03/28/1976,HAZEL,156.48559093209357 +754281,12/19/1988,BROWN,166.22261567137025 +68729,03/09/1978,HAZEL,160.3953106166676 +""" + +ADMIT_VITALS_0_10_CSV = """ +subject_id,admit_date,disch_date,department,vitals_date,HR,temp +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:57:18",112.6,95.5 +754281,"01/03/2010, 06:27:59","01/03/2010, 08:22:13",PULMONARY,"01/03/2010, 06:27:59",142.0,99.8 +814703,"02/05/2010, 05:55:39","02/05/2010, 07:02:30",ORTHOPEDIC,"02/05/2010, 05:55:39",170.2,100.1 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:25:35",113.4,95.8 +68729,"05/26/2010, 02:30:56","05/26/2010, 04:51:52",PULMONARY,"05/26/2010, 02:30:56",86.0,97.8 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:12:31",112.5,99.8 +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 16:20:49",90.1,100.1 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 17:48:48",105.1,96.2 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 17:41:51",102.6,96.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:25:32",114.1,100.0 +""" + +ADMIT_VITALS_10_16_CSV = """ +subject_id,admit_date,disch_date,department,vitals_date,HR,temp +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 14:54:38",91.4,100.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:41:33",107.5,100.4 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:24:44",107.7,100.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:45:19",119.8,99.9 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:23:52",109.0,100.0 +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 15:39:49",84.4,100.3 +""" + +EVENT_CFGS_YAML = """ +subjects: + subject_id_col: MRN + eye_color: + code: + - EYE_COLOR + - col(eye_color) + time: null + _metadata: + demo_metadata: + description: description + height: + code: HEIGHT + time: null + numeric_value: height + dob: + code: DOB + time: col(dob) + time_format: "%m/%d/%Y" +admit_vitals: + admissions: + code: + - ADMISSION + - col(department) + time: col(admit_date) + time_format: "%m/%d/%Y, %H:%M:%S" + discharge: + code: DISCHARGE + time: col(disch_date) + time_format: "%m/%d/%Y, %H:%M:%S" + HR: + code: HR + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: HR + _metadata: + input_metadata: + description: {"title": {"lab_code": "HR"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "HR"}} + temp: + code: TEMP + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: temp + _metadata: + input_metadata: + description: {"title": {"lab_code": "temp"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "temp"}} +""" + +# Test data (expected outputs) -- ALL OF THIS MAY CHANGE IF THE SEED OR DATA CHANGES +EXPECTED_SPLITS = { + "train/0": [239684, 1195293], + "train/1": [68729, 814703], + "tuning/0": [754281], + "held_out/0": [1500733], +} + +SUBJECT_SPLITS_DF = pl.DataFrame( + { + "subject_id": [239684, 1195293, 68729, 814703, 754281, 1500733], + "split": ["train", "train", "train", "train", "tuning", "held_out"], + } +) + + +def test_split_and_shard(): + single_stage_tester( + script=SPLIT_AND_SHARD_SCRIPT, + stage_name="split_and_shard_subjects", + stage_kwargs={ + "split_fracs.train": 4 / 6, + "split_fracs.tuning": 1 / 6, + "split_fracs.held_out": 1 / 6, + "n_subjects_per_shard": 2, + }, + config_name="extract", + input_files={ + "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), + "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_0_10_CSV)), + "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_10_16_CSV)), + "event_cfgs.yaml": EVENT_CFGS_YAML, + }, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + want_outputs={"metadata/.shards.json": EXPECTED_SPLITS}, + ) diff --git a/tests/MEDS_Transforms/__init__.py b/tests/MEDS_Transforms/__init__.py new file mode 100644 index 0000000..a2d3d56 --- /dev/null +++ b/tests/MEDS_Transforms/__init__.py @@ -0,0 +1,46 @@ +import os + +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + +code_root = root / "src" / "MEDS_transforms" +transforms_root = code_root / "transforms" +filters_root = code_root / "filters" + +if os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1": + # Root Source + AGGREGATE_CODE_METADATA_SCRIPT = code_root / "aggregate_code_metadata.py" + FIT_VOCABULARY_INDICES_SCRIPT = code_root / "fit_vocabulary_indices.py" + RESHARD_TO_SPLIT_SCRIPT = code_root / "reshard_to_split.py" + + # Filters + FILTER_MEASUREMENTS_SCRIPT = filters_root / "filter_measurements.py" + FILTER_SUBJECTS_SCRIPT = filters_root / "filter_subjects.py" + + # Transforms + ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT = transforms_root / "add_time_derived_measurements.py" + REORDER_MEASUREMENTS_SCRIPT = transforms_root / "reorder_measurements.py" + EXTRACT_VALUES_SCRIPT = transforms_root / "extract_values.py" + NORMALIZATION_SCRIPT = transforms_root / "normalization.py" + OCCLUDE_OUTLIERS_SCRIPT = transforms_root / "occlude_outliers.py" + TENSORIZATION_SCRIPT = transforms_root / "tensorization.py" + TOKENIZATION_SCRIPT = transforms_root / "tokenization.py" +else: + # Root Source + AGGREGATE_CODE_METADATA_SCRIPT = "MEDS_transform-aggregate_code_metadata" + FIT_VOCABULARY_INDICES_SCRIPT = "MEDS_transform-fit_vocabulary_indices" + RESHARD_TO_SPLIT_SCRIPT = "MEDS_transform-reshard_to_split" + + # Filters + FILTER_MEASUREMENTS_SCRIPT = "MEDS_transform-filter_measurements" + FILTER_SUBJECTS_SCRIPT = "MEDS_transform-filter_subjects" + + # Transforms + ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT = "MEDS_transform-add_time_derived_measurements" + REORDER_MEASUREMENTS_SCRIPT = "MEDS_transform-reorder_measurements" + EXTRACT_VALUES_SCRIPT = "MEDS_transform-extract_values" + NORMALIZATION_SCRIPT = "MEDS_transform-normalization" + OCCLUDE_OUTLIERS_SCRIPT = "MEDS_transform-occlude_outliers" + TENSORIZATION_SCRIPT = "MEDS_transform-tensorization" + TOKENIZATION_SCRIPT = "MEDS_transform-tokenization" diff --git a/tests/test_add_time_derived_measurements.py b/tests/MEDS_Transforms/test_add_time_derived_measurements.py similarity index 93% rename from tests/test_add_time_derived_measurements.py rename to tests/MEDS_Transforms/test_add_time_derived_measurements.py index e5653a1..ff2131d 100644 --- a/tests/test_add_time_derived_measurements.py +++ b/tests/MEDS_Transforms/test_add_time_derived_measurements.py @@ -4,9 +4,11 @@ scripts. """ +from meds import subject_id_field -from .transform_tester_base import ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, single_stage_transform_tester -from .utils import parse_meds_csvs +from tests.MEDS_Transforms import ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester +from tests.utils import parse_meds_csvs AGE_CALCULATION_STR = """ See `add_time_derived_measurements.py` for the source of the constant value. @@ -96,8 +98,8 @@ ``` """ -WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value +WANT_TRAIN_0 = f""" +{subject_id_field},time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00","TIME_OF_DAY//[00,06)", @@ -156,9 +158,9 @@ 1195293,"06/20/2010, 20:50:04",DISCHARGE, """ -# All patients in this shard had only 4 events. -WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value +# All subjects in this shard had only 4 events. +WANT_TRAIN_1 = f""" +{subject_id_field},time,code,numeric_value 68729,,EYE_COLOR//HAZEL, 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00","TIME_OF_DAY//[00,06)", @@ -185,8 +187,8 @@ 814703,"02/05/2010, 07:02:30",DISCHARGE, """ -WANT_TUNING_0 = """ -patient_id,time,code,numeric_value +WANT_TUNING_0 = f""" +{subject_id_field},time,code,numeric_value 754281,,EYE_COLOR//BROWN, 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00","TIME_OF_DAY//[00,06)", @@ -201,8 +203,8 @@ 754281,"01/03/2010, 08:22:13",DISCHARGE, """ -WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +WANT_HELD_OUT_0 = f""" +{subject_id_field},time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00","TIME_OF_DAY//[00,06)", diff --git a/tests/test_aggregate_code_metadata.py b/tests/MEDS_Transforms/test_aggregate_code_metadata.py similarity index 91% rename from tests/test_aggregate_code_metadata.py rename to tests/MEDS_Transforms/test_aggregate_code_metadata.py index 21698cb..a2abce5 100644 --- a/tests/test_aggregate_code_metadata.py +++ b/tests/MEDS_Transforms/test_aggregate_code_metadata.py @@ -6,14 +6,14 @@ import polars as pl -from .transform_tester_base import ( - AGGREGATE_CODE_METADATA_SCRIPT, +from tests.MEDS_Transforms import AGGREGATE_CODE_METADATA_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import ( MEDS_CODE_METADATA_SCHEMA, single_stage_transform_tester, ) WANT_OUTPUT_CODE_METADATA_FILE = """ -code,code/n_occurrences,code/n_patients,values/n_occurrences,values/n_patients,values/sum,values/sum_sqd,values/n_ints,values/min,values/max,description,parent_codes +code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/n_subjects,values/sum,values/sum_sqd,values/n_ints,values/min,values/max,description,parent_codes ,44,4,28,4,3198.8389005974336,382968.28937288234,6,86.0,175.271118,, ADMISSION//CARDIAC,2,2,0,0,0,0,0,,,, ADMISSION//ORTHOPEDIC,1,1,0,0,0,0,0,,,, @@ -45,9 +45,9 @@ "TEMP", ], "code/n_occurrences": [44, 2, 1, 1, 4, 4, 1, 1, 2, 4, 12, 12], - "code/n_patients": [4, 2, 1, 1, 4, 4, 1, 1, 2, 4, 4, 4], + "code/n_subjects": [4, 2, 1, 1, 4, 4, 1, 1, 2, 4, 4, 4], "values/n_occurrences": [28, 0, 0, 0, 0, 0, 0, 0, 0, 4, 12, 12], - "values/n_patients": [4, 0, 0, 0, 0, 0, 0, 0, 0, 4, 4, 4], + "values/n_subjects": [4, 0, 0, 0, 0, 0, 0, 0, 0, 4, 4, 4], "values/sum": [ 3198.8389005974336, 0, @@ -163,9 +163,9 @@ AGGREGATIONS = [ "code/n_occurrences", - "code/n_patients", + "code/n_subjects", "values/n_occurrences", - "values/n_patients", + "values/n_subjects", "values/sum", "values/sum_sqd", "values/n_ints", @@ -183,4 +183,6 @@ def test_aggregate_code_metadata(): want_metadata=WANT_OUTPUT_CODE_METADATA_FILE, input_code_metadata=MEDS_CODE_METADATA_FILE, do_use_config_yaml=True, + assert_no_other_outputs=False, + df_check_kwargs={"check_column_order": False}, ) diff --git a/tests/test_extract_values.py b/tests/MEDS_Transforms/test_extract_values.py similarity index 87% rename from tests/test_extract_values.py rename to tests/MEDS_Transforms/test_extract_values.py index 8a3a65b..4114b3b 100644 --- a/tests/test_extract_values.py +++ b/tests/MEDS_Transforms/test_extract_values.py @@ -4,12 +4,13 @@ scripts. """ -from .transform_tester_base import EXTRACT_VALUES_SCRIPT, parse_shards_yaml, single_stage_transform_tester +from tests.MEDS_Transforms import EXTRACT_VALUES_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import parse_shards_yaml, single_stage_transform_tester INPUT_SHARDS = parse_shards_yaml( """ train/0: |-2 - patient_id,time,code,numeric_value,text_value + subject_id,time,code,numeric_value,text_value 239684,,EYE_COLOR//BROWN,, 239684,"12/28/1980, 00:00:00",DOB,, 239684,"05/11/2010, 17:41:51",BP,,"120/80" @@ -19,19 +20,19 @@ 1195293,"06/20/2010, 19:23:52",HR,80, 1195293,"06/20/2010, 19:23:52",TEMP,,"100F" train/1: |-2 - patient_id,time,code,numeric_value,text_value + subject_id,time,code,numeric_value,text_value 68729,,EYE_COLOR//HAZEL,, 68729,"03/09/1978, 00:00:00",DOB,, 814703,"02/05/2010, 05:55:39",HR,170.2, tuning/0: |-2 - patient_id,time,code,numeric_value,text_value + subject_id,time,code,numeric_value,text_value 754281,,EYE_COLOR//BROWN,, 754281,"12/19/1988, 00:00:00",DOB,, 754281,"01/03/2010, 06:27:59",HR,142.0, 754281,"06/20/2010, 20:23:50",BP,,"134/76" 754281,"06/20/2010, 21:00:02",TEMP,,"36.2C" held_out/0: |-2 - patient_id,time,code,numeric_value,text_value + subject_id,time,code,numeric_value,text_value 1500733,,EYE_COLOR//BROWN,, 1500733,"07/20/1986, 00:00:00",DOB,, 1500733,"06/03/2010, 14:54:38",HR,91.4 @@ -42,7 +43,7 @@ WANT_SHARDS = parse_shards_yaml( """ train/0: |-2 - patient_id,time,code,numeric_value,text_value + subject_id,time,code,numeric_value,text_value 239684,,EYE_COLOR//BROWN,, 239684,"12/28/1980, 00:00:00",DOB,, 239684,"05/11/2010, 17:41:51",BP//SYSTOLIC,120, @@ -54,12 +55,12 @@ 1195293,"06/20/2010, 19:23:52",TEMP//F,100, 1195293,"06/20/2010, 19:23:52",HR,80, train/1: |-2 - patient_id,time,code,numeric_value,text_value + subject_id,time,code,numeric_value,text_value 68729,,EYE_COLOR//HAZEL,, 68729,"03/09/1978, 00:00:00",DOB,, 814703,"02/05/2010, 05:55:39",HR,170.2, tuning/0: |-2 - patient_id,time,code,numeric_value,text_value + subject_id,time,code,numeric_value,text_value 754281,,EYE_COLOR//BROWN,, 754281,"12/19/1988, 00:00:00",DOB,, 754281,"01/03/2010, 06:27:59",HR,142.0, @@ -67,7 +68,7 @@ 754281,"06/20/2010, 20:23:50",BP//DIASTOLIC,76, 754281,"06/20/2010, 21:00:02",TEMP//C,36.2, held_out/0: |-2 - patient_id,time,code,numeric_value,text_value + subject_id,time,code,numeric_value,text_value 1500733,,EYE_COLOR//BROWN,, 1500733,"07/20/1986, 00:00:00",DOB,, 1500733,"06/03/2010, 14:54:38",BP//SYSTOLIC,123, diff --git a/tests/test_filter_measurements.py b/tests/MEDS_Transforms/test_filter_measurements.py similarity index 88% rename from tests/test_filter_measurements.py rename to tests/MEDS_Transforms/test_filter_measurements.py index 3bc807f..9991a26 100644 --- a/tests/test_filter_measurements.py +++ b/tests/MEDS_Transforms/test_filter_measurements.py @@ -4,13 +4,13 @@ scripts. """ - -from .transform_tester_base import FILTER_MEASUREMENTS_SCRIPT, single_stage_transform_tester -from .utils import parse_meds_csvs +from tests.MEDS_Transforms import FILTER_MEASUREMENTS_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester +from tests.utils import parse_meds_csvs # This is the code metadata # MEDS_CODE_METADATA_CSV = """ -# code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,description,parent_code +# code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/sum,values/sum_sqd,description,parent_code # ,44,4,28,3198.8389005974336,382968.28937288234,, # ADMISSION//CARDIAC,2,2,0,,,, # ADMISSION//ORTHOPEDIC,1,1,0,,,, @@ -25,11 +25,11 @@ # TEMP,12,4,12,1181.4999999999998,116373.38999999998,"Body Temperature",LOINC/8310-5 # """ # -# We'll keep only the codes that occur for at least 2 patients, which are: ADMISSION//CARDIAC, DISCHARGE, DOB, +# We'll keep only the codes that occur for at least 2 subjects, which are: ADMISSION//CARDIAC, DISCHARGE, DOB, # EYE_COLOR//HAZEL, HEIGHT, HR, TEMP WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, @@ -61,7 +61,7 @@ """ WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,,EYE_COLOR//HAZEL, 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00",DOB, @@ -77,7 +77,7 @@ """ WANT_TUNING_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, 754281,"01/03/2010, 06:27:59",HR,142.0 @@ -86,7 +86,7 @@ """ WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, 1500733,"06/03/2010, 14:54:38",HR,91.4 @@ -112,14 +112,14 @@ def test_filter_measurements(): single_stage_transform_tester( transform_script=FILTER_MEASUREMENTS_SCRIPT, stage_name="filter_measurements", - transform_stage_kwargs={"min_patients_per_code": 2}, + transform_stage_kwargs={"min_subjects_per_code": 2}, want_data=WANT_SHARDS, ) # This is the code metadata # MEDS_CODE_METADATA_CSV = """ -# code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,description,parent_code +# code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/sum,values/sum_sqd,description,parent_code # ,44,4,28,3198.8389005974336,382968.28937288234,, # ADMISSION//CARDIAC,2,2,0,,,, # ADMISSION//ORTHOPEDIC,1,1,0,,,, @@ -144,7 +144,7 @@ def test_filter_measurements(): # - Other codes won't be filtered, so we will retain HEIGHT, DISCHARGE, DOB, TEMP MR_WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, @@ -166,7 +166,7 @@ def test_filter_measurements(): """ MR_WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00",DOB, 68729,"05/26/2010, 02:30:56",TEMP,97.8 @@ -178,7 +178,7 @@ def test_filter_measurements(): """ MR_WANT_TUNING_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, 754281,"01/03/2010, 06:27:59",TEMP,99.8 @@ -186,7 +186,7 @@ def test_filter_measurements(): """ MR_WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, 1500733,"06/03/2010, 14:54:38",TEMP,100.0 @@ -214,11 +214,11 @@ def test_match_revise_filter_measurements(): stage_name="filter_measurements", transform_stage_kwargs={ "_match_revise": [ - {"_matcher": {"code": {"regex": "ADMISSION//.*"}}, "min_patients_per_code": 2}, - {"_matcher": {"code": "HR"}, "min_patients_per_code": 15}, - {"_matcher": {"code": "EYE_COLOR//BLUE"}, "min_patients_per_code": 4}, - {"_matcher": {"code": "EYE_COLOR//BROWN"}, "min_patients_per_code": 4}, - {"_matcher": {"code": "EYE_COLOR//HAZEL"}, "min_patients_per_code": 4}, + {"_matcher": {"code": {"regex": "ADMISSION//.*"}}, "min_subjects_per_code": 2}, + {"_matcher": {"code": "HR"}, "min_subjects_per_code": 15}, + {"_matcher": {"code": "EYE_COLOR//BLUE"}, "min_subjects_per_code": 4}, + {"_matcher": {"code": "EYE_COLOR//BROWN"}, "min_subjects_per_code": 4}, + {"_matcher": {"code": "EYE_COLOR//HAZEL"}, "min_subjects_per_code": 4}, ], }, want_data=MR_WANT_SHARDS, diff --git a/tests/test_filter_patients.py b/tests/MEDS_Transforms/test_filter_subjects.py similarity index 72% rename from tests/test_filter_patients.py rename to tests/MEDS_Transforms/test_filter_subjects.py index 0b07836..83f4068 100644 --- a/tests/test_filter_patients.py +++ b/tests/MEDS_Transforms/test_filter_subjects.py @@ -1,15 +1,17 @@ -"""Tests the filter patients script. +"""Tests the filter subjects script. Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed scripts. """ +from meds import subject_id_field -from .transform_tester_base import FILTER_PATIENTS_SCRIPT, single_stage_transform_tester -from .utils import parse_meds_csvs +from tests.MEDS_Transforms import FILTER_SUBJECTS_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester +from tests.utils import parse_meds_csvs -WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value +WANT_TRAIN_0 = f""" +{subject_id_field},time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -42,18 +44,18 @@ 1195293,"06/20/2010, 20:50:04",DISCHARGE, """ -# All patients in this shard had only 4 events. -WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value +# All subjects in this shard had only 4 events. +WANT_TRAIN_1 = f""" +{subject_id_field},time,code,numeric_value """ -# All patients in this shard had only 4 events. -WANT_TUNING_0 = """ -patient_id,time,code,numeric_value +# All subjects in this shard had only 4 events. +WANT_TUNING_0 = f""" +{subject_id_field},time,code,numeric_value """ -WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +WANT_HELD_OUT_0 = f""" +{subject_id_field},time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, @@ -77,10 +79,10 @@ ) -def test_filter_patients(): +def test_filter_subjects(): single_stage_transform_tester( - transform_script=FILTER_PATIENTS_SCRIPT, - stage_name="filter_patients", - transform_stage_kwargs={"min_events_per_patient": 5}, + transform_script=FILTER_SUBJECTS_SCRIPT, + stage_name="filter_subjects", + transform_stage_kwargs={"min_events_per_subject": 5}, want_data=WANT_SHARDS, ) diff --git a/tests/test_fit_vocabulary_indices.py b/tests/MEDS_Transforms/test_fit_vocabulary_indices.py similarity index 68% rename from tests/test_fit_vocabulary_indices.py rename to tests/MEDS_Transforms/test_fit_vocabulary_indices.py index ce7c40a..78f637a 100644 --- a/tests/test_fit_vocabulary_indices.py +++ b/tests/MEDS_Transforms/test_fit_vocabulary_indices.py @@ -5,14 +5,11 @@ """ -from .transform_tester_base import ( - FIT_VOCABULARY_INDICES_SCRIPT, - parse_code_metadata_csv, - single_stage_transform_tester, -) +from tests.MEDS_Transforms import FIT_VOCABULARY_INDICES_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import parse_code_metadata_csv, single_stage_transform_tester WANT_CSV = """ -code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,description,parent_codes,code/vocab_index +code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/sum,values/sum_sqd,description,parent_codes,code/vocab_index ,44,4,28,3198.8389005974336,382968.28937288234,,,1 ADMISSION//CARDIAC,2,2,0,,,,,2 ADMISSION//ORTHOPEDIC,1,1,0,,,,,3 @@ -35,3 +32,11 @@ def test_fit_vocabulary_indices_with_default_stage_config(): transform_stage_kwargs=None, want_metadata=parse_code_metadata_csv(WANT_CSV), ) + + single_stage_transform_tester( + transform_script=FIT_VOCABULARY_INDICES_SCRIPT, + stage_name="fit_vocabulary_indices", + transform_stage_kwargs={"ordering_method": "file"}, + want_metadata=parse_code_metadata_csv(WANT_CSV), + should_error=True, + ) diff --git a/tests/test_multi_stage_preprocess_pipeline.py b/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py similarity index 94% rename from tests/test_multi_stage_preprocess_pipeline.py rename to tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py index eda9060..6667313 100644 --- a/tests/test_multi_stage_preprocess_pipeline.py +++ b/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py @@ -4,7 +4,7 @@ scripts. In this test, the following stages are run: - - filter_patients + - filter_subjects - add_time_derived_measurements - fit_outlier_detection - occlude_outliers @@ -17,23 +17,24 @@ The stage configuration arguments will be as given in the yaml block below: """ + from datetime import datetime import polars as pl +from meds import subject_id_field from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict -from .transform_tester_base import ( +from tests.MEDS_Transforms import ( ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, AGGREGATE_CODE_METADATA_SCRIPT, - FILTER_PATIENTS_SCRIPT, + FILTER_SUBJECTS_SCRIPT, FIT_VOCABULARY_INDICES_SCRIPT, NORMALIZATION_SCRIPT, OCCLUDE_OUTLIERS_SCRIPT, TENSORIZATION_SCRIPT, TOKENIZATION_SCRIPT, - multi_stage_transform_tester, - parse_shards_yaml, ) +from tests.MEDS_Transforms.transform_tester_base import multi_stage_transform_tester, parse_shards_yaml MEDS_CODE_METADATA = pl.DataFrame( { @@ -51,8 +52,8 @@ ) STAGE_CONFIG_YAML = """ -filter_patients: - min_events_per_patient: 5 +filter_subjects: + min_events_per_subject: 5 add_time_derived_measurements: age: DOB_code: "DOB" # This is the MEDS official code for BIRTH @@ -71,17 +72,17 @@ fit_normalization: aggregations: - "code/n_occurrences" - - "code/n_patients" + - "code/n_subjects" - "values/n_occurrences" - "values/sum" - "values/sum_sqd" """ -# After filtering out patients with fewer than 5 events: +# After filtering out subjects with fewer than 5 events: WANT_FILTER = parse_shards_yaml( - """ - "filter_patients/train/0": |-2 - patient_id,time,code,numeric_value + f""" + "filter_subjects/train/0": |-2 + {subject_id_field},time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -113,14 +114,14 @@ 1195293,"06/20/2010, 20:41:33",TEMP,100.4 1195293,"06/20/2010, 20:50:04",DISCHARGE, - "filter_patients/train/1": |-2 - patient_id,time,code,numeric_value + "filter_subjects/train/1": |-2 + {subject_id_field},time,code,numeric_value - "filter_patients/tuning/0": |-2 - patient_id,time,code,numeric_value + "filter_subjects/tuning/0": |-2 + {subject_id_field},time,code,numeric_value - "filter_patients/held_out/0": |-2 - patient_id,time,code,numeric_value + "filter_subjects/held_out/0": |-2 + {subject_id_field},time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, @@ -136,9 +137,9 @@ ) WANT_TIME_DERIVED = parse_shards_yaml( - """ + f""" "add_time_derived_measurements/train/0": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00","TIME_OF_DAY//[00,06)", @@ -197,13 +198,13 @@ 1195293,"06/20/2010, 20:50:04",DISCHARGE, "add_time_derived_measurements/train/1": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value "add_time_derived_measurements/tuning/0": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value "add_time_derived_measurements/held_out/0": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00","TIME_OF_DAY//[00,06)", @@ -387,9 +388,9 @@ """ WANT_OCCLUDE_OUTLIERS = parse_shards_yaml( - """ + f""" "occlude_outliers/train/0": |-2 - patient_id,time,code,numeric_value,numeric_value/is_inlier + {subject_id_field},time,code,numeric_value,numeric_value/is_inlier 239684,,EYE_COLOR//BROWN,, 239684,,HEIGHT,,false 239684,"12/28/1980, 00:00:00","TIME_OF_DAY//[00,06)",, @@ -448,13 +449,13 @@ 1195293,"06/20/2010, 20:50:04",DISCHARGE,, "occlude_outliers/train/1": |-2 - patient_id,time,code,numeric_value,numeric_value/is_inlier + {subject_id_field},time,code,numeric_value,numeric_value/is_inlier "occlude_outliers/tuning/0": |-2 - patient_id,time,code,numeric_value,numeric_value/is_inlier + {subject_id_field},time,code,numeric_value,numeric_value/is_inlier "occlude_outliers/held_out/0": |-2 - patient_id,time,code,numeric_value,numeric_value/is_inlier + {subject_id_field},time,code,numeric_value,numeric_value/is_inlier 1500733,,EYE_COLOR//BROWN,, 1500733,,HEIGHT,,false 1500733,"07/20/1986, 00:00:00","TIME_OF_DAY//[00,06)",, @@ -488,7 +489,7 @@ ... .group_by("code") ... .agg( ... pl.len().alias("code/n_occurrences"), -... pl.col("patient_id").n_unique().alias("code/n_patients"), +... pl.col("subject_id").n_unique().alias("code/n_subjects"), ... VALS.len().alias("values/n_occurrences"), ... VALS.sum().alias("values/sum"), ... (VALS**2).sum().alias("values/sum_sqd") @@ -497,7 +498,7 @@ >>> post_transform.filter(pl.col("values/n_occurrences") > 0) shape: (3, 6) ┌──────┬────────────────────┬─────────────────┬──────────────────────┬────────────┬────────────────┐ -│ code ┆ code/n_occurrences ┆ code/n_patients ┆ values/n_occurrences ┆ values/sum ┆ values/sum_sqd │ +│ code ┆ code/n_occurrences ┆ code/n_subjects ┆ values/n_occurrences ┆ values/sum ┆ values/sum_sqd │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ u32 ┆ u32 ┆ u32 ┆ f32 ┆ f32 │ ╞══════╪════════════════════╪═════════════════╪══════════════════════╪════════════╪════════════════╡ @@ -508,7 +509,7 @@ >>> print(post_transform.filter(pl.col("values/n_occurrences") > 0).to_dict(as_series=False)) {'code': ['HR', 'TEMP', 'AGE'], 'code/n_occurrences': [10, 10, 12], - 'code/n_patients': [2, 2, 2], + 'code/n_subjects': [2, 2, 2], 'values/n_occurrences': [7, 6, 7], 'values/sum': [776.7999877929688, 600.1000366210938, 224.02084350585938], 'values/sum_sqd': [86249.921875, 60020.21484375, 7169.33349609375]} @@ -533,7 +534,7 @@ "DOB", ], "code/n_occurrences": [1, 1, 10, 10, 12, 2, 10, 2, 2, 2, 2, 2], - "code/n_patients": [1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2], + "code/n_subjects": [1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2], "values/n_occurrences": [0, 0, 7, 6, 7, 0, 0, 0, 0, 0, 0, 0], "values/sum": [ 0.0, @@ -597,7 +598,7 @@ "description": pl.String, "parent_codes": pl.List(pl.String), "code/n_occurrences": pl.UInt8, - "code/n_patients": pl.UInt8, + "code/n_subjects": pl.UInt8, "values/n_occurrences": pl.UInt8, # In the real stage, this is shrunk, so it differs from the ex. "values/sum": pl.Float32, "values/sum_sqd": pl.Float32, @@ -625,7 +626,7 @@ ], "code/vocab_index": [5, 6, 8, 9, 2, 7, 12, 11, 10, 1, 3, 4], "code/n_occurrences": [1, 1, 10, 10, 12, 2, 10, 2, 2, 2, 2, 2], - "code/n_patients": [1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2], + "code/n_subjects": [1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2], "values/n_occurrences": [0, 0, 7, 6, 7, 0, 0, 0, 0, 0, 0, 0], "values/sum": [ 0.0, @@ -689,7 +690,7 @@ "description": pl.String, "parent_codes": pl.List(pl.String), "code/n_occurrences": pl.UInt8, - "code/n_patients": pl.UInt8, + "code/n_subjects": pl.UInt8, "code/vocab_index": pl.UInt8, "values/n_occurrences": pl.UInt8, "values/sum": pl.Float32, @@ -772,9 +773,9 @@ # Note we have dropped the row in the held out shard that doesn't have a code in the vocabulary! WANT_NORMALIZATION = parse_shards_yaml( - """ + f""" "normalization/train/0": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value 239684,,6, 239684,,7, 239684,"12/28/1980, 00:00:00",10, @@ -833,13 +834,13 @@ 1195293,"06/20/2010, 20:50:04",3, "normalization/train/1": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value "normalization/tuning/0": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value "normalization/held_out/0": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value 1500733,,6, 1500733,,7, 1500733,"07/20/1986, 00:00:00",10, @@ -864,7 +865,7 @@ ) TOKENIZATION_SCHEMA_DF_SCHEMA = { - "patient_id": pl.UInt32, + subject_id_field: pl.Int64, "code": pl.List(pl.UInt8), "numeric_value": pl.List(pl.Float32), "start_time": pl.Datetime("us"), @@ -873,7 +874,7 @@ WANT_TOKENIZATION_SCHEMAS = { "tokenization/schemas/train/0": pl.DataFrame( { - "patient_id": [239684, 1195293], + subject_id_field: [239684, 1195293], "code": [[6, 7], [5, 7]], "numeric_value": [[None, None], [None, None]], "start_time": [datetime(1980, 12, 28), datetime(1978, 6, 20)], @@ -901,16 +902,16 @@ schema=TOKENIZATION_SCHEMA_DF_SCHEMA, ), "tokenization/schemas/train/1": pl.DataFrame( - {k: [] for k in ["patient_id", "code", "numeric_value", "start_time", "time"]}, + {k: [] for k in [subject_id_field, "code", "numeric_value", "start_time", "time"]}, schema=TOKENIZATION_SCHEMA_DF_SCHEMA, ), "tokenization/schemas/tuning/0": pl.DataFrame( - {k: [] for k in ["patient_id", "code", "numeric_value", "start_time", "time"]}, + {k: [] for k in [subject_id_field, "code", "numeric_value", "start_time", "time"]}, schema=TOKENIZATION_SCHEMA_DF_SCHEMA, ), "tokenization/schemas/held_out/0": pl.DataFrame( { - "patient_id": [1500733], + subject_id_field: [1500733], "code": [[6, 7]], "numeric_value": [[None, None]], "start_time": [datetime(1986, 7, 20)], @@ -928,18 +929,9 @@ ), } -TOKENIZATION_CODE = """ -```python - ->>> import polars as pl ->>> from tests.test_multi_stage_preprocess_pipeline import WANT_NORMALIZATION as dfs ->>> - -``` -""" TOKENIZATION_EVENT_SEQS_DF_SCHEMA = { - "patient_id": pl.UInt32, + subject_id_field: pl.Int64, "code": pl.List(pl.List(pl.UInt8)), "numeric_value": pl.List(pl.List(pl.Float32)), "time_delta_days": pl.List(pl.Float32), @@ -948,7 +940,7 @@ WANT_TOKENIZATION_EVENT_SEQS = { "tokenization/event_seqs/train/0": pl.DataFrame( { - "patient_id": [239684, 1195293], + subject_id_field: [239684, 1195293], "code": [ [[10, 4], [11, 2, 1, 8, 9], [11, 2, 8, 9], [12, 2, 8, 9], [12, 2, 8, 9], [12, 2, 3]], [ @@ -995,16 +987,16 @@ schema=TOKENIZATION_EVENT_SEQS_DF_SCHEMA, ), "tokenization/event_seqs/train/1": pl.DataFrame( - {k: [] for k in ["patient_id", "code", "numeric_value", "time_delta_days"]}, + {k: [] for k in [subject_id_field, "code", "numeric_value", "time_delta_days"]}, schema=TOKENIZATION_EVENT_SEQS_DF_SCHEMA, ), "tokenization/event_seqs/tuning/0": pl.DataFrame( - {k: [] for k in ["patient_id", "code", "numeric_value", "time_delta_days"]}, + {k: [] for k in [subject_id_field, "code", "numeric_value", "time_delta_days"]}, schema=TOKENIZATION_EVENT_SEQS_DF_SCHEMA, ), "tokenization/event_seqs/held_out/0": pl.DataFrame( { - "patient_id": [1500733], + subject_id_field: [1500733], "code": [ [ [10, 4], @@ -1057,7 +1049,7 @@ def test_pipeline(): multi_stage_transform_tester( transform_scripts=[ - FILTER_PATIENTS_SCRIPT, + FILTER_SUBJECTS_SCRIPT, ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, AGGREGATE_CODE_METADATA_SCRIPT, OCCLUDE_OUTLIERS_SCRIPT, @@ -1068,7 +1060,7 @@ def test_pipeline(): TENSORIZATION_SCRIPT, ], stage_names=[ - "filter_patients", + "filter_subjects", "add_time_derived_measurements", "fit_outlier_detection", "occlude_outliers", @@ -1093,6 +1085,5 @@ def test_pipeline(): **WANT_TOKENIZATION_EVENT_SEQS, **WANT_NRTs, }, - outputs_from_cohort_dir=True, input_code_metadata=MEDS_CODE_METADATA, ) diff --git a/tests/test_normalization.py b/tests/MEDS_Transforms/test_normalization.py similarity index 94% rename from tests/test_normalization.py rename to tests/MEDS_Transforms/test_normalization.py index 46992ed..4cc21ae 100644 --- a/tests/test_normalization.py +++ b/tests/MEDS_Transforms/test_normalization.py @@ -6,13 +6,14 @@ import polars as pl -from .transform_tester_base import NORMALIZATION_SCRIPT, single_stage_transform_tester -from .utils import MEDS_PL_SCHEMA, parse_meds_csvs +from tests.MEDS_Transforms import NORMALIZATION_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester +from tests.utils import MEDS_PL_SCHEMA, parse_meds_csvs # This is the code metadata file we'll use in this transform test. It is different than the default as we need # a code/vocab_index MEDS_CODE_METADATA_CSV = """ -code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,code/vocab_index +code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/sum,values/sum_sqd,code/vocab_index ADMISSION//CARDIAC,2,2,0,,,1 ADMISSION//ORTHOPEDIC,1,1,0,,,2 ADMISSION//PULMONARY,1,1,0,,,3 @@ -129,7 +130,7 @@ # TEMP: 11 WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,,7, 239684,,9,1.5770268440246582 239684,"12/28/1980, 00:00:00",5, @@ -163,7 +164,7 @@ """ WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,,8, 68729,,9,-0.5438239574432373 68729,"03/09/1978, 00:00:00",5, @@ -181,7 +182,7 @@ """ WANT_TUNING_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,,7, 754281,,9,0.28697699308395386 754281,"12/19/1988, 00:00:00",5, @@ -192,7 +193,7 @@ """ WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,,7, 1500733,,9,-0.7995940446853638 1500733,"07/20/1986, 00:00:00",5, diff --git a/tests/test_occlude_outliers.py b/tests/MEDS_Transforms/test_occlude_outliers.py similarity index 91% rename from tests/test_occlude_outliers.py rename to tests/MEDS_Transforms/test_occlude_outliers.py index 63e9376..8e30db6 100644 --- a/tests/test_occlude_outliers.py +++ b/tests/MEDS_Transforms/test_occlude_outliers.py @@ -7,12 +7,13 @@ import polars as pl -from .transform_tester_base import OCCLUDE_OUTLIERS_SCRIPT, single_stage_transform_tester -from .utils import MEDS_PL_SCHEMA, parse_meds_csvs +from tests.MEDS_Transforms import OCCLUDE_OUTLIERS_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester +from tests.utils import MEDS_PL_SCHEMA, parse_meds_csvs # This is the code metadata # MEDS_CODE_METADATA_CSV = """ -# code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,description,parent_code +# code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/sum,values/sum_sqd,description,parent_code # ,44,4,28,3198.8389005974336,382968.28937288234,, # ADMISSION//CARDIAC,2,2,0,,,, # ADMISSION//ORTHOPEDIC,1,1,0,,,, @@ -75,7 +76,7 @@ """ # noqa: E501 WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value,numeric_value/is_inlier +subject_id,time,code,numeric_value,numeric_value/is_inlier 239684,,EYE_COLOR//BROWN,, 239684,,HEIGHT,,false 239684,"12/28/1980, 00:00:00",DOB,, @@ -109,7 +110,7 @@ """ WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value,numeric_value/is_inlier +subject_id,time,code,numeric_value,numeric_value/is_inlier 68729,,EYE_COLOR//HAZEL,, 68729,,HEIGHT,160.3953106166676,true 68729,"03/09/1978, 00:00:00",DOB,, @@ -127,7 +128,7 @@ """ WANT_TUNING_0 = """ -patient_id,time,code,numeric_value,numeric_value/is_inlier +subject_id,time,code,numeric_value,numeric_value/is_inlier 754281,,EYE_COLOR//BROWN,, 754281,,HEIGHT,166.22261567137025,true 754281,"12/19/1988, 00:00:00",DOB,, @@ -138,7 +139,7 @@ """ WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value,numeric_value/is_inlier +subject_id,time,code,numeric_value,numeric_value/is_inlier 1500733,,EYE_COLOR//BROWN,, 1500733,,HEIGHT,158.60131573580904,true 1500733,"07/20/1986, 00:00:00",DOB,, diff --git a/tests/test_reorder_measurements.py b/tests/MEDS_Transforms/test_reorder_measurements.py similarity index 90% rename from tests/test_reorder_measurements.py rename to tests/MEDS_Transforms/test_reorder_measurements.py index c90dee4..782c394 100644 --- a/tests/test_reorder_measurements.py +++ b/tests/MEDS_Transforms/test_reorder_measurements.py @@ -5,8 +5,9 @@ """ -from .transform_tester_base import REORDER_MEASUREMENTS_SCRIPT, single_stage_transform_tester -from .utils import parse_meds_csvs +from tests.MEDS_Transforms import REORDER_MEASUREMENTS_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester +from tests.utils import parse_meds_csvs ORDERED_CODE_PATTERNS = [ "ADMISSION.*", @@ -19,7 +20,7 @@ WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -53,7 +54,7 @@ """ WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,,HEIGHT,160.3953106166676 68729,,EYE_COLOR//HAZEL, 68729,"03/09/1978, 00:00:00",DOB, @@ -71,7 +72,7 @@ """ WANT_TUNING_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,,EYE_COLOR//BROWN, 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, @@ -82,7 +83,7 @@ """ WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, diff --git a/tests/test_reshard_to_split.py b/tests/MEDS_Transforms/test_reshard_to_split.py similarity index 84% rename from tests/test_reshard_to_split.py rename to tests/MEDS_Transforms/test_reshard_to_split.py index 65056e5..d0094a9 100644 --- a/tests/test_reshard_to_split.py +++ b/tests/MEDS_Transforms/test_reshard_to_split.py @@ -5,8 +5,11 @@ """ -from .transform_tester_base import RESHARD_TO_SPLIT_SCRIPT, single_stage_transform_tester -from .utils import parse_meds_csvs +from meds import subject_id_field + +from tests.MEDS_Transforms import RESHARD_TO_SPLIT_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester +from tests.utils import parse_meds_csvs IN_SHARDS_MAP = { "0": [68729, 1195293], @@ -14,8 +17,8 @@ "2": [239684, 1500733], } -IN_SHARD_0 = """ -patient_id,time,code,numeric_value +IN_SHARD_0 = f""" +{subject_id_field},time,code,numeric_value 68729,,EYE_COLOR//HAZEL, 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00",DOB, @@ -42,8 +45,8 @@ 1195293,"06/20/2010, 20:50:04",DISCHARGE, """ -IN_SHARD_1 = """ -patient_id,time,code,numeric_value +IN_SHARD_1 = f""" +{subject_id_field},time,code,numeric_value 754281,,EYE_COLOR//BROWN, 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, @@ -60,8 +63,8 @@ 814703,"02/05/2010, 07:02:30",DISCHARGE, """ -IN_SHARD_2 = """ -patient_id,time,code,numeric_value +IN_SHARD_2 = f""" +{subject_id_field},time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -94,8 +97,8 @@ "held_out": [1500733], } -WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value +WANT_TRAIN_0 = f""" +{subject_id_field},time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -128,8 +131,8 @@ 1195293,"06/20/2010, 20:50:04",DISCHARGE, """ -WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value +WANT_TRAIN_1 = f""" +{subject_id_field},time,code,numeric_value 68729,,EYE_COLOR//HAZEL, 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00",DOB, @@ -146,8 +149,8 @@ 814703,"02/05/2010, 07:02:30",DISCHARGE, """ -WANT_TUNING_0 = """ -patient_id,time,code,numeric_value +WANT_TUNING_0 = f""" +{subject_id_field},time,code,numeric_value 754281,,EYE_COLOR//BROWN, 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, @@ -157,8 +160,8 @@ 754281,"01/03/2010, 08:22:13",DISCHARGE, """ -WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +WANT_HELD_OUT_0 = f""" +{subject_id_field},time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, @@ -194,9 +197,20 @@ def test_reshard_to_split(): single_stage_transform_tester( transform_script=RESHARD_TO_SPLIT_SCRIPT, stage_name="reshard_to_split", - transform_stage_kwargs={"n_patients_per_shard": 2}, + transform_stage_kwargs={"n_subjects_per_shard": 2}, + want_data=WANT_SHARDS, + input_shards=IN_SHARDS, + input_shards_map=IN_SHARDS_MAP, + input_splits_map=SPLITS, + ) + + single_stage_transform_tester( + transform_script=RESHARD_TO_SPLIT_SCRIPT, + stage_name="reshard_to_split", + transform_stage_kwargs={"n_patients_per_shard": 2, "+train_only": True}, want_data=WANT_SHARDS, input_shards=IN_SHARDS, input_shards_map=IN_SHARDS_MAP, input_splits_map=SPLITS, + should_error=True, ) diff --git a/tests/test_tensorization.py b/tests/MEDS_Transforms/test_tensorization.py similarity index 77% rename from tests/test_tensorization.py rename to tests/MEDS_Transforms/test_tensorization.py index 0337155..f56d064 100644 --- a/tests/test_tensorization.py +++ b/tests/MEDS_Transforms/test_tensorization.py @@ -9,8 +9,9 @@ from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict -from .test_tokenization import WANT_EVENT_SEQS as TOKENIZED_SHARDS -from .transform_tester_base import TENSORIZATION_SCRIPT, single_stage_transform_tester +from tests.MEDS_Transforms import TENSORIZATION_SCRIPT +from tests.MEDS_Transforms.test_tokenization import WANT_EVENT_SEQS as TOKENIZED_SHARDS +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester WANT_NRTS = { f'{k.replace("event_seqs/", "")}.nrt': JointNestedRaggedTensorDict( diff --git a/tests/test_tokenization.py b/tests/MEDS_Transforms/test_tokenization.py similarity index 86% rename from tests/test_tokenization.py rename to tests/MEDS_Transforms/test_tokenization.py index 693add1..470b425 100644 --- a/tests/test_tokenization.py +++ b/tests/MEDS_Transforms/test_tokenization.py @@ -10,9 +10,11 @@ import polars as pl +from tests.MEDS_Transforms import TOKENIZATION_SCRIPT + from .test_normalization import NORMALIZED_MEDS_SCHEMA from .test_normalization import WANT_SHARDS as NORMALIZED_SHARDS -from .transform_tester_base import TOKENIZATION_SCRIPT, single_stage_transform_tester +from .transform_tester_base import single_stage_transform_tester SECONDS_PER_DAY = 60 * 60 * 24 @@ -20,17 +22,17 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: """TODO: Doctests""" out = [] - for patient_ts in ts: + for subject_ts in ts: out.append([float("nan")]) - for i in range(1, len(patient_ts)): - out[-1].append((patient_ts[i] - patient_ts[i - 1]).total_seconds() / SECONDS_PER_DAY) + for i in range(1, len(subject_ts)): + out[-1].append((subject_ts[i] - subject_ts[i - 1]).total_seconds() / SECONDS_PER_DAY) return out # TODO: Make these schemas exportable, maybe??? # TODO: Why is the code getting converted to a float? SCHEMAS_SCHEMA = { - "patient_id": NORMALIZED_MEDS_SCHEMA["patient_id"], + "subject_id": NORMALIZED_MEDS_SCHEMA["subject_id"], "code": pl.List(NORMALIZED_MEDS_SCHEMA["code"]), "numeric_value": pl.List(NORMALIZED_MEDS_SCHEMA["numeric_value"]), "start_time": NORMALIZED_MEDS_SCHEMA["time"], @@ -38,7 +40,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: } SEQ_SCHEMA = { - "patient_id": NORMALIZED_MEDS_SCHEMA["patient_id"], + "subject_id": NORMALIZED_MEDS_SCHEMA["subject_id"], "code": pl.List(pl.List(pl.UInt8)), "numeric_value": pl.List(pl.List(NORMALIZED_MEDS_SCHEMA["numeric_value"])), "time_delta_days": pl.List(pl.Float32), @@ -66,7 +68,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: ] WANT_SCHEMAS_TRAIN_0 = pl.DataFrame( { - "patient_id": [239684, 1195293], + "subject_id": [239684, 1195293], "code": [[7, 9], [6, 9]], "numeric_value": [[None, 1.5770268440246582], [None, 0.06802856922149658]], "start_time": [ts[0] for ts in TRAIN_0_TIMES], @@ -77,7 +79,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: WANT_EVENT_SEQ_TRAIN_0 = pl.DataFrame( { - "patient_id": [239684, 1195293], + "subject_id": [239684, 1195293], "time_delta_days": ts_to_time_delta_days(TRAIN_0_TIMES), "code": [ [[5], [1, 10, 11], [10, 11], [10, 11], [10, 11], [4]], @@ -114,7 +116,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: WANT_SCHEMAS_TRAIN_1 = pl.DataFrame( { - "patient_id": [68729, 814703], + "subject_id": [68729, 814703], "code": [[8, 9], [8, 9]], "numeric_value": [[None, -0.5438239574432373], [None, -1.1012336015701294]], "start_time": [ts[0] for ts in TRAIN_1_TIMES], @@ -125,7 +127,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: WANT_EVENT_SEQ_TRAIN_1 = pl.DataFrame( { - "patient_id": [68729, 814703], + "subject_id": [68729, 814703], "time_delta_days": ts_to_time_delta_days(TRAIN_1_TIMES), "code": [[[5], [3, 10, 11], [4]], [[5], [2, 10, 11], [4]]], "numeric_value": [ @@ -140,7 +142,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: WANT_SCHEMAS_TUNING_0 = pl.DataFrame( { - "patient_id": [754281], + "subject_id": [754281], "code": [[7, 9]], "numeric_value": [[None, 0.28697699308395386]], "start_time": [ts[0] for ts in TUNING_0_TIMES], @@ -151,7 +153,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: WANT_EVENT_SEQ_TUNING_0 = pl.DataFrame( { - "patient_id": [754281], + "subject_id": [754281], "time_delta_days": ts_to_time_delta_days(TUNING_0_TIMES), "code": [[[5], [3, 10, 11], [4]]], "numeric_value": [ @@ -174,7 +176,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: WANT_SCHEMAS_HELD_OUT_0 = pl.DataFrame( { - "patient_id": [1500733], + "subject_id": [1500733], "code": [[7, 9]], "numeric_value": [[None, -0.7995940446853638]], "start_time": [ts[0] for ts in HELD_OUT_0_TIMES], @@ -185,7 +187,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: WANT_EVENT_SEQ_HELD_OUT_0 = pl.DataFrame( { - "patient_id": [1500733], + "subject_id": [1500733], "time_delta_days": ts_to_time_delta_days(HELD_OUT_0_TIMES), "code": [[[5], [2, 10, 11], [10, 11], [10, 11], [4]]], "numeric_value": [ @@ -224,4 +226,14 @@ def test_tokenization(): transform_stage_kwargs=None, input_shards=NORMALIZED_SHARDS, want_data={**WANT_SCHEMAS, **WANT_EVENT_SEQS}, + df_check_kwargs={"check_column_order": False}, + ) + + single_stage_transform_tester( + transform_script=TOKENIZATION_SCRIPT, + stage_name="tokenization", + transform_stage_kwargs={"train_only": True}, + input_shards=NORMALIZED_SHARDS, + want_data={**WANT_SCHEMAS, **WANT_EVENT_SEQS}, + should_error=True, ) diff --git a/tests/MEDS_Transforms/transform_tester_base.py b/tests/MEDS_Transforms/transform_tester_base.py new file mode 100644 index 0000000..6a1d4ab --- /dev/null +++ b/tests/MEDS_Transforms/transform_tester_base.py @@ -0,0 +1,275 @@ +"""Base helper code and data inputs for all transforms integration tests. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + + +try: + pass +except ImportError: + pass + +from collections import defaultdict +from io import StringIO +from pathlib import Path + +import polars as pl +from meds import subject_id_field + +from tests.utils import FILE_T, multi_stage_tester, parse_meds_csvs, parse_shards_yaml, single_stage_tester + +# So it can be imported from here +parse_shards_yaml = parse_shards_yaml + +# Test MEDS data (inputs) + +SHARDS = { + "train/0": [239684, 1195293], + "train/1": [68729, 814703], + "tuning/0": [754281], + "held_out/0": [1500733], +} + +SPLITS = { + "train": [239684, 1195293, 68729, 814703], + "tuning": [754281], + "held_out": [1500733], +} + +MEDS_TRAIN_0 = """ +subject_id,time,code,numeric_value +239684,,EYE_COLOR//BROWN, +239684,,HEIGHT,175.271115221764 +239684,"12/28/1980, 00:00:00",DOB, +239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, +239684,"05/11/2010, 17:41:51",HR,102.6 +239684,"05/11/2010, 17:41:51",TEMP,96.0 +239684,"05/11/2010, 17:48:48",HR,105.1 +239684,"05/11/2010, 17:48:48",TEMP,96.2 +239684,"05/11/2010, 18:25:35",HR,113.4 +239684,"05/11/2010, 18:25:35",TEMP,95.8 +239684,"05/11/2010, 18:57:18",HR,112.6 +239684,"05/11/2010, 18:57:18",TEMP,95.5 +239684,"05/11/2010, 19:27:19",DISCHARGE, +1195293,,EYE_COLOR//BLUE, +1195293,,HEIGHT,164.6868838269085 +1195293,"06/20/1978, 00:00:00",DOB, +1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, +1195293,"06/20/2010, 19:23:52",HR,109.0 +1195293,"06/20/2010, 19:23:52",TEMP,100.0 +1195293,"06/20/2010, 19:25:32",HR,114.1 +1195293,"06/20/2010, 19:25:32",TEMP,100.0 +1195293,"06/20/2010, 19:45:19",HR,119.8 +1195293,"06/20/2010, 19:45:19",TEMP,99.9 +1195293,"06/20/2010, 20:12:31",HR,112.5 +1195293,"06/20/2010, 20:12:31",TEMP,99.8 +1195293,"06/20/2010, 20:24:44",HR,107.7 +1195293,"06/20/2010, 20:24:44",TEMP,100.0 +1195293,"06/20/2010, 20:41:33",HR,107.5 +1195293,"06/20/2010, 20:41:33",TEMP,100.4 +1195293,"06/20/2010, 20:50:04",DISCHARGE, +""" + +MEDS_TRAIN_1 = """ +subject_id,time,code,numeric_value +68729,,EYE_COLOR//HAZEL, +68729,,HEIGHT,160.3953106166676 +68729,"03/09/1978, 00:00:00",DOB, +68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, +68729,"05/26/2010, 02:30:56",HR,86.0 +68729,"05/26/2010, 02:30:56",TEMP,97.8 +68729,"05/26/2010, 04:51:52",DISCHARGE, +814703,,EYE_COLOR//HAZEL, +814703,,HEIGHT,156.48559093209357 +814703,"03/28/1976, 00:00:00",DOB, +814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, +814703,"02/05/2010, 05:55:39",HR,170.2 +814703,"02/05/2010, 05:55:39",TEMP,100.1 +814703,"02/05/2010, 07:02:30",DISCHARGE, +""" + +MEDS_TUNING_0 = """ +subject_id,time,code,numeric_value +754281,,EYE_COLOR//BROWN, +754281,,HEIGHT,166.22261567137025 +754281,"12/19/1988, 00:00:00",DOB, +754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, +754281,"01/03/2010, 06:27:59",HR,142.0 +754281,"01/03/2010, 06:27:59",TEMP,99.8 +754281,"01/03/2010, 08:22:13",DISCHARGE, +""" + +MEDS_HELD_OUT_0 = """ +subject_id,time,code,numeric_value +1500733,,EYE_COLOR//BROWN, +1500733,,HEIGHT,158.60131573580904 +1500733,"07/20/1986, 00:00:00",DOB, +1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, +1500733,"06/03/2010, 14:54:38",HR,91.4 +1500733,"06/03/2010, 14:54:38",TEMP,100.0 +1500733,"06/03/2010, 15:39:49",HR,84.4 +1500733,"06/03/2010, 15:39:49",TEMP,100.3 +1500733,"06/03/2010, 16:20:49",HR,90.1 +1500733,"06/03/2010, 16:20:49",TEMP,100.1 +1500733,"06/03/2010, 16:44:26",DISCHARGE, +""" + +MEDS_SHARDS = parse_meds_csvs( + { + "train/0": MEDS_TRAIN_0, + "train/1": MEDS_TRAIN_1, + "tuning/0": MEDS_TUNING_0, + "held_out/0": MEDS_HELD_OUT_0, + } +) + + +MEDS_CODE_METADATA_CSV = """ +code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/sum,values/sum_sqd,description,parent_codes +,44,4,28,3198.8389005974336,382968.28937288234,, +ADMISSION//CARDIAC,2,2,0,,,, +ADMISSION//ORTHOPEDIC,1,1,0,,,, +ADMISSION//PULMONARY,1,1,0,,,, +DISCHARGE,4,4,0,,,, +DOB,4,4,0,,,, +EYE_COLOR//BLUE,1,1,0,,,"Blue Eyes. Less common than brown.", +EYE_COLOR//BROWN,1,1,0,,,"Brown Eyes. The most common eye color.", +EYE_COLOR//HAZEL,2,2,0,,,"Hazel eyes. These are uncommon", +HEIGHT,4,4,4,656.8389005974336,108056.12937288235,, +HR,12,4,12,1360.5000000000002,158538.77,"Heart Rate",LOINC/8867-4 +TEMP,12,4,12,1181.4999999999998,116373.38999999998,"Body Temperature",LOINC/8310-5 +""" + +MEDS_CODE_METADATA_SCHEMA = { + "code": pl.Utf8, + "code/n_occurrences": pl.UInt8, + "code/n_subjects": pl.UInt8, + "values/n_occurrences": pl.UInt8, + "values/n_subjects": pl.UInt8, + "values/sum": pl.Float32, + "values/sum_sqd": pl.Float32, + "values/n_ints": pl.UInt8, + "values/min": pl.Float32, + "values/max": pl.Float32, + "description": pl.Utf8, + "parent_codes": pl.Utf8, + "code/vocab_index": pl.UInt8, +} + + +def parse_code_metadata_csv(csv_str: str) -> pl.DataFrame: + cols = csv_str.strip().split("\n")[0].split(",") + schema = {col: dt for col, dt in MEDS_CODE_METADATA_SCHEMA.items() if col in cols} + df = pl.read_csv(StringIO(csv_str), schema=schema) + if "parent_codes" in cols: + df = df.with_columns(pl.col("parent_codes").cast(pl.List(pl.Utf8))) + return df + + +MEDS_CODE_METADATA = parse_code_metadata_csv(MEDS_CODE_METADATA_CSV) + + +def remap_inputs_for_transform( + input_code_metadata: pl.DataFrame | str | None = None, + input_shards: dict[str, pl.DataFrame] | None = None, + input_shards_map: dict[str, list[int]] | None = None, + input_splits_map: dict[str, list[int]] | None = None, +) -> dict[str, FILE_T]: + unified_inputs = {} + + if input_code_metadata is None: + input_code_metadata = MEDS_CODE_METADATA + elif isinstance(input_code_metadata, str): + input_code_metadata = parse_code_metadata_csv(input_code_metadata) + + unified_inputs["metadata/codes.parquet"] = input_code_metadata + + if input_shards is None: + input_shards = MEDS_SHARDS + + for shard_name, df in input_shards.items(): + unified_inputs[f"data/{shard_name}.parquet"] = df + + if input_shards_map is None: + input_shards_map = SHARDS + + unified_inputs["metadata/.shards.json"] = input_shards_map + + if input_splits_map is None: + input_splits_map = SPLITS + + input_splits_as_df = defaultdict(list) + for split_name, subject_ids in input_splits_map.items(): + input_splits_as_df[subject_id_field].extend(subject_ids) + input_splits_as_df["split"].extend([split_name] * len(subject_ids)) + + input_splits_df = pl.DataFrame(input_splits_as_df) + + unified_inputs["metadata/subject_splits.parquet"] = input_splits_df + + return unified_inputs + + +def single_stage_transform_tester( + transform_script: str | Path, + stage_name: str, + transform_stage_kwargs: dict[str, str] | None, + do_pass_stage_name: bool = False, + do_use_config_yaml: bool = False, + want_data: dict[str, pl.DataFrame] | None = None, + want_metadata: pl.DataFrame | None = None, + assert_no_other_outputs: bool = True, + should_error: bool = False, + df_check_kwargs: dict | None = None, + **input_data_kwargs, +): + if df_check_kwargs is None: + df_check_kwargs = {} + + base_kwargs = { + "script": transform_script, + "stage_name": stage_name, + "stage_kwargs": transform_stage_kwargs, + "do_pass_stage_name": do_pass_stage_name, + "do_use_config_yaml": do_use_config_yaml, + "assert_no_other_outputs": assert_no_other_outputs, + "should_error": should_error, + "config_name": "preprocess", + "input_files": remap_inputs_for_transform(**input_data_kwargs), + "df_check_kwargs": df_check_kwargs, + } + + want_outputs = {} + if want_data: + for data_fn, want in want_data.items(): + want_outputs[f"data/{data_fn}"] = want + if want_metadata is not None: + want_outputs["metadata/codes.parquet"] = want_metadata + + base_kwargs["want_outputs"] = want_outputs + + single_stage_tester(**base_kwargs) + + +def multi_stage_transform_tester( + transform_scripts: list[str | Path], + stage_names: list[str], + stage_configs: dict[str, str] | str | None, + do_pass_stage_name: bool | dict[str, bool] = True, + want_data: dict[str, pl.DataFrame] | None = None, + want_metadata: pl.DataFrame | None = None, + **input_data_kwargs, +): + base_kwargs = { + "scripts": transform_scripts, + "stage_names": stage_names, + "stage_configs": stage_configs, + "do_pass_stage_name": do_pass_stage_name, + "assert_no_other_outputs": False, # TODO(mmd): eventually fix + "config_name": "preprocess", + "input_files": remap_inputs_for_transform(**input_data_kwargs), + "want_outputs": {**want_data, **want_metadata}, + } + + multi_stage_tester(**base_kwargs) diff --git a/tests/transform_tester_base.py b/tests/transform_tester_base.py deleted file mode 100644 index bca36ad..0000000 --- a/tests/transform_tester_base.py +++ /dev/null @@ -1,503 +0,0 @@ -"""Base helper code and data inputs for all transforms integration tests. - -Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed -scripts. -""" - -from yaml import load as load_yaml - -try: - from yaml import CLoader as Loader -except ImportError: - from yaml import Loader - -import json -import os -import tempfile -from collections import defaultdict -from contextlib import contextmanager -from io import StringIO -from pathlib import Path - -import numpy as np -import polars as pl -import rootutils -from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict - -from .utils import MEDS_PL_SCHEMA, assert_df_equal, parse_meds_csvs, run_command - -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) - -code_root = root / "src" / "MEDS_transforms" -transforms_root = code_root / "transforms" -filters_root = code_root / "filters" - -if os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1": - # Root Source - AGGREGATE_CODE_METADATA_SCRIPT = code_root / "aggregate_code_metadata.py" - FIT_VOCABULARY_INDICES_SCRIPT = code_root / "fit_vocabulary_indices.py" - RESHARD_TO_SPLIT_SCRIPT = code_root / "reshard_to_split.py" - - # Filters - FILTER_MEASUREMENTS_SCRIPT = filters_root / "filter_measurements.py" - FILTER_PATIENTS_SCRIPT = filters_root / "filter_patients.py" - - # Transforms - ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT = transforms_root / "add_time_derived_measurements.py" - REORDER_MEASUREMENTS_SCRIPT = transforms_root / "reorder_measurements.py" - EXTRACT_VALUES_SCRIPT = transforms_root / "extract_values.py" - NORMALIZATION_SCRIPT = transforms_root / "normalization.py" - OCCLUDE_OUTLIERS_SCRIPT = transforms_root / "occlude_outliers.py" - TENSORIZATION_SCRIPT = transforms_root / "tensorization.py" - TOKENIZATION_SCRIPT = transforms_root / "tokenization.py" -else: - # Root Source - AGGREGATE_CODE_METADATA_SCRIPT = "MEDS_transform-aggregate_code_metadata" - FIT_VOCABULARY_INDICES_SCRIPT = "MEDS_transform-fit_vocabulary_indices" - RESHARD_TO_SPLIT_SCRIPT = "MEDS_transform-reshard_to_split" - - # Filters - FILTER_MEASUREMENTS_SCRIPT = "MEDS_transform-filter_measurements" - FILTER_PATIENTS_SCRIPT = "MEDS_transform-filter_patients" - - # Transforms - ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT = "MEDS_transform-add_time_derived_measurements" - REORDER_MEASUREMENTS_SCRIPT = "MEDS_transform-reorder_measurements" - EXTRACT_VALUES_SCRIPT = "MEDS_transform-extract_values" - NORMALIZATION_SCRIPT = "MEDS_transform-normalization" - OCCLUDE_OUTLIERS_SCRIPT = "MEDS_transform-occlude_outliers" - TENSORIZATION_SCRIPT = "MEDS_transform-tensorization" - TOKENIZATION_SCRIPT = "MEDS_transform-tokenization" - -# Test MEDS data (inputs) - -SHARDS = { - "train/0": [239684, 1195293], - "train/1": [68729, 814703], - "tuning/0": [754281], - "held_out/0": [1500733], -} - -SPLITS = { - "train": [239684, 1195293, 68729, 814703], - "tuning": [754281], - "held_out": [1500733], -} - -MEDS_TRAIN_0 = """ -patient_id,time,code,numeric_value -239684,,EYE_COLOR//BROWN, -239684,,HEIGHT,175.271115221764 -239684,"12/28/1980, 00:00:00",DOB, -239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, -239684,"05/11/2010, 17:41:51",HR,102.6 -239684,"05/11/2010, 17:41:51",TEMP,96.0 -239684,"05/11/2010, 17:48:48",HR,105.1 -239684,"05/11/2010, 17:48:48",TEMP,96.2 -239684,"05/11/2010, 18:25:35",HR,113.4 -239684,"05/11/2010, 18:25:35",TEMP,95.8 -239684,"05/11/2010, 18:57:18",HR,112.6 -239684,"05/11/2010, 18:57:18",TEMP,95.5 -239684,"05/11/2010, 19:27:19",DISCHARGE, -1195293,,EYE_COLOR//BLUE, -1195293,,HEIGHT,164.6868838269085 -1195293,"06/20/1978, 00:00:00",DOB, -1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, -1195293,"06/20/2010, 19:23:52",HR,109.0 -1195293,"06/20/2010, 19:23:52",TEMP,100.0 -1195293,"06/20/2010, 19:25:32",HR,114.1 -1195293,"06/20/2010, 19:25:32",TEMP,100.0 -1195293,"06/20/2010, 19:45:19",HR,119.8 -1195293,"06/20/2010, 19:45:19",TEMP,99.9 -1195293,"06/20/2010, 20:12:31",HR,112.5 -1195293,"06/20/2010, 20:12:31",TEMP,99.8 -1195293,"06/20/2010, 20:24:44",HR,107.7 -1195293,"06/20/2010, 20:24:44",TEMP,100.0 -1195293,"06/20/2010, 20:41:33",HR,107.5 -1195293,"06/20/2010, 20:41:33",TEMP,100.4 -1195293,"06/20/2010, 20:50:04",DISCHARGE, -""" - -MEDS_TRAIN_1 = """ -patient_id,time,code,numeric_value -68729,,EYE_COLOR//HAZEL, -68729,,HEIGHT,160.3953106166676 -68729,"03/09/1978, 00:00:00",DOB, -68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, -68729,"05/26/2010, 02:30:56",HR,86.0 -68729,"05/26/2010, 02:30:56",TEMP,97.8 -68729,"05/26/2010, 04:51:52",DISCHARGE, -814703,,EYE_COLOR//HAZEL, -814703,,HEIGHT,156.48559093209357 -814703,"03/28/1976, 00:00:00",DOB, -814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, -814703,"02/05/2010, 05:55:39",HR,170.2 -814703,"02/05/2010, 05:55:39",TEMP,100.1 -814703,"02/05/2010, 07:02:30",DISCHARGE, -""" - -MEDS_TUNING_0 = """ -patient_id,time,code,numeric_value -754281,,EYE_COLOR//BROWN, -754281,,HEIGHT,166.22261567137025 -754281,"12/19/1988, 00:00:00",DOB, -754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, -754281,"01/03/2010, 06:27:59",HR,142.0 -754281,"01/03/2010, 06:27:59",TEMP,99.8 -754281,"01/03/2010, 08:22:13",DISCHARGE, -""" - -MEDS_HELD_OUT_0 = """ -patient_id,time,code,numeric_value -1500733,,EYE_COLOR//BROWN, -1500733,,HEIGHT,158.60131573580904 -1500733,"07/20/1986, 00:00:00",DOB, -1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, -1500733,"06/03/2010, 14:54:38",HR,91.4 -1500733,"06/03/2010, 14:54:38",TEMP,100.0 -1500733,"06/03/2010, 15:39:49",HR,84.4 -1500733,"06/03/2010, 15:39:49",TEMP,100.3 -1500733,"06/03/2010, 16:20:49",HR,90.1 -1500733,"06/03/2010, 16:20:49",TEMP,100.1 -1500733,"06/03/2010, 16:44:26",DISCHARGE, -""" - -MEDS_SHARDS = parse_meds_csvs( - { - "train/0": MEDS_TRAIN_0, - "train/1": MEDS_TRAIN_1, - "tuning/0": MEDS_TUNING_0, - "held_out/0": MEDS_HELD_OUT_0, - } -) - - -MEDS_CODE_METADATA_CSV = """ -code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,description,parent_codes -,44,4,28,3198.8389005974336,382968.28937288234,, -ADMISSION//CARDIAC,2,2,0,,,, -ADMISSION//ORTHOPEDIC,1,1,0,,,, -ADMISSION//PULMONARY,1,1,0,,,, -DISCHARGE,4,4,0,,,, -DOB,4,4,0,,,, -EYE_COLOR//BLUE,1,1,0,,,"Blue Eyes. Less common than brown.", -EYE_COLOR//BROWN,1,1,0,,,"Brown Eyes. The most common eye color.", -EYE_COLOR//HAZEL,2,2,0,,,"Hazel eyes. These are uncommon", -HEIGHT,4,4,4,656.8389005974336,108056.12937288235,, -HR,12,4,12,1360.5000000000002,158538.77,"Heart Rate",LOINC/8867-4 -TEMP,12,4,12,1181.4999999999998,116373.38999999998,"Body Temperature",LOINC/8310-5 -""" - -MEDS_CODE_METADATA_SCHEMA = { - "code": pl.Utf8, - "code/n_occurrences": pl.UInt8, - "code/n_patients": pl.UInt8, - "values/n_occurrences": pl.UInt8, - "values/n_patients": pl.UInt8, - "values/sum": pl.Float32, - "values/sum_sqd": pl.Float32, - "values/n_ints": pl.UInt8, - "values/min": pl.Float32, - "values/max": pl.Float32, - "description": pl.Utf8, - "parent_codes": pl.Utf8, - "code/vocab_index": pl.UInt8, -} - - -def parse_shards_yaml(yaml_str: str, **schema_updates) -> pl.DataFrame: - schema = {**MEDS_PL_SCHEMA, **schema_updates} - return parse_meds_csvs(load_yaml(yaml_str, Loader=Loader), schema=schema) - - -def parse_code_metadata_csv(csv_str: str) -> pl.DataFrame: - cols = csv_str.strip().split("\n")[0].split(",") - schema = {col: dt for col, dt in MEDS_CODE_METADATA_SCHEMA.items() if col in cols} - df = pl.read_csv(StringIO(csv_str), schema=schema) - if "parent_codes" in cols: - df = df.with_columns(pl.col("parent_codes").cast(pl.List(pl.Utf8))) - return df - - -MEDS_CODE_METADATA = parse_code_metadata_csv(MEDS_CODE_METADATA_CSV) - - -def check_NRT_output( - output_fp: Path, - want_nrt: JointNestedRaggedTensorDict, -): - assert output_fp.is_file(), f"Expected {output_fp} to exist." - - got_nrt = JointNestedRaggedTensorDict.load(output_fp) - - # assert got_nrt.schema == want_nrt.schema, ( - # f"Expected the schema of the NRT at {output_fp} to be equal to the target.\n" - # f"Wanted:\n{want_nrt.schema}\n" - # f"Got:\n{got_nrt.schema}" - # ) - - want_tensors = want_nrt.tensors - got_tensors = got_nrt.tensors - - assert got_tensors.keys() == want_tensors.keys(), ( - f"Expected the keys of the NRT at {output_fp} to be equal to the target.\n" - f"Wanted:\n{list(want_tensors.keys())}\n" - f"Got:\n{list(got_tensors.keys())}" - ) - - for k in want_tensors.keys(): - want_v = want_tensors[k] - got_v = got_tensors[k] - - assert type(want_v) is type(got_v), ( - f"Expected tensor {k} of the NRT at {output_fp} to be of the same type as the target.\n" - f"Wanted:\n{type(want_v)}\n" - f"Got:\n{type(got_v)}" - ) - - if isinstance(want_v, list): - assert len(want_v) == len(got_v), ( - f"Expected list {k} of the NRT at {output_fp} to be of the same length as the target.\n" - f"Wanted:\n{len(want_v)}\n" - f"Got:\n{len(got_v)}" - ) - for i, (want_i, got_i) in enumerate(zip(want_v, got_v)): - assert np.array_equal(want_i, got_i, equal_nan=True), ( - f"Expected tensor {k}[{i}] of the NRT at {output_fp} to be equal to the target.\n" - f"Wanted:\n{want_i}\n" - f"Got:\n{got_i}" - ) - else: - assert np.array_equal(want_v, got_v, equal_nan=True), ( - f"Expected tensor {k} of the NRT at {output_fp} to be equal to the target.\n" - f"Wanted:\n{want_v}\n" - f"Got:\n{got_v}" - ) - - -def check_df_output( - output_fp: Path, - want_df: pl.DataFrame, - check_column_order: bool = False, - check_row_order: bool = True, - **kwargs, -): - assert output_fp.is_file(), f"Expected {output_fp} to exist." - - got_df = pl.read_parquet(output_fp, glob=False) - assert_df_equal( - want_df, - got_df, - (f"Expected the dataframe at {output_fp} to be equal to the target.\n"), - check_column_order=check_column_order, - check_row_order=check_row_order, - **kwargs, - ) - - -@contextmanager -def input_MEDS_dataset( - input_code_metadata: pl.DataFrame | str | None = None, - input_shards: dict[str, pl.DataFrame] | None = None, - input_shards_map: dict[str, list[int]] | None = None, - input_splits_map: dict[str, list[int]] | None = None, -): - with tempfile.TemporaryDirectory() as d: - MEDS_dir = Path(d) / "MEDS_cohort" - cohort_dir = Path(d) / "output_cohort" - - MEDS_data_dir = MEDS_dir / "data" - MEDS_metadata_dir = MEDS_dir / "metadata" - - # Create the directories - MEDS_data_dir.mkdir(parents=True) - MEDS_metadata_dir.mkdir(parents=True) - cohort_dir.mkdir(parents=True) - - # Write the shards map - if input_shards_map is None: - input_shards_map = SHARDS - - shards_fp = MEDS_metadata_dir / ".shards.json" - shards_fp.write_text(json.dumps(input_shards_map)) - - # Write the splits parquet file - if input_splits_map is None: - input_splits_map = SPLITS - input_splits_as_df = defaultdict(list) - for split_name, patient_ids in input_splits_map.items(): - input_splits_as_df["patient_id"].extend(patient_ids) - input_splits_as_df["split"].extend([split_name] * len(patient_ids)) - input_splits_df = pl.DataFrame(input_splits_as_df) - input_splits_fp = MEDS_metadata_dir / "patient_splits.parquet" - input_splits_df.write_parquet(input_splits_fp, use_pyarrow=True) - - if input_shards is None: - input_shards = MEDS_SHARDS - - # Write the shards - for shard_name, df in input_shards.items(): - fp = MEDS_data_dir / f"{shard_name}.parquet" - fp.parent.mkdir(parents=True, exist_ok=True) - df.write_parquet(fp, use_pyarrow=True) - - code_metadata_fp = MEDS_metadata_dir / "codes.parquet" - if input_code_metadata is None: - input_code_metadata = MEDS_CODE_METADATA - elif isinstance(input_code_metadata, str): - input_code_metadata = parse_code_metadata_csv(input_code_metadata) - input_code_metadata.write_parquet(code_metadata_fp, use_pyarrow=True) - - yield MEDS_dir, cohort_dir - - -def check_outputs( - cohort_dir: Path, - want_data: dict[str, pl.DataFrame] | None = None, - want_metadata: dict[str, pl.DataFrame] | pl.DataFrame | None = None, - assert_no_other_outputs: bool = True, - outputs_from_cohort_dir: bool = False, -): - if want_metadata is not None: - if isinstance(want_metadata, pl.DataFrame): - want_metadata = {"codes.parquet": want_metadata} - metadata_root = cohort_dir if outputs_from_cohort_dir else cohort_dir / "metadata" - for shard_name, want in want_metadata.items(): - if Path(shard_name).suffix == "": - shard_name = f"{shard_name}.parquet" - check_df_output(metadata_root / shard_name, want) - - if want_data: - data_root = cohort_dir if outputs_from_cohort_dir else cohort_dir / "data" - all_file_suffixes = set() - for shard_name, want in want_data.items(): - if Path(shard_name).suffix == "": - shard_name = f"{shard_name}.parquet" - - file_suffix = Path(shard_name).suffix - all_file_suffixes.add(file_suffix) - - output_fp = data_root / f"{shard_name}" - if file_suffix == ".parquet": - check_df_output(output_fp, want) - elif file_suffix == ".nrt": - check_NRT_output(output_fp, want) - else: - raise ValueError(f"Unknown file suffix: {file_suffix}") - - if assert_no_other_outputs: - all_outputs = [] - for suffix in all_file_suffixes: - all_outputs.extend(list((data_root).glob(f"**/*{suffix}"))) - assert len(want_data) == len(all_outputs), ( - f"Want {len(want_data)} outputs, but found {len(all_outputs)}.\n" - f"Found outputs: {[fp.relative_to(data_root) for fp in all_outputs]}\n" - ) - - -def single_stage_transform_tester( - transform_script: str | Path, - stage_name: str, - transform_stage_kwargs: dict[str, str] | None, - do_pass_stage_name: bool = False, - do_use_config_yaml: bool = False, - want_data: dict[str, pl.DataFrame] | None = None, - want_metadata: pl.DataFrame | None = None, - assert_no_other_outputs: bool = True, - **input_data_kwargs, -): - with input_MEDS_dataset(**input_data_kwargs) as (MEDS_dir, cohort_dir): - pipeline_config_kwargs = { - "input_dir": str(MEDS_dir.resolve()), - "cohort_dir": str(cohort_dir.resolve()), - "stages": [stage_name], - "hydra.verbose": True, - } - - if transform_stage_kwargs: - pipeline_config_kwargs["stage_configs"] = {stage_name: transform_stage_kwargs} - - run_command_kwargs = { - "script": transform_script, - "hydra_kwargs": pipeline_config_kwargs, - "test_name": f"Single stage transform: {stage_name}", - } - if do_use_config_yaml: - run_command_kwargs["do_use_config_yaml"] = True - run_command_kwargs["config_name"] = "preprocess" - if do_pass_stage_name: - run_command_kwargs["stage"] = stage_name - run_command_kwargs["do_pass_stage_name"] = True - - # Run the transform - stderr, stdout = run_command(**run_command_kwargs) - - try: - check_outputs(cohort_dir, want_data=want_data, want_metadata=want_metadata) - except Exception as e: - raise AssertionError( - f"Single stage transform {stage_name} failed.\n" - f"Script stdout:\n{stdout}\n" - f"Script stderr:\n{stderr}" - ) from e - - -def multi_stage_transform_tester( - transform_scripts: list[str | Path], - stage_names: list[str], - stage_configs: dict[str, str] | str | None, - do_pass_stage_name: bool | dict[str, bool] = True, - want_data: dict[str, pl.DataFrame] | None = None, - want_metadata: pl.DataFrame | None = None, - outputs_from_cohort_dir: bool = True, - **input_data_kwargs, -): - with input_MEDS_dataset(**input_data_kwargs) as (MEDS_dir, cohort_dir): - match stage_configs: - case None: - stage_configs = {} - case str(): - stage_configs = load_yaml(stage_configs, Loader=Loader) - case dict(): - pass - case _: - raise ValueError(f"Unknown stage_configs type: {type(stage_configs)}") - - match do_pass_stage_name: - case True: - do_pass_stage_name = {stage_name: True for stage_name in stage_names} - case False: - do_pass_stage_name = {stage_name: False for stage_name in stage_names} - case dict(): - pass - case _: - raise ValueError(f"Unknown do_pass_stage_name type: {type(do_pass_stage_name)}") - - pipeline_config_kwargs = { - "input_dir": str(MEDS_dir.resolve()), - "cohort_dir": str(cohort_dir.resolve()), - "stages": stage_names, - "stage_configs": stage_configs, - "hydra.verbose": True, - } - - script_outputs = {} - n_stages = len(stage_names) - for i, (stage, script) in enumerate(zip(stage_names, transform_scripts)): - script_outputs[stage] = run_command( - script=script, - hydra_kwargs=pipeline_config_kwargs, - do_use_config_yaml=True, - config_name="preprocess", - test_name=f"Multi stage transform {i}/{n_stages}: {stage}", - stage_name=stage, - do_pass_stage_name=do_pass_stage_name[stage], - ) - - check_outputs( - cohort_dir, - want_data=want_data, - want_metadata=want_metadata, - outputs_from_cohort_dir=outputs_from_cohort_dir, - assert_no_other_outputs=False, # this currently doesn't work due to metadata / data confusions. - ) diff --git a/tests/utils.py b/tests/utils.py index e7220c9..450a86b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,17 +1,28 @@ +import json import subprocess import tempfile +from contextlib import contextmanager from io import StringIO from pathlib import Path +from typing import Any +import numpy as np import polars as pl +from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict from omegaconf import OmegaConf from polars.testing import assert_frame_equal +from yaml import load as load_yaml + +try: + from yaml import CLoader as Loader +except ImportError: + from yaml import Loader DEFAULT_CSV_TS_FORMAT = "%m/%d/%Y, %H:%M:%S" # TODO: Make use meds library MEDS_PL_SCHEMA = { - "patient_id": pl.UInt32, + "subject_id": pl.Int64, "time": pl.Datetime("us"), "code": pl.String, "numeric_value": pl.Float32, @@ -44,6 +55,11 @@ def reader(csv_str: str) -> pl.DataFrame: return {k: reader(v) for k, v in csvs.items()} +def parse_shards_yaml(yaml_str: str, **schema_updates) -> pl.DataFrame: + schema = {**MEDS_PL_SCHEMA, **schema_updates} + return parse_meds_csvs(load_yaml(yaml_str, Loader=Loader), schema=schema) + + def dict_to_hydra_kwargs(d: dict[str, str]) -> str: """Converts a dictionary to a hydra kwargs string for testing purposes. @@ -70,6 +86,8 @@ def dict_to_hydra_kwargs(d: dict[str, str]) -> str: ValueError: Unexpected type for value for key a: : 2021-11-01 00:00:00 """ + modifier_chars = ["~", "'", "++", "+"] + out = [] for k, v in d.items(): if not isinstance(k, str): @@ -86,11 +104,13 @@ def dict_to_hydra_kwargs(d: dict[str, str]) -> str: case dict(): inner_kwargs = dict_to_hydra_kwargs(v) for inner_kv in inner_kwargs: - if inner_kv.startswith("~"): - out.append(f"~{k}.{inner_kv[1:]}") - elif inner_kv.startswith("'"): - out.append(f"'{k}.{inner_kv[1:]}") - else: + handled = False + for mod in modifier_chars: + if inner_kv.startswith(mod): + out.append(f"{mod}{k}.{inner_kv[len(mod):]}") + handled = True + break + if not handled: out.append(f"{k}.{inner_kv}") case list() | tuple(): v = list(v) @@ -180,11 +200,286 @@ def run_command( def assert_df_equal(want: pl.DataFrame, got: pl.DataFrame, msg: str = None, **kwargs): try: + update_exprs = {} + for k, v in want.schema.items(): + assert k in got.schema, f"missing column {k}." + if kwargs.get("check_dtypes", False): + assert v == got.schema[k], f"column {k} has different types." + if v == pl.List(pl.String) and got.schema[k] == pl.List(pl.String): + update_exprs[k] = pl.col(k).list.join("||") + if update_exprs: + want_cols = want.columns + got_cols = got.columns + + want = want.with_columns(**update_exprs).select(want_cols) + got = got.with_columns(**update_exprs).select(got_cols) + assert_frame_equal(want, got, **kwargs) except AssertionError as e: pl.Config.set_tbl_rows(-1) - print(f"DFs are not equal: {msg}\nwant:") - print(want) - print("got:") - print(got) - raise AssertionError(f"{msg}\n{e}") from e + raise AssertionError(f"{msg}:\nWant:\n{want}\nGot:\n{got}\n{e}") from e + + +def check_NRT_output( + output_fp: Path, + want_nrt: JointNestedRaggedTensorDict, + msg: str, +): + got_nrt = JointNestedRaggedTensorDict.load(output_fp) + + # assert got_nrt.schema == want_nrt.schema, ( + # f"Expected the schema of the NRT at {output_fp} to be equal to the target.\n" + # f"Wanted:\n{want_nrt.schema}\n" + # f"Got:\n{got_nrt.schema}" + # ) + + want_tensors = want_nrt.tensors + got_tensors = got_nrt.tensors + + assert got_tensors.keys() == want_tensors.keys(), ( + f"{msg}:\n" f"Wanted:\n{list(want_tensors.keys())}\n" f"Got:\n{list(got_tensors.keys())}" + ) + + for k in want_tensors.keys(): + want_v = want_tensors[k] + got_v = got_tensors[k] + + assert type(want_v) is type( + got_v + ), f"{msg}: Wanted {k} to be of type {type(want_v)}, got {type(got_v)}." + + if isinstance(want_v, list): + assert len(want_v) == len(got_v), ( + f"Expected list {k} of the NRT at {output_fp} to be of the same length as the target.\n" + f"Wanted:\n{len(want_v)}\n" + f"Got:\n{len(got_v)}" + ) + for i, (want_i, got_i) in enumerate(zip(want_v, got_v)): + assert np.array_equal(want_i, got_i, equal_nan=True), ( + f"Expected tensor {k}[{i}] of the NRT at {output_fp} to be equal to the target.\n" + f"Wanted:\n{want_i}\n" + f"Got:\n{got_i}" + ) + else: + assert np.array_equal(want_v, got_v, equal_nan=True), ( + f"Expected tensor {k} of the NRT at {output_fp} to be equal to the target.\n" + f"Wanted:\n{want_v}\n" + f"Got:\n{got_v}" + ) + + +FILE_T = pl.DataFrame | dict[str, Any] | str + + +@contextmanager +def input_dataset(input_files: dict[str, FILE_T] | None = None): + with tempfile.TemporaryDirectory() as d: + input_dir = Path(d) / "input_cohort" + cohort_dir = Path(d) / "output_cohort" + + for filename, data in input_files.items(): + fp = input_dir / filename + fp.parent.mkdir(parents=True, exist_ok=True) + + match data: + case pl.DataFrame() if fp.suffix == "": + data.write_parquet(fp.with_suffix(".parquet"), use_pyarrow=True) + case pl.DataFrame() if fp.suffix == ".parquet": + data.write_parquet(fp, use_pyarrow=True) + case pl.DataFrame() if fp.suffix == ".csv": + data.write_csv(fp) + case dict() if fp.suffix == "": + fp.with_suffix(".json").write_text(json.dumps(data)) + case dict() if fp.suffix.endswith(".json"): + fp.write_text(json.dumps(data)) + case str(): + fp.write_text(data.strip()) + case _: + raise ValueError(f"Unknown data type {type(data)} for file {fp.relative_to(input_dir)}") + + yield input_dir, cohort_dir + + +def check_outputs( + cohort_dir: Path, + want_outputs: dict[str, pl.DataFrame], + assert_no_other_outputs: bool = True, + **df_check_kwargs, +): + all_file_suffixes = set() + + for output_name, want in want_outputs.items(): + if Path(output_name).suffix == "": + output_name = f"{output_name}.parquet" + + file_suffix = Path(output_name).suffix + all_file_suffixes.add(file_suffix) + + output_fp = cohort_dir / output_name + + files_found = [str(fp.relative_to(cohort_dir)) for fp in cohort_dir.glob("**/*{file_suffix}")] + + if not output_fp.is_file(): + raise AssertionError( + f"Wanted {output_fp.relative_to(cohort_dir)} to exist. " + f"{len(files_found)} {file_suffix} files found: {', '.join(files_found)}" + ) + + msg = f"Expected {output_fp.relative_to(cohort_dir)} to be equal to the target" + + match file_suffix: + case ".parquet": + got_df = pl.read_parquet(output_fp, glob=False) + assert_df_equal(want, got_df, msg=msg, **df_check_kwargs) + case ".nrt": + check_NRT_output(output_fp, want, msg=msg) + case ".json": + with open(output_fp) as f: + got = json.load(f) + assert got == want, ( + f"Expected JSON at {output_fp.relative_to(cohort_dir)} to be equal to the target.\n" + f"Wanted:\n{want}\n" + f"Got:\n{got}" + ) + case _: + raise ValueError(f"Unknown file suffix: {file_suffix}") + + if assert_no_other_outputs: + all_outputs = [] + for suffix in all_file_suffixes: + all_outputs.extend(list(cohort_dir.glob(f"**/*{suffix}"))) + assert len(want_outputs) == len(all_outputs), ( + f"Want {len(want_outputs)} outputs, but found {len(all_outputs)}.\n" + f"Found outputs: {[fp.relative_to(cohort_dir) for fp in all_outputs]}\n" + ) + + +def single_stage_tester( + script: str | Path, + stage_name: str, + stage_kwargs: dict[str, str] | None, + do_pass_stage_name: bool = False, + do_use_config_yaml: bool = False, + want_outputs: dict[str, pl.DataFrame] | None = None, + assert_no_other_outputs: bool = True, + should_error: bool = False, + config_name: str = "preprocess", + input_files: dict[str, FILE_T] | None = None, + df_check_kwargs: dict | None = None, + **pipeline_kwargs, +): + if df_check_kwargs is None: + df_check_kwargs = {} + + with input_dataset(input_files) as (input_dir, cohort_dir): + for k, v in pipeline_kwargs.items(): + if type(v) is str and "{input_dir}" in v: + pipeline_kwargs[k] = v.format(input_dir=str(input_dir.resolve())) + + pipeline_config_kwargs = { + "input_dir": str(input_dir.resolve()), + "cohort_dir": str(cohort_dir.resolve()), + "stages": [stage_name], + "hydra.verbose": True, + **pipeline_kwargs, + } + + if stage_kwargs: + pipeline_config_kwargs["stage_configs"] = {stage_name: stage_kwargs} + + run_command_kwargs = { + "script": script, + "hydra_kwargs": pipeline_config_kwargs, + "test_name": f"Single stage transform: {stage_name}", + "should_error": should_error, + "config_name": config_name, + "do_use_config_yaml": do_use_config_yaml, + } + + if do_pass_stage_name: + run_command_kwargs["stage"] = stage_name + run_command_kwargs["do_pass_stage_name"] = True + + # Run the transform + stderr, stdout = run_command(**run_command_kwargs) + if should_error: + return + + try: + check_outputs( + cohort_dir, + want_outputs=want_outputs, + assert_no_other_outputs=assert_no_other_outputs, + **df_check_kwargs, + ) + except Exception as e: + raise AssertionError( + f"Single stage transform {stage_name} failed -- {e}:\n" + f"Script stdout:\n{stdout}\n" + f"Script stderr:\n{stderr}\n" + ) from e + + +def multi_stage_tester( + scripts: list[str | Path], + stage_names: list[str], + stage_configs: dict[str, str] | str | None, + do_pass_stage_name: bool | dict[str, bool] = True, + want_outputs: dict[str, pl.DataFrame] | None = None, + assert_no_other_outputs: bool = False, + config_name: str = "preprocess", + input_files: dict[str, FILE_T] | None = None, + **pipeline_kwargs, +): + with input_dataset(input_files) as (input_dir, cohort_dir): + match stage_configs: + case None: + stage_configs = {} + case str(): + stage_configs = load_yaml(stage_configs, Loader=Loader) + case dict(): + pass + case _: + raise ValueError(f"Unknown stage_configs type: {type(stage_configs)}") + + match do_pass_stage_name: + case True: + do_pass_stage_name = {stage_name: True for stage_name in stage_names} + case False: + do_pass_stage_name = {stage_name: False for stage_name in stage_names} + case dict(): + pass + case _: + raise ValueError(f"Unknown do_pass_stage_name type: {type(do_pass_stage_name)}") + + pipeline_config_kwargs = { + "input_dir": str(input_dir.resolve()), + "cohort_dir": str(cohort_dir.resolve()), + "stages": stage_names, + "stage_configs": stage_configs, + "hydra.verbose": True, + **pipeline_kwargs, + } + + script_outputs = {} + n_stages = len(stage_names) + for i, (stage, script) in enumerate(zip(stage_names, scripts)): + script_outputs[stage] = run_command( + script=script, + hydra_kwargs=pipeline_config_kwargs, + do_use_config_yaml=True, + config_name=config_name, + test_name=f"Multi stage transform {i}/{n_stages}: {stage}", + stage_name=stage, + do_pass_stage_name=do_pass_stage_name[stage], + ) + + try: + check_outputs( + cohort_dir, + want_outputs=want_outputs, + assert_no_other_outputs=assert_no_other_outputs, + check_column_order=False, + ) + except Exception as e: + raise AssertionError(f"{n_stages}-stage pipeline ({stage_names}) failed--{e}") from e