From b899a8913868afcc469f9e9c72e8ff30fef6b5ad Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 21 Aug 2024 17:35:32 -0400 Subject: [PATCH 01/23] Updating to MEDS v0.3.2 by correcting the subject ID field name. --- MIMIC-IV_Example/README.md | 16 +-- MIMIC-IV_Example/configs/event_configs.yaml | 14 +- MIMIC-IV_Example/joint_script.sh | 8 +- MIMIC-IV_Example/joint_script_slurm.sh | 6 +- README.md | 72 +++++----- eICU_Example/README.md | 32 ++--- eICU_Example/configs/event_configs.yaml | 6 +- eICU_Example/joint_script.sh | 10 +- eICU_Example/joint_script_slurm.sh | 6 +- pyproject.toml | 6 +- src/MEDS_transforms/__init__.py | 11 +- .../aggregate_code_metadata.py | 75 +++++----- src/MEDS_transforms/configs/extract.yaml | 4 +- src/MEDS_transforms/configs/preprocess.yaml | 4 +- .../stage_configs/count_code_occurrences.yaml | 2 +- .../stage_configs/filter_measurements.yaml | 2 +- .../stage_configs/filter_patients.yaml | 3 - .../stage_configs/filter_subjects.yaml | 3 + .../stage_configs/fit_normalization.yaml | 2 +- .../stage_configs/reshard_to_split.yaml | 2 +- ...nts.yaml => split_and_shard_subjects.yaml} | 4 +- src/MEDS_transforms/extract/README.md | 22 +-- .../extract/convert_to_sharded_events.py | 80 +++++------ .../extract/extract_code_metadata.py | 8 +- .../extract/finalize_MEDS_data.py | 14 +- .../extract/finalize_MEDS_metadata.py | 29 ++-- .../extract/merge_to_MEDS_cohort.py | 36 ++--- src/MEDS_transforms/extract/shard_events.py | 23 +-- ...atients.py => split_and_shard_subjects.py} | 134 +++++++++--------- src/MEDS_transforms/filters/README.md | 2 +- .../filters/filter_measurements.py | 36 ++--- ...{filter_patients.py => filter_subjects.py} | 96 ++++++------- src/MEDS_transforms/mapreduce/mapper.py | 22 +-- src/MEDS_transforms/mapreduce/utils.py | 8 +- src/MEDS_transforms/reshard_to_split.py | 20 +-- .../add_time_derived_measurements.py | 46 +++--- .../transforms/normalization.py | 16 +-- .../transforms/occlude_outliers.py | 10 +- .../transforms/reorder_measurements.py | 22 +-- .../transforms/tensorization.py | 4 +- .../transforms/tokenization.py | 54 +++---- src/MEDS_transforms/utils.py | 22 +-- tests/test_add_time_derived_measurements.py | 19 +-- tests/test_aggregate_code_metadata.py | 10 +- tests/test_extract.py | 76 +++++----- tests/test_extract_no_metadata.py | 76 +++++----- tests/test_filter_measurements.py | 38 ++--- ...er_patients.py => test_filter_subjects.py} | 33 ++--- tests/test_fit_vocabulary_indices.py | 2 +- tests/test_multi_stage_preprocess_pipeline.py | 108 +++++++------- tests/test_normalization.py | 10 +- tests/test_occlude_outliers.py | 10 +- tests/test_reorder_measurements.py | 8 +- tests/test_reshard_to_split.py | 32 +++-- tests/test_tokenization.py | 26 ++-- tests/transform_tester_base.py | 27 ++-- tests/utils.py | 2 +- 57 files changed, 735 insertions(+), 734 deletions(-) delete mode 100644 src/MEDS_transforms/configs/stage_configs/filter_patients.yaml create mode 100644 src/MEDS_transforms/configs/stage_configs/filter_subjects.yaml rename src/MEDS_transforms/configs/stage_configs/{split_and_shard_patients.yaml => split_and_shard_subjects.yaml} (73%) rename src/MEDS_transforms/extract/{split_and_shard_patients.py => split_and_shard_subjects.py} (71%) rename src/MEDS_transforms/filters/{filter_patients.py => filter_subjects.py} (67%) rename tests/{test_filter_patients.py => test_filter_subjects.py} (75%) diff --git a/MIMIC-IV_Example/README.md b/MIMIC-IV_Example/README.md index c003860..6bf348d 100644 --- a/MIMIC-IV_Example/README.md +++ b/MIMIC-IV_Example/README.md @@ -76,10 +76,10 @@ This is a step in a few parts: - the `hosp/diagnoses_icd` table with the `hosp/admissions` table to get the `dischtime` for each `hadm_id`. - the `hosp/drgcodes` table with the `hosp/admissions` table to get the `dischtime` for each `hadm_id`. -2. Convert the patient's static data to a more parseable form. This entails: - - Get the patient's DOB in a format that is usable for MEDS, rather than the integral `anchor_year` and +2. Convert the subject's static data to a more parseable form. This entails: + - Get the subject's DOB in a format that is usable for MEDS, rather than the integral `anchor_year` and `anchor_offset` fields. - - Merge the patient's `dod` with the `deathtime` from the `admissions` table. + - Merge the subject's `dod` with the `deathtime` from the `admissions` table. After these steps, modified files or symlinks to the original files will be written in a new directory which will be used as the input to the actual MEDS extraction ETL. We'll use `$MIMICIV_PREMEDS_DIR` to denote this @@ -109,14 +109,14 @@ This is a step in 4 parts: This step uses the `./scripts/extraction/shard_events.py` script. See `joint_script*.sh` for the expected format of the command. -2. Extract and form the patient splits and sub-shards. The `./scripts/extraction/split_and_shard_patients.py` +2. Extract and form the subject splits and sub-shards. The `./scripts/extraction/split_and_shard_subjects.py` script is used for this step. See `joint_script*.sh` for the expected format of the command. -3. Extract patient sub-shards and convert to MEDS events. The +3. Extract subject sub-shards and convert to MEDS events. The `./scripts/extraction/convert_to_sharded_events.py` script is used for this step. See `joint_script*.sh` for the expected format of the command. -4. Merge the MEDS events into a single file per patient sub-shard. The +4. Merge the MEDS events into a single file per subject sub-shard. The `./scripts/extraction/merge_to_MEDS_cohort.py` script is used for this step. See `joint_script*.sh` for the expected format of the command. @@ -139,7 +139,7 @@ timeline which is otherwise stored at the _datetime_ resolution? Other questions: -1. How to handle merging the deathtimes between the hosp table and the patients table? +1. How to handle merging the deathtimes between the hosp table and the subjects table? 2. How to handle the dob nonsense MIMIC has? ## Notes @@ -153,4 +153,4 @@ may need to run `unset SLURM_CPU_BIND` in your terminal first to avoid errors. If you wanted, some other processing could also be done here, such as: -1. Converting the patient's dynamically recorded race into a static, most commonly recorded race field. +1. Converting the subject's dynamically recorded race into a static, most commonly recorded race field. diff --git a/MIMIC-IV_Example/configs/event_configs.yaml b/MIMIC-IV_Example/configs/event_configs.yaml index 0cd0381..619d5a2 100644 --- a/MIMIC-IV_Example/configs/event_configs.yaml +++ b/MIMIC-IV_Example/configs/event_configs.yaml @@ -1,4 +1,4 @@ -patient_id_col: subject_id +subject_id_col: subject_id hosp/admissions: ed_registration: code: ED_REGISTRATION @@ -27,7 +27,7 @@ hosp/admissions: time: col(dischtime) time_format: "%Y-%m-%d %H:%M:%S" hadm_id: hadm_id - # We omit the death event here as it is joined to the data in the patients table in the pre-MEDS step. + # We omit the death event here as it is joined to the data in the subjects table in the pre-MEDS step. hosp/diagnoses_icd: diagnosis: @@ -108,7 +108,7 @@ hosp/omr: time: col(chartdate) time_format: "%Y-%m-%d" -hosp/patients: +hosp/subjects: gender: code: - GENDER @@ -295,18 +295,18 @@ icu/inputevents: description: ["omop_concept_name", "label"] # List of strings are columns to be collated itemid: "itemid (omop_source_code)" parent_codes: "{omop_vocabulary_id}/{omop_concept_code}" - patient_weight: + subject_weight: code: - - PATIENT_WEIGHT_AT_INFUSION + - SUBJECT_WEIGHT_AT_INFUSION - KG time: col(starttime) time_format: "%Y-%m-%d %H:%M:%S" - numeric_value: patientweight + numeric_value: subjectweight icu/outputevents: output: code: - - PATIENT_FLUID_OUTPUT + - SUBJECT_FLUID_OUTPUT - col(itemid) - col(valueuom) time: col(charttime) diff --git a/MIMIC-IV_Example/joint_script.sh b/MIMIC-IV_Example/joint_script.sh index a98fee7..dd1459c 100755 --- a/MIMIC-IV_Example/joint_script.sh +++ b/MIMIC-IV_Example/joint_script.sh @@ -8,7 +8,7 @@ function display_help() { echo "Usage: $0 " echo echo "This script processes MIMIC-IV data through several steps, handling raw data conversion," - echo "sharding events, splitting patients, converting to sharded events, and merging into a MEDS cohort." + echo "sharding events, splitting subjects, converting to sharded events, and merging into a MEDS cohort." echo echo "Arguments:" echo " MIMICIV_RAW_DIR Directory containing raw MIMIC-IV data files." @@ -88,11 +88,11 @@ MEDS_extract-shard_events \ etl_metadata.dataset_version="2.2" \ event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" -echo "Splitting patients in serial" -MEDS_extract-split_and_shard_patients \ +echo "Splitting subjects in serial" +MEDS_extract-split_and_shard_subjects \ input_dir="$MIMICIV_PREMEDS_DIR" \ cohort_dir="$MIMICIV_MEDS_DIR" \ - stage="split_and_shard_patients" \ + stage="split_and_shard_subjects" \ etl_metadata.dataset_name="MIMIC-IV" \ etl_metadata.dataset_version="2.2" \ event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@" diff --git a/MIMIC-IV_Example/joint_script_slurm.sh b/MIMIC-IV_Example/joint_script_slurm.sh index 3ff9684..e13fb7e 100755 --- a/MIMIC-IV_Example/joint_script_slurm.sh +++ b/MIMIC-IV_Example/joint_script_slurm.sh @@ -8,7 +8,7 @@ function display_help() { echo "Usage: $0 " echo echo "This script processes MIMIC-IV data through several steps, handling raw data conversion," - echo "sharding events, splitting patients, converting to sharded events, and merging into a MEDS cohort." + echo "sharding events, splitting subjects, converting to sharded events, and merging into a MEDS cohort." echo "This script uses slurm to process the data in parallel via the 'submitit' Hydra launcher." echo echo "Arguments:" @@ -72,8 +72,8 @@ MEDS_extract-shard_events \ event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml \ stage=shard_events -echo "Splitting patients on one worker" -MEDS_extract-split_and_shard_patients \ +echo "Splitting subjects on one worker" +MEDS_extract-split_and_shard_subjects \ --multirun \ worker="range(0,1)" \ hydra/launcher=submitit_slurm \ diff --git a/README.md b/README.md index e797d3f..a9e057d 100644 --- a/README.md +++ b/README.md @@ -45,12 +45,12 @@ directories. The fundamental design philosophy of this repository can be summarized as follows: 1. _(The MEDS Assumption)_: All structured electronic health record (EHR) data can be represented as a - series of events, each of which is associated with a patient, a time, and a set of codes and + series of events, each of which is associated with a subject, a time, and a set of codes and numeric values. This representation is the Medical Event Data Standard (MEDS) format, and in this - repository we use it in the "flat" format, where data is organized as rows of `patient_id`, + repository we use it in the "flat" format, where data is organized as rows of `subject_id`, `time`, `code`, `numeric_value` columns. 2. _Easy Efficiency through Sharding_: MEDS datasets in this repository are sharded into smaller, more - manageable pieces (organized as separate files) at the patient level (and, during the raw-data extraction + manageable pieces (organized as separate files) at the subject level (and, during the raw-data extraction process, the event level). This enables users to scale up their processing capabilities ad nauseum by leveraging more workers to process these shards in parallel. This parallelization is seamlessly enabled with the configuration schema used in the scripts in this repository. This style of parallelization @@ -62,7 +62,7 @@ The fundamental design philosophy of this repository can be summarized as follow the others, and each stage is designed to do a small amount of work and be easily testable in isolation. This design philosophy ensures that the pipeline is robust to changes, easy to debug, and easy to extend. In particular, to add new operations specific to a given model or dataset, the user need only write - simple functions that take in a flat MEDS dataframe (representing a single patient level shard) and + simple functions that take in a flat MEDS dataframe (representing a single subject level shard) and return a new flat MEDS dataframe, and then wrap that function in a script by following the examples provided in this repository. These individual functions can use the same configuration schema as other stages in the pipeline or include a separate, stage-specific configuration, and can use whatever @@ -198,9 +198,9 @@ To use this repository as a template, the user should follow these steps: Assumptions: 1. Your data is organized in a set of parquet files on disk such that each row of each file corresponds to - one or more measurements per patient and has all necessary information in that row to extract said - measurement, organized in a simple, columnar format. Each of these parquet files stores the patient's ID in - a column called `patient_id` in the same type. + one or more measurements per subject and has all necessary information in that row to extract said + measurement, organized in a simple, columnar format. Each of these parquet files stores the subject's ID in + a column called `subject_id` in the same type. 2. You have a pre-defined or can externally define the requisite MEDS base `code_metadata` file that describes the codes in your data as necessary. This file is not used in the provided pre-processing pipeline in this package, but is necessary for other uses of the MEDS data. @@ -221,16 +221,16 @@ The provided ETL consists of the following steps, which can be performed as need degree of parallelism is desired per step. 1. It re-shards the input data into a set of smaller, event-level shards to facilitate parallel processing. - This can be skipped if your input data is already suitably sharded at either a per-patient or per-event + This can be skipped if your input data is already suitably sharded at either a per-subject or per-event level. 2. It extracts the subject IDs from the sharded data and computes the set of ML splits and (per split) the - patient shards. These are stored in a JSON file in the output cohort directory. + subject shards. These are stored in a JSON file in the output cohort directory. 3. It converts the input, event level shards into the MEDS flat format and joins and shards these data into - patient-level shards for MEDS use and stores them in a nested format in the output cohort directory, + subject-level shards for MEDS use and stores them in a nested format in the output cohort directory, again in the flat format. This step can be broken down into two sub-steps: - - First, each input shard is converted to the MEDS flat format and split into sub patient-level shards. - - Second, the appropriate sub patient-level shards are joined and re-organized into the final - patient-level shards. This method ensures that we minimize the amount of read contention on the input + - First, each input shard is converted to the MEDS flat format and split into sub subject-level shards. + - Second, the appropriate sub subject-level shards are joined and re-organized into the final + subject-level shards. This method ensures that we minimize the amount of read contention on the input shards during the join process and can maximize parallel throughput, as (theoretically, with sufficient workers) all input shards can be sub-sharded in parallel and then all output shards can be joined in parallel. @@ -239,7 +239,7 @@ The ETL scripts all use [Hydra](https://hydra.cc/) for configuration management, `configs/extraction.yaml` file for configuration. The user can override any of these settings in the normal way for Hydra configurations. -If desired, appropriate scripts can be written and run at a per-patient shard level to convert between the +If desired, appropriate scripts can be written and run at a per-subject shard level to convert between the flat format and any of the other valid nested MEDS format, but for now we leave that up to the user. #### Input Event Extraction @@ -250,11 +250,11 @@ dataframes should be parsed into different event formats. The YAML file stores a following structure: ```yaml -patient_id: $GLOBAL_PATIENT_ID_OVERWRITE # Optional, if you want to overwrite the patient ID column name for - # all inputs. If not specified, defaults to "patient_id". +subject_id: $GLOBAL_SUBJECT_ID_OVERWRITE # Optional, if you want to overwrite the subject ID column name for + # all inputs. If not specified, defaults to "subject_id". $INPUT_FILE_STEM: - patient_id: $INPUT_FILE_PATIENT_ID # Optional, if you want to overwrite the patient ID column name for - # this input. IF not specified, defaults to the global patient ID. + subject_id: $INPUT_FILE_SUBJECT_ID # Optional, if you want to overwrite the subject ID column name for + # this input. IF not specified, defaults to the global subject ID. $EVENT_NAME: code: - $CODE_PART_1 @@ -287,18 +287,18 @@ script is a functional test that is also run with `pytest` to verify correctness 1. `scripts/extraction/shard_events.py` shards the input data into smaller, event-level shards by splitting raw files into chunks of a configurable number of rows. Files are split sequentially, with no regard for - data content or patient boundaries. The resulting files are stored in the `subsharded_events` + data content or subject boundaries. The resulting files are stored in the `subsharded_events` subdirectory of the output directory. -2. `scripts/extraction/split_and_shard_patients.py` splits the patient population into ML splits and shards - these splits into patient-level shards. The result of this process is only a simple `JSON` file - containing the patient IDs belonging to individual splits and shards. This file is stored in the +2. `scripts/extraction/split_and_shard_subjects.py` splits the subject population into ML splits and shards + these splits into subject-level shards. The result of this process is only a simple `JSON` file + containing the subject IDs belonging to individual splits and shards. This file is stored in the `output_directory/splits.json` file. 3. `scripts/extraction/convert_to_sharded_events.py` converts the input, event-level shards into the MEDS - event format and splits them into patient-level sub-shards. So, the resulting files are sharded into - patient-level, then event-level groups and are not merged into full patient-level shards or appropriately + event format and splits them into subject-level sub-shards. So, the resulting files are sharded into + subject-level, then event-level groups and are not merged into full subject-level shards or appropriately sorted for downstream use. -4. `scripts/extraction/merge_to_MEDS_cohort.py` merges the patient-level, event-level shards into full - patient-level shards and sorts them appropriately for downstream use. The resulting files are stored in +4. `scripts/extraction/merge_to_MEDS_cohort.py` merges the subject-level, event-level shards into full + subject-level shards and sorts them appropriately for downstream use. The resulting files are stored in the `output_directory/final_cohort` directory. ## MEDS Pre-processing Transformations @@ -308,9 +308,9 @@ contains a variety of pre-processing transformations and scripts that can be app in various ways to prepare them for downstream modeling. Broadly speaking, the pre-processing pipeline can be broken down into the following steps: -1. Filtering the dataset by criteria that do not require cross-patient analyses, e.g., +1. Filtering the dataset by criteria that do not require cross-subject analyses, e.g., - - Filtering patients by the number of events or unique times they have. + - Filtering subjects by the number of events or unique times they have. - Removing numeric values that fall outside of pre-specified, per-code ranges (e.g., for outlier removal). @@ -318,9 +318,9 @@ broken down into the following steps: - Adding time-derived measurements, e.g., - The time since the last event of a certain type. - - The patient's age as of each unique timepoint. + - The subject's age as of each unique timepoint. - The time-of-day of each event. - - Adding a "dummy" event to the dataset for each patient that occurs at the end of the observation + - Adding a "dummy" event to the dataset for each subject that occurs at the end of the observation period. 3. Iteratively (a) grouping the dataset by `code` and associated code modifier columns and collecting @@ -344,11 +344,11 @@ broken down into the following steps: 5. Normalizing the data to convert codes to indices and numeric values to the desired form (either categorical indices or normalized numeric values). -6. Tokenizing the data in time to create a pre-tensorized dataset with clear delineations between patients, - patient sequence elements, and measurements per sequence element (note that various of these delineations +6. Tokenizing the data in time to create a pre-tensorized dataset with clear delineations between subjects, + subject sequence elements, and measurements per sequence element (note that various of these delineations may be fully flat/trivial for unnested formats). -7. Tensorizing the data to permit efficient retrieval from disk of patient data for deep-learning modeling +7. Tensorizing the data to permit efficient retrieval from disk of subject data for deep-learning modeling via PyTorch. Much like how the entire MEDS ETL pipeline is controlled by a single configuration file, the pre-processing @@ -363,7 +363,7 @@ be a bottleneck. Tokenization is the process of producing dataframes that are arranged into the sequences that will eventually be processed by deep-learning methods. Generally, these dataframes will be arranged such that each row -corresponds to a unique patient, with nested list-type columns corresponding either to _events_ (unique +corresponds to a unique subject, with nested list-type columns corresponding either to _events_ (unique timepoints), themselves with nested, list-type measurements, or to _measurements_ (unique measurements within a timepoint) directly. Importantly, _tokenized files are generally not ideally suited to direct ingestion by PyTorch datasets_. Instead, they should undergo a _tensorization_ process to be converted into a format that @@ -379,7 +379,7 @@ does not inhibit rapid training, and (3) be organized such that CPU and GPU reso during training. Similarly, by _scalability_, we mean that the three desiderata above should hold true even as the dataset size grows much larger---while total training time can increase, time to begin training, to process the data per-item, and CPU/GPU resources required should remain constant, or only grow negligibly, -such as the cost of maintaining a larger index of patient IDs to file offsets or paths (though disk space will +such as the cost of maintaining a larger index of subject IDs to file offsets or paths (though disk space will of course increase). Depending on one's performance needs and dataset sizes, there are 3 modes of deep learning training that can @@ -398,7 +398,7 @@ on an as-needed basis. This mode is extremely scalable, because the entire datas loaded or stored in memory in its entirety. When done properly, retrieving data from disk can be done in a manner that is independent of the total dataset size as well, thereby rendering the load time similarly unconstrained by total dataset size. This mode is also extremely flexible, because different cohorts can be -loaded from the same base dataset simply by changing which patients and what offsets within patient data are +loaded from the same base dataset simply by changing which subjects and what offsets within subject data are read on any given cohort, all without changing the base files or underlying code. However, this mode does require ragged dataset collation which can be more resource intensive than pre-batched iteration, so it is slower than the "Fixed-batch retrieval" approach. This mode is what is currently supported by this repository. diff --git a/eICU_Example/README.md b/eICU_Example/README.md index c0494c9..37eb9d0 100644 --- a/eICU_Example/README.md +++ b/eICU_Example/README.md @@ -19,7 +19,7 @@ up from this one). - [ ] Testing the MEDS extraction ETL runs on eICU-CRD (this should be expected to work, but needs live testing). - [ ] Sub-sharding - - [ ] Patient split gathering + - [ ] Subject split gathering - [ ] Event extraction - [ ] Merging - [ ] Validating the output MEDS cohort @@ -58,10 +58,10 @@ This is a step in a few parts: 1. Join a few tables by `hadm_id` to get the right timestamps in the right rows for processing. In particular, we need to join: - TODO -2. Convert the patient's static data to a more parseable form. This entails: - - Get the patient's DOB in a format that is usable for MEDS, rather than the integral `anchor_year` and +2. Convert the subject's static data to a more parseable form. This entails: + - Get the subject's DOB in a format that is usable for MEDS, rather than the integral `anchor_year` and `anchor_offset` fields. - - Merge the patient's `dod` with the `deathtime` from the `admissions` table. + - Merge the subject's `dod` with the `deathtime` from the `admissions` table. After these steps, modified files or symlinks to the original files will be written in a new directory which will be used as the input to the actual MEDS extraction ETL. We'll use `$EICU_PREMEDS_DIR` to denote this @@ -78,12 +78,12 @@ In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less ## Step 3: Run the MEDS extraction ETL -Note that eICU has a lot more observations per patient than does MIMIC-IV, so to keep to a reasonable memory +Note that eICU has a lot more observations per subject than does MIMIC-IV, so to keep to a reasonable memory burden (e.g., \< 150GB per worker), you will want a smaller shard size, as well as to turn off the final unique check (which should not be necessary given the structure of eICU and is expensive) in the merge stage. You can do this by setting the following parameters at the end of the mandatory args when running this script: -- `stage_configs.split_and_shard_patients.n_patients_per_shard=10000` +- `stage_configs.split_and_shard_subjects.n_subjects_per_shard=10000` - `stage_configs.merge_to_MEDS_cohort.unique_by=null` ### Running locally, serially @@ -106,10 +106,10 @@ This is a step in 4 parts: In practice, on a machine with 150 GB of RAM and 10 cores, this step takes approximately 20 minutes in total. -1. Extract and form the patient splits and sub-shards. +1. Extract and form the subject splits and sub-shards. ```bash -./scripts/extraction/split_and_shard_patients.py \ +./scripts/extraction/split_and_shard_subjects.py \ input_dir=$EICU_PREMEDS_DIR \ cohort_dir=$EICU_MEDS_DIR \ event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml @@ -117,7 +117,7 @@ In practice, on a machine with 150 GB of RAM and 10 cores, this step takes appro In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less than 5 minutes in total. -1. Extract patient sub-shards and convert to MEDS events. +1. Extract subject sub-shards and convert to MEDS events. ```bash ./scripts/extraction/convert_to_sharded_events.py \ @@ -132,7 +132,7 @@ multiple times (though this will, of course, consume more resources). If your fi commands can also be launched as separate slurm jobs, for example. For eICU, this level of parallelization and performance is not necessary; however, for larger datasets, it can be. -1. Merge the MEDS events into a single file per patient sub-shard. +1. Merge the MEDS events into a single file per subject sub-shard. ```bash ./scripts/extraction/merge_to_MEDS_cohort.py \ @@ -172,10 +172,10 @@ to finish before moving on to the next stage. Let `$N_PARALLEL_WORKERS` be the n In practice, on a machine with 150 GB of RAM and 10 cores, this step takes approximately 20 minutes in total. -1. Extract and form the patient splits and sub-shards. +1. Extract and form the subject splits and sub-shards. ```bash -./scripts/extraction/split_and_shard_patients.py \ +./scripts/extraction/split_and_shard_subjects.py \ input_dir=$EICU_PREMEDS_DIR \ cohort_dir=$EICU_MEDS_DIR \ event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml @@ -183,7 +183,7 @@ In practice, on a machine with 150 GB of RAM and 10 cores, this step takes appro In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less than 5 minutes in total. -1. Extract patient sub-shards and convert to MEDS events. +1. Extract subject sub-shards and convert to MEDS events. ```bash ./scripts/extraction/convert_to_sharded_events.py \ @@ -198,7 +198,7 @@ multiple times (though this will, of course, consume more resources). If your fi commands can also be launched as separate slurm jobs, for example. For eICU, this level of parallelization and performance is not necessary; however, for larger datasets, it can be. -1. Merge the MEDS events into a single file per patient sub-shard. +1. Merge the MEDS events into a single file per subject sub-shard. ```bash ./scripts/extraction/merge_to_MEDS_cohort.py \ @@ -221,7 +221,7 @@ timeline which is otherwise stored at the _datetime_ resolution? Other questions: -1. How to handle merging the deathtimes between the hosp table and the patients table? +1. How to handle merging the deathtimes between the hosp table and the subjects table? 2. How to handle the dob nonsense MIMIC has? ## Future Work @@ -230,4 +230,4 @@ Other questions: If you wanted, some other processing could also be done here, such as: -1. Converting the patient's dynamically recorded race into a static, most commonly recorded race field. +1. Converting the subject's dynamically recorded race into a static, most commonly recorded race field. diff --git a/eICU_Example/configs/event_configs.yaml b/eICU_Example/configs/event_configs.yaml index e6f2e7a..fb7901c 100644 --- a/eICU_Example/configs/event_configs.yaml +++ b/eICU_Example/configs/event_configs.yaml @@ -1,7 +1,7 @@ -# Note that there is no "patient_id" for eICU -- patients are only differentiable during the course of a +# Note that there is no "subject_id" for eICU -- patients are only differentiable during the course of a # single health system stay. Accordingly, we set the "patient" id here as the "patientHealthSystemStayID" -patient_id_col: patienthealthsystemstayid +subject_id_col: patienthealthsystemstayid patient: dob: @@ -131,7 +131,7 @@ infusionDrug: volume_of_fluid: "volumeoffluid" patient_weight: code: - - "INFUSION_PATIENT_WEIGHT" + - "INFUSION_SUBJECT_WEIGHT" time: col(infusionEnteredTimestamp) numeric_value: "patientweight" diff --git a/eICU_Example/joint_script.sh b/eICU_Example/joint_script.sh index fd76ee2..0b3ad6c 100755 --- a/eICU_Example/joint_script.sh +++ b/eICU_Example/joint_script.sh @@ -8,7 +8,7 @@ function display_help() { echo "Usage: $0 " echo echo "This script processes eICU data through several steps, handling raw data conversion," - echo "sharding events, splitting patients, converting to sharded events, and merging into a MEDS cohort." + echo "sharding events, splitting subjects, converting to sharded events, and merging into a MEDS cohort." echo echo "Arguments:" echo " EICU_RAW_DIR Directory containing raw eICU data files." @@ -39,12 +39,12 @@ N_PARALLEL_WORKERS="$4" shift 4 -echo "Note that eICU has a lot more observations per patient than does MIMIC-IV, so to keep to a reasonable " +echo "Note that eICU has a lot more observations per subject than does MIMIC-IV, so to keep to a reasonable " echo "memory burden (e.g., < 150GB per worker), you will want a smaller shard size, as well as to turn off " echo "the final unique check (which should not be necessary given the structure of eICU and is expensive) " echo "in the merge stage. You can do this by setting the following parameters at the end of the mandatory " echo "args when running this script:" -echo " * stage_configs.split_and_shard_patients.n_patients_per_shard=10000" +echo " * stage_configs.split_and_shard_subjects.n_subjects_per_shard=10000" echo " * stage_configs.merge_to_MEDS_cohort.unique_by=null" echo "Running pre-MEDS conversion." @@ -59,8 +59,8 @@ echo "Running shard_events.py with $N_PARALLEL_WORKERS workers in parallel" cohort_dir="$EICU_MEDS_DIR" \ event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml "$@" -echo "Splitting patients in serial" -./scripts/extraction/split_and_shard_patients.py \ +echo "Splitting subjects in serial" +./scripts/extraction/split_and_shard_subjects.py \ input_dir="$EICU_PREMEDS_DIR" \ cohort_dir="$EICU_MEDS_DIR" \ event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml "$@" diff --git a/eICU_Example/joint_script_slurm.sh b/eICU_Example/joint_script_slurm.sh index 7880286..bdd7abe 100755 --- a/eICU_Example/joint_script_slurm.sh +++ b/eICU_Example/joint_script_slurm.sh @@ -8,7 +8,7 @@ function display_help() { echo "Usage: $0 " echo echo "This script processes eICU data through several steps, handling raw data conversion," - echo "sharding events, splitting patients, converting to sharded events, and merging into a MEDS cohort." + echo "sharding events, splitting subjects, converting to sharded events, and merging into a MEDS cohort." echo "This script uses slurm to process the data in parallel via the 'submitit' Hydra launcher." echo echo "Arguments:" @@ -71,8 +71,8 @@ echo "Trying submitit launching with $N_PARALLEL_WORKERS jobs." cohort_dir="$EICU_MEDS_DIR" \ event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml -echo "Splitting patients on one worker" -./scripts/extraction/split_and_shard_patients.py \ +echo "Splitting subjects on one worker" +./scripts/extraction/split_and_shard_subjects.py \ --multirun \ worker="range(0,1)" \ hydra/launcher=submitit_slurm \ diff --git a/pyproject.toml b/pyproject.toml index d4f6450..8073111 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ - "polars~=1.1.0", "pyarrow", "nested_ragged_tensors", "loguru", "hydra-core", "numpy", "meds==0.3", + "polars~=1.1.0", "pyarrow", "nested_ragged_tensors", "loguru", "hydra-core", "numpy", "meds==0.3.2", ] [tool.setuptools_scm] @@ -35,7 +35,7 @@ docs = [ [project.scripts] # MEDS_extract -MEDS_extract-split_and_shard_patients = "MEDS_transforms.extract.split_and_shard_patients:main" +MEDS_extract-split_and_shard_subjects = "MEDS_transforms.extract.split_and_shard_subjects:main" MEDS_extract-shard_events = "MEDS_transforms.extract.shard_events:main" MEDS_extract-convert_to_sharded_events = "MEDS_transforms.extract.convert_to_sharded_events:main" MEDS_extract-merge_to_MEDS_cohort = "MEDS_transforms.extract.merge_to_MEDS_cohort:main" @@ -50,7 +50,7 @@ MEDS_transform-fit_vocabulary_indices = "MEDS_transforms.fit_vocabulary_indices: MEDS_transform-reshard_to_split = "MEDS_transforms.reshard_to_split:main" ## Filters MEDS_transform-filter_measurements = "MEDS_transforms.filters.filter_measurements:main" -MEDS_transform-filter_patients = "MEDS_transforms.filters.filter_patients:main" +MEDS_transform-filter_subjects = "MEDS_transforms.filters.filter_subjects:main" ## Transforms MEDS_transform-reorder_measurements = "MEDS_transforms.transforms.reorder_measurements:main" MEDS_transform-add_time_derived_measurements = "MEDS_transforms.transforms.add_time_derived_measurements:main" diff --git a/src/MEDS_transforms/__init__.py b/src/MEDS_transforms/__init__.py index e0aaaf3..c40e2d9 100644 --- a/src/MEDS_transforms/__init__.py +++ b/src/MEDS_transforms/__init__.py @@ -2,6 +2,7 @@ from importlib.resources import files import polars as pl +from meds import code_field, subject_id_field, time_field __package_name__ = "MEDS_transforms" try: @@ -12,12 +13,12 @@ PREPROCESS_CONFIG_YAML = files(__package_name__).joinpath("configs/preprocess.yaml") EXTRACT_CONFIG_YAML = files(__package_name__).joinpath("configs/extract.yaml") -MANDATORY_COLUMNS = ["patient_id", "time", "code", "numeric_value"] +MANDATORY_COLUMNS = [subject_id_field, time_field, code_field, "numeric_value"] MANDATORY_TYPES = { - "patient_id": pl.Int64, - "time": pl.Datetime("us"), - "code": pl.String, + subject_id_field: pl.Int64, + time_field: pl.Datetime("us"), + code_field: pl.String, "numeric_value": pl.Float32, "categorical_value": pl.String, "text_value": pl.String, @@ -29,5 +30,5 @@ "category_value": "categoric_value", "textual_value": "text_value", "timestamp": "time", - "subject_id": "patient_id", + "patient_id": subject_id_field, } diff --git a/src/MEDS_transforms/aggregate_code_metadata.py b/src/MEDS_transforms/aggregate_code_metadata.py index 13e9b34..d5c2f81 100755 --- a/src/MEDS_transforms/aggregate_code_metadata.py +++ b/src/MEDS_transforms/aggregate_code_metadata.py @@ -12,6 +12,7 @@ import polars as pl import polars.selectors as cs from loguru import logger +from meds import subject_id_field from omegaconf import DictConfig, ListConfig, OmegaConf from MEDS_transforms import PREPROCESS_CONFIG_YAML @@ -26,7 +27,7 @@ class METADATA_FN(StrEnum): This enumeration contains the supported code-metadata collection and aggregation function names that can be applied to codes (or, rather, unique code & modifier units) in a MEDS cohort. Each function name is mapped, in the below `CODE_METADATA_AGGREGATIONS` dictionary, to mapper and reducer functions that (a) - collect the raw data at a per code-modifier level from MEDS patient-level shards and (b) aggregates two or + collect the raw data at a per code-modifier level from MEDS subject-level shards and (b) aggregates two or more per-shard metadata files into a single metadata file, which can be used to merge metadata across all shards into a single file. @@ -45,14 +46,14 @@ class METADATA_FN(StrEnum): or on the command line. Args: - "code/n_patients": Collects the number of unique patients who have (anywhere in their record) the code + "code/n_subjects": Collects the number of unique subjects who have (anywhere in their record) the code & modifiers group. "code/n_occurrences": Collects the total number of occurrences of the code & modifiers group across - all observations for all patients. - "values/n_patients": Collects the number of unique patients who have a non-null, non-nan + all observations for all subjects. + "values/n_subjects": Collects the number of unique subjects who have a non-null, non-nan numeric_value field for the code & modifiers group. "values/n_occurrences": Collects the total number of non-null, non-nan numeric_value occurrences for - the code & modifiers group across all observations for all patients. + the code & modifiers group across all observations for all subjects. "values/n_ints": Collects the number of times the observed, non-null numeric_value for the code & modifiers group is an integral value (i.e., a whole number, not an integral type). "values/sum": Collects the sum of the non-null, non-nan numeric_value values for the code & @@ -67,9 +68,9 @@ class METADATA_FN(StrEnum): the configuration file using the dictionary syntax for the aggregation. """ - CODE_N_PATIENTS = "code/n_patients" + CODE_N_PATIENTS = "code/n_subjects" CODE_N_OCCURRENCES = "code/n_occurrences" - VALUES_N_PATIENTS = "values/n_patients" + VALUES_N_PATIENTS = "values/n_subjects" VALUES_N_OCCURRENCES = "values/n_occurrences" VALUES_N_INTS = "values/n_ints" VALUES_SUM = "values/sum" @@ -157,10 +158,10 @@ def quantile_reducer(cols: cs._selector_proxy_, quantiles: list[float]) -> pl.Ex PRESENT_VALS = VAL.filter(VAL_PRESENT) CODE_METADATA_AGGREGATIONS: dict[METADATA_FN, MapReducePair] = { - METADATA_FN.CODE_N_PATIENTS: MapReducePair(pl.col("patient_id").n_unique(), pl.sum_horizontal), + METADATA_FN.CODE_N_PATIENTS: MapReducePair(pl.col(subject_id_field).n_unique(), pl.sum_horizontal), METADATA_FN.CODE_N_OCCURRENCES: MapReducePair(pl.len(), pl.sum_horizontal), METADATA_FN.VALUES_N_PATIENTS: MapReducePair( - pl.col("patient_id").filter(VAL_PRESENT).n_unique(), pl.sum_horizontal + pl.col(subject_id_field).filter(VAL_PRESENT).n_unique(), pl.sum_horizontal ), METADATA_FN.VALUES_N_OCCURRENCES: MapReducePair(PRESENT_VALS.len(), pl.sum_horizontal), METADATA_FN.VALUES_N_INTS: MapReducePair(VAL.filter(VAL_PRESENT & IS_INT).len(), pl.sum_horizontal), @@ -203,9 +204,9 @@ def validate_args_and_get_code_cols(stage_cfg: DictConfig, code_modifiers: list[ Traceback (most recent call last): ... ValueError: Metadata aggregation function INVALID not found in METADATA_FN enumeration. Values are: - code/n_patients, code/n_occurrences, values/n_patients, values/n_occurrences, values/n_ints, + code/n_subjects, code/n_occurrences, values/n_subjects, values/n_occurrences, values/n_ints, values/sum, values/sum_sqd, values/min, values/max, values/quantiles - >>> valid_cfg = DictConfig({"aggregations": ["code/n_patients", {"name": "values/n_ints"}]}) + >>> valid_cfg = DictConfig({"aggregations": ["code/n_subjects", {"name": "values/n_ints"}]}) >>> validate_args_and_get_code_cols(valid_cfg, 33) Traceback (most recent call last): ... @@ -264,7 +265,7 @@ def mapper_fntr( A function that extracts the specified metadata from a MEDS cohort shard after grouping by the specified code & modifier columns. **Note**: The output of this function will, if ``stage_cfg.do_summarize_over_all_codes`` is True, contain the metadata summarizing all observations - across all codes and patients in the shard, with both ``code`` and all ``code_modifiers`` set + across all codes and subjects in the shard, with both ``code`` and all ``code_modifiers`` set to `None` in the output dataframe, in the same format as the code/modifier specific rows with non-null values. @@ -274,13 +275,13 @@ def mapper_fntr( ... "code": ["A", "B", "A", "B", "C", "A", "C", "B", "D"], ... "modifier1": [1, 2, 1, 2, 1, 2, 1, 2, None], ... "modifier_ignored": [3, 3, 4, 4, 5, 5, 6, 6, 7], - ... "patient_id": [1, 2, 1, 3, 1, 2, 2, 2, 1], + ... "subject_id": [1, 2, 1, 3, 1, 2, 2, 2, 1], ... "numeric_value": [1.1, 2., 1.1, 4., 5., 6., 7.5, float('nan'), None], ... }) >>> df shape: (9, 5) ┌──────┬───────────┬──────────────────┬────────────┬───────────────┐ - │ code ┆ modifier1 ┆ modifier_ignored ┆ patient_id ┆ numeric_value │ + │ code ┆ modifier1 ┆ modifier_ignored ┆ subject_id ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ i64 ┆ i64 ┆ i64 ┆ f64 │ ╞══════╪═══════════╪══════════════════╪════════════╪═══════════════╡ @@ -295,14 +296,14 @@ def mapper_fntr( │ D ┆ null ┆ 7 ┆ 1 ┆ null │ └──────┴───────────┴──────────────────┴────────────┴───────────────┘ >>> stage_cfg = DictConfig({ - ... "aggregations": ["code/n_patients", "values/n_ints"], + ... "aggregations": ["code/n_subjects", "values/n_ints"], ... "do_summarize_over_all_codes": True ... }) >>> mapper = mapper_fntr(stage_cfg, None) >>> mapper(df.lazy()).collect() shape: (5, 3) ┌──────┬─────────────────┬───────────────┐ - │ code ┆ code/n_patients ┆ values/n_ints │ + │ code ┆ code/n_subjects ┆ values/n_ints │ │ --- ┆ --- ┆ --- │ │ str ┆ u32 ┆ u32 │ ╞══════╪═════════════════╪═══════════════╡ @@ -312,12 +313,12 @@ def mapper_fntr( │ C ┆ 2 ┆ 1 │ │ D ┆ 1 ┆ 0 │ └──────┴─────────────────┴───────────────┘ - >>> stage_cfg = DictConfig({"aggregations": ["code/n_patients", "values/n_ints"]}) + >>> stage_cfg = DictConfig({"aggregations": ["code/n_subjects", "values/n_ints"]}) >>> mapper = mapper_fntr(stage_cfg, None) >>> mapper(df.lazy()).collect() shape: (4, 3) ┌──────┬─────────────────┬───────────────┐ - │ code ┆ code/n_patients ┆ values/n_ints │ + │ code ┆ code/n_subjects ┆ values/n_ints │ │ --- ┆ --- ┆ --- │ │ str ┆ u32 ┆ u32 │ ╞══════╪═════════════════╪═══════════════╡ @@ -327,12 +328,12 @@ def mapper_fntr( │ D ┆ 1 ┆ 0 │ └──────┴─────────────────┴───────────────┘ >>> code_modifiers = ["modifier1"] - >>> stage_cfg = DictConfig({"aggregations": ["code/n_patients", "values/n_ints"]}) + >>> stage_cfg = DictConfig({"aggregations": ["code/n_subjects", "values/n_ints"]}) >>> mapper = mapper_fntr(stage_cfg, ListConfig(code_modifiers)) >>> mapper(df.lazy()).collect() shape: (5, 4) ┌──────┬───────────┬─────────────────┬───────────────┐ - │ code ┆ modifier1 ┆ code/n_patients ┆ values/n_ints │ + │ code ┆ modifier1 ┆ code/n_subjects ┆ values/n_ints │ │ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ i64 ┆ u32 ┆ u32 │ ╞══════╪═══════════╪═════════════════╪═══════════════╡ @@ -376,12 +377,12 @@ def mapper_fntr( │ C ┆ 1 ┆ 2 ┆ 12.5 │ │ D ┆ null ┆ 1 ┆ 0.0 │ └──────┴───────────┴────────────────────┴────────────┘ - >>> stage_cfg = DictConfig({"aggregations": ["values/n_patients", "values/n_occurrences"]}) + >>> stage_cfg = DictConfig({"aggregations": ["values/n_subjects", "values/n_occurrences"]}) >>> mapper = mapper_fntr(stage_cfg, code_modifiers) >>> mapper(df.lazy()).collect() shape: (5, 4) ┌──────┬───────────┬───────────────────┬──────────────────────┐ - │ code ┆ modifier1 ┆ values/n_patients ┆ values/n_occurrences │ + │ code ┆ modifier1 ┆ values/n_subjects ┆ values/n_occurrences │ │ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ i64 ┆ u32 ┆ u32 │ ╞══════╪═══════════╪═══════════════════╪══════════════════════╡ @@ -455,7 +456,7 @@ def mapper_fntr( def by_code_mapper(df: pl.LazyFrame) -> pl.LazyFrame: return df.group_by(code_key_columns).agg(**agg_operations).sort(code_key_columns) - def all_patients_mapper(df: pl.LazyFrame) -> pl.LazyFrame: + def all_subjects_mapper(df: pl.LazyFrame) -> pl.LazyFrame: local_agg_operations = agg_operations.copy() if METADATA_FN.VALUES_QUANTILES in agg_operations: local_agg_operations[METADATA_FN.VALUES_QUANTILES] = agg_operations[ @@ -467,8 +468,8 @@ def all_patients_mapper(df: pl.LazyFrame) -> pl.LazyFrame: def mapper(df: pl.LazyFrame) -> pl.LazyFrame: by_code = by_code_mapper(df) - all_patients = all_patients_mapper(df) - return pl.concat([all_patients, by_code], how="diagonal_relaxed").select( + all_subjects = all_subjects_mapper(df) + return pl.concat([all_subjects, by_code], how="diagonal_relaxed").select( *code_key_columns, *agg_operations.keys() ) @@ -502,9 +503,9 @@ def reducer_fntr( >>> df_1 = pl.DataFrame({ ... "code": [None, "A", "A", "B", "C"], ... "modifier1": [None, 1, 2, 1, 2], - ... "code/n_patients": [10, 1, 1, 2, 2], + ... "code/n_subjects": [10, 1, 1, 2, 2], ... "code/n_occurrences": [13, 2, 1, 3, 2], - ... "values/n_patients": [8, 1, 1, 2, 2], + ... "values/n_subjects": [8, 1, 1, 2, 2], ... "values/n_occurrences": [12, 2, 1, 3, 2], ... "values/n_ints": [4, 0, 1, 3, 1], ... "values/sum": [13.2, 2.2, 6.0, 14.0, 12.5], @@ -516,9 +517,9 @@ def reducer_fntr( >>> df_2 = pl.DataFrame({ ... "code": ["A", "A", "B", "C"], ... "modifier1": [1, 2, 1, None], - ... "code/n_patients": [3, 3, 4, 4], + ... "code/n_subjects": [3, 3, 4, 4], ... "code/n_occurrences": [10, 11, 8, 11], - ... "values/n_patients": [0, 1, 2, 2], + ... "values/n_subjects": [0, 1, 2, 2], ... "values/n_occurrences": [0, 4, 3, 2], ... "values/n_ints": [0, 1, 3, 1], ... "values/sum": [0., 7.0, 14.0, 12.5], @@ -530,9 +531,9 @@ def reducer_fntr( >>> df_3 = pl.DataFrame({ ... "code": ["D"], ... "modifier1": [1], - ... "code/n_patients": [2], + ... "code/n_subjects": [2], ... "code/n_occurrences": [2], - ... "values/n_patients": [1], + ... "values/n_subjects": [1], ... "values/n_occurrences": [3], ... "values/n_ints": [3], ... "values/sum": [2], @@ -542,12 +543,12 @@ def reducer_fntr( ... "values/quantiles": [[]], ... }) >>> code_modifiers = ["modifier1"] - >>> stage_cfg = DictConfig({"aggregations": ["code/n_patients", "values/n_ints"]}) + >>> stage_cfg = DictConfig({"aggregations": ["code/n_subjects", "values/n_ints"]}) >>> reducer = reducer_fntr(stage_cfg, code_modifiers) >>> reducer(df_1, df_2, df_3) shape: (7, 4) ┌──────┬───────────┬─────────────────┬───────────────┐ - │ code ┆ modifier1 ┆ code/n_patients ┆ values/n_ints │ + │ code ┆ modifier1 ┆ code/n_subjects ┆ values/n_ints │ │ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ i64 ┆ i64 ┆ i64 │ ╞══════╪═══════════╪═════════════════╪═══════════════╡ @@ -562,9 +563,9 @@ def reducer_fntr( >>> cfg = DictConfig({ ... "code_modifiers": ["modifier1"], ... "code_processing_stages": { - ... "stage1": ["code/n_patients", "values/n_ints"], + ... "stage1": ["code/n_subjects", "values/n_ints"], ... "stage2": ["code/n_occurrences", "values/sum"], - ... "stage3.A": ["values/n_patients", "values/n_occurrences"], + ... "stage3.A": ["values/n_subjects", "values/n_occurrences"], ... "stage3.B": ["values/sum_sqd", "values/min", "values/max"], ... "stage4": ["INVALID"], ... } @@ -586,12 +587,12 @@ def reducer_fntr( │ C ┆ 2 ┆ 2 ┆ 12.5 │ │ D ┆ 1 ┆ 2 ┆ 2.0 │ └──────┴───────────┴────────────────────┴────────────┘ - >>> stage_cfg = DictConfig({"aggregations": ["values/n_patients", "values/n_occurrences"]}) + >>> stage_cfg = DictConfig({"aggregations": ["values/n_subjects", "values/n_occurrences"]}) >>> reducer = reducer_fntr(stage_cfg, code_modifiers) >>> reducer(df_1, df_2, df_3) shape: (7, 4) ┌──────┬───────────┬───────────────────┬──────────────────────┐ - │ code ┆ modifier1 ┆ values/n_patients ┆ values/n_occurrences │ + │ code ┆ modifier1 ┆ values/n_subjects ┆ values/n_occurrences │ │ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ i64 ┆ i64 ┆ i64 │ ╞══════╪═══════════╪═══════════════════╪══════════════════════╡ diff --git a/src/MEDS_transforms/configs/extract.yaml b/src/MEDS_transforms/configs/extract.yaml index ae3bf07..3abd498 100644 --- a/src/MEDS_transforms/configs/extract.yaml +++ b/src/MEDS_transforms/configs/extract.yaml @@ -2,7 +2,7 @@ defaults: - pipeline - stage_configs: - shard_events - - split_and_shard_patients + - split_and_shard_subjects - merge_to_MEDS_cohort - extract_code_metadata - finalize_MEDS_metadata @@ -32,7 +32,7 @@ shards_map_fp: "${cohort_dir}/metadata/.shards.json" stages: - shard_events - - split_and_shard_patients + - split_and_shard_subjects - convert_to_sharded_events - merge_to_MEDS_cohort - extract_code_metadata diff --git a/src/MEDS_transforms/configs/preprocess.yaml b/src/MEDS_transforms/configs/preprocess.yaml index ea509cd..dab87a9 100644 --- a/src/MEDS_transforms/configs/preprocess.yaml +++ b/src/MEDS_transforms/configs/preprocess.yaml @@ -2,7 +2,7 @@ defaults: - pipeline - stage_configs: - reshard_to_split - - filter_patients + - filter_subjects - add_time_derived_measurements - count_code_occurrences - filter_measurements @@ -24,7 +24,7 @@ code_modifiers: ??? # Pipeline Structure stages: - - filter_patients + - filter_subjects - add_time_derived_measurements - preliminary_counts - filter_measurements diff --git a/src/MEDS_transforms/configs/stage_configs/count_code_occurrences.yaml b/src/MEDS_transforms/configs/stage_configs/count_code_occurrences.yaml index 076a1a0..b17b74e 100644 --- a/src/MEDS_transforms/configs/stage_configs/count_code_occurrences.yaml +++ b/src/MEDS_transforms/configs/stage_configs/count_code_occurrences.yaml @@ -1,5 +1,5 @@ count_code_occurrences: aggregations: - "code/n_occurrences" - - "code/n_patients" + - "code/n_subjects" do_summarize_over_all_codes: true # This indicates we should include overall, code-independent counts diff --git a/src/MEDS_transforms/configs/stage_configs/filter_measurements.yaml b/src/MEDS_transforms/configs/stage_configs/filter_measurements.yaml index 12ff62f..0d0a5bd 100644 --- a/src/MEDS_transforms/configs/stage_configs/filter_measurements.yaml +++ b/src/MEDS_transforms/configs/stage_configs/filter_measurements.yaml @@ -1,3 +1,3 @@ filter_measurements: - min_patients_per_code: null + min_subjects_per_code: null min_occurrences_per_code: null diff --git a/src/MEDS_transforms/configs/stage_configs/filter_patients.yaml b/src/MEDS_transforms/configs/stage_configs/filter_patients.yaml deleted file mode 100644 index 70332b1..0000000 --- a/src/MEDS_transforms/configs/stage_configs/filter_patients.yaml +++ /dev/null @@ -1,3 +0,0 @@ -filter_patients: - min_events_per_patient: null - min_measurements_per_patient: null diff --git a/src/MEDS_transforms/configs/stage_configs/filter_subjects.yaml b/src/MEDS_transforms/configs/stage_configs/filter_subjects.yaml new file mode 100644 index 0000000..2706ffc --- /dev/null +++ b/src/MEDS_transforms/configs/stage_configs/filter_subjects.yaml @@ -0,0 +1,3 @@ +filter_subjects: + min_events_per_subject: null + min_measurements_per_subject: null diff --git a/src/MEDS_transforms/configs/stage_configs/fit_normalization.yaml b/src/MEDS_transforms/configs/stage_configs/fit_normalization.yaml index e522470..6bd90cb 100644 --- a/src/MEDS_transforms/configs/stage_configs/fit_normalization.yaml +++ b/src/MEDS_transforms/configs/stage_configs/fit_normalization.yaml @@ -1,7 +1,7 @@ fit_normalization: aggregations: - "code/n_occurrences" - - "code/n_patients" + - "code/n_subjects" - "values/n_occurrences" - "values/sum" - "values/sum_sqd" diff --git a/src/MEDS_transforms/configs/stage_configs/reshard_to_split.yaml b/src/MEDS_transforms/configs/stage_configs/reshard_to_split.yaml index 16dc505..fd0dc8a 100644 --- a/src/MEDS_transforms/configs/stage_configs/reshard_to_split.yaml +++ b/src/MEDS_transforms/configs/stage_configs/reshard_to_split.yaml @@ -1,2 +1,2 @@ reshard_to_split: - n_patients_per_shard: 50000 + n_subjects_per_shard: 50000 diff --git a/src/MEDS_transforms/configs/stage_configs/split_and_shard_patients.yaml b/src/MEDS_transforms/configs/stage_configs/split_and_shard_subjects.yaml similarity index 73% rename from src/MEDS_transforms/configs/stage_configs/split_and_shard_patients.yaml rename to src/MEDS_transforms/configs/stage_configs/split_and_shard_subjects.yaml index c4015bd..7dbed11 100644 --- a/src/MEDS_transforms/configs/stage_configs/split_and_shard_patients.yaml +++ b/src/MEDS_transforms/configs/stage_configs/split_and_shard_subjects.yaml @@ -1,7 +1,7 @@ -split_and_shard_patients: +split_and_shard_subjects: is_metadata: True output_dir: ${cohort_dir}/metadata - n_patients_per_shard: 50000 + n_subjects_per_shard: 50000 external_splits_json_fp: null split_fracs: train: 0.8 diff --git a/src/MEDS_transforms/extract/README.md b/src/MEDS_transforms/extract/README.md index 47f7e7d..a60a8c8 100644 --- a/src/MEDS_transforms/extract/README.md +++ b/src/MEDS_transforms/extract/README.md @@ -4,8 +4,8 @@ This directory contains the scripts and functions used to extract raw data into dataset is: 1. Arranged in a series of files on disk of an allowed format (e.g., `.csv`, `.csv.gz`, `.parquet`)... -2. Such that each file stores a dataframe containing data about patients such that each row of any given - table corresponds to zero or more observations about a patient at a given time... +2. Such that each file stores a dataframe containing data about subjects such that each row of any given + table corresponds to zero or more observations about a subject at a given time... 3. And you can configure how to extract those observations in the time, code, and numeric value format of MEDS in the event conversion `yaml` file format specified below, then... this tool can automatically extract your raw data into a MEDS dataset for you in an efficient, reproducible, @@ -53,7 +53,7 @@ step](#step-0-pre-meds) and the [Data Cleaning step](#step-3-data-cleanup), for ### Event Conversion Configuration The event conversion configuration file tells MEDS Extract how to convert each row of a file among your raw -data files into one or more MEDS measurements (meaning a tuple of a patient ID, a time, a categorical +data files into one or more MEDS measurements (meaning a tuple of a subject ID, a time, a categorical code, and/or various other value or properties columns, most commonly a numeric value). This file is written in yaml and has the following format: @@ -93,11 +93,11 @@ each row of the file will be converted into a MEDS event according to the logic here, as string literals _cannot_ be used for these columns. There are several more nuanced aspects to the configuration file that have not yet been discussed. First, the -configuration file also specifies how to identify the patient ID from the raw data. This can be done either by -specifying a global `patient_id_col` field at the top level of the configuration file, or by specifying a -`patient_id_col` field at the per-file or per-event level. Multiple specifications can be used simultaneously, -with the most local taking precedent. If no patient ID column is specified, the patient ID will be assumed to -be stored in a `patient_id` column. If the patient ID column is not found, an error will be raised. +configuration file also specifies how to identify the subject ID from the raw data. This can be done either by +specifying a global `subject_id_col` field at the top level of the configuration file, or by specifying a +`subject_id_col` field at the per-file or per-event level. Multiple specifications can be used simultaneously, +with the most local taking precedent. If no subject ID column is specified, the subject ID will be assumed to +be stored in a `subject_id` column. If the subject ID column is not found, an error will be raised. Second, you can also specify how to link the codes constructed for each event block to code-specific metadata in these blocks. This is done by specifying a `_metadata` block in the event block. The format of this block @@ -117,7 +117,7 @@ the [Partial MIMIC-IV Example](#partial-mimic-iv-example) below for an example o ```yaml subjects: - patient_id_col: MRN + subject_id_col: MRN eye_color: code: - EYE_COLOR @@ -144,7 +144,7 @@ admit_vitals: ##### Partial MIMIC-IV Example ```yaml -patient_id_col: subject_id +subject_id_col: subject_id hosp/admissions: admission: code: @@ -259,4 +259,4 @@ Note that this tool is _not_: TODO: Add issues for all of these. 1. Single event blocks for files should be specifiable directly, without an event block name. -2. Time format should be specifiable at the file or global level, like patient ID. +2. Time format should be specifiable at the file or global level, like subject ID. diff --git a/src/MEDS_transforms/extract/convert_to_sharded_events.py b/src/MEDS_transforms/extract/convert_to_sharded_events.py index ee4e9d7..c83f934 100755 --- a/src/MEDS_transforms/extract/convert_to_sharded_events.py +++ b/src/MEDS_transforms/extract/convert_to_sharded_events.py @@ -97,7 +97,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy """Extracts a single event dataframe from the raw data. Args: - df: The raw data DataFrame. This must have a `"patient_id"` column containing the patient ID. The + df: The raw data DataFrame. This must have a `"subject_id"` column containing the subject ID. The other columns it must have are determined by the `event_cfg` configuration dictionary. event_cfg: A dictionary containing the configuration for the event. This must contain two critical keys (`"code"` and `"time"`) and may contain additional keys for other columns to include @@ -128,7 +128,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy A DataFrame containing the event data extracted from the raw data, containing only unique rows across all columns. If the raw data has no duplicates when considering the event column space, the output dataframe will have the same number of rows as the raw data and be in the same order. The output - dataframe will contain at least three columns: `"patient_id"`, `"code"`, and `"time"`. If the + dataframe will contain at least three columns: `"subject_id"`, `"code"`, and `"time"`. If the event has additional columns, they will be included in the output dataframe as well. **_Events that would be extracted with a null code or a time that should be specified via a column with or without a formatting option but in practice is null will be dropped._** Note that this dropping logic @@ -145,7 +145,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> _ = pl.Config.set_tbl_rows(20) >>> _ = pl.Config.set_tbl_cols(20) >>> raw_data = pl.DataFrame({ - ... "patient_id": [1, 1, 2, 2], + ... "subject_id": [1, 1, 2, 2], ... "code": ["A", "B", "C", "D"], ... "code_modifier": ["1", "2", "3", "4"], ... "time": ["2021-01-01", "2021-01-02", "2021-01-03", "2021-01-04"], @@ -160,7 +160,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> extract_event(raw_data, event_cfg) shape: (4, 4) ┌────────────┬───────────┬─────────────────────┬───────────────┐ - │ patient_id ┆ code ┆ time ┆ numeric_value │ + │ subject_id ┆ code ┆ time ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ datetime[μs] ┆ i64 │ ╞════════════╪═══════════╪═════════════════════╪═══════════════╡ @@ -170,7 +170,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy │ 2 ┆ FOO//D//4 ┆ 2021-01-04 00:00:00 ┆ 4 │ └────────────┴───────────┴─────────────────────┴───────────────┘ >>> data_with_nulls = pl.DataFrame({ - ... "patient_id": [1, 1, 2, 2], + ... "subject_id": [1, 1, 2, 2], ... "code": ["A", None, "C", "D"], ... "code_modifier": ["1", "2", "3", None], ... "time": [None, "2021-01-02", "2021-01-03", "2021-01-04"], @@ -185,7 +185,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> extract_event(data_with_nulls, event_cfg) shape: (2, 4) ┌────────────┬────────┬─────────────────────┬───────────────┐ - │ patient_id ┆ code ┆ time ┆ numeric_value │ + │ subject_id ┆ code ┆ time ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ datetime[μs] ┆ i64 │ ╞════════════╪════════╪═════════════════════╪═══════════════╡ @@ -195,7 +195,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> from datetime import datetime >>> complex_raw_data = pl.DataFrame( ... { - ... "patient_id": [1, 1, 2, 2, 2, 3], + ... "subject_id": [1, 1, 2, 2, 2, 3], ... "admission_time": [ ... "2021-01-01 00:00:00", ... "2021-01-02 00:00:00", @@ -227,7 +227,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy ... "eye_color": ["blue", "blue", "green", "green", "green", "brown"], ... }, ... schema={ - ... "patient_id": pl.UInt8, + ... "subject_id": pl.UInt8, ... "admission_time": pl.Utf8, ... "discharge_time": pl.Datetime, ... "admission_type": pl.Utf8, @@ -263,7 +263,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> complex_raw_data shape: (6, 9) ┌────────────┬─────────────────────┬─────────────────────┬────────────────┬────────────────────┬──────────────────┬────────────────┬────────────┬───────────┐ - │ patient_id ┆ admission_time ┆ discharge_time ┆ admission_type ┆ discharge_location ┆ discharge_status ┆ severity_score ┆ death_time ┆ eye_color │ + │ subject_id ┆ admission_time ┆ discharge_time ┆ admission_type ┆ discharge_location ┆ discharge_status ┆ severity_score ┆ death_time ┆ eye_color │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] ┆ str ┆ cat ┆ str ┆ f64 ┆ str ┆ cat │ ╞════════════╪═════════════════════╪═════════════════════╪════════════════╪════════════════════╪══════════════════╪════════════════╪════════════╪═══════════╡ @@ -277,7 +277,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> extract_event(complex_raw_data, valid_admission_event_cfg) shape: (6, 4) ┌────────────┬──────────────┬─────────────────────┬───────────────┐ - │ patient_id ┆ code ┆ time ┆ numeric_value │ + │ subject_id ┆ code ┆ time ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] ┆ f64 │ ╞════════════╪══════════════╪═════════════════════╪═══════════════╡ @@ -294,7 +294,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy ... ) shape: (6, 4) ┌────────────┬──────────────┬─────────────────────┬───────────────┐ - │ patient_id ┆ code ┆ time ┆ numeric_value │ + │ subject_id ┆ code ┆ time ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] ┆ f64 │ ╞════════════╪══════════════╪═════════════════════╪═══════════════╡ @@ -311,7 +311,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy ... ) shape: (6, 4) ┌────────────┬──────────────┬─────────────────────┬───────────────┐ - │ patient_id ┆ code ┆ time ┆ numeric_value │ + │ subject_id ┆ code ┆ time ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] ┆ f64 │ ╞════════════╪══════════════╪═════════════════════╪═══════════════╡ @@ -325,7 +325,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> extract_event(complex_raw_data, valid_discharge_event_cfg) shape: (6, 5) ┌────────────┬─────────────────┬─────────────────────┬───────────────────┬────────────┐ - │ patient_id ┆ code ┆ time ┆ categorical_value ┆ text_value │ + │ subject_id ┆ code ┆ time ┆ categorical_value ┆ text_value │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] ┆ str ┆ str │ ╞════════════╪═════════════════╪═════════════════════╪═══════════════════╪════════════╡ @@ -339,7 +339,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> extract_event(complex_raw_data, valid_death_event_cfg) shape: (3, 3) ┌────────────┬───────┬─────────────────────┐ - │ patient_id ┆ code ┆ time │ + │ subject_id ┆ code ┆ time │ │ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] │ ╞════════════╪═══════╪═════════════════════╡ @@ -351,7 +351,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy >>> extract_event(complex_raw_data, valid_static_event_cfg) shape: (3, 3) ┌────────────┬──────────────────┬──────────────┐ - │ patient_id ┆ code ┆ time │ + │ subject_id ┆ code ┆ time │ │ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] │ ╞════════════╪══════════════════╪══════════════╡ @@ -371,10 +371,10 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy Traceback (most recent call last): ... ValueError: Invalid time literal: 12-01-23 - >>> extract_event(complex_raw_data, {"code": "test", "time": None, "patient_id": 3}) + >>> extract_event(complex_raw_data, {"code": "test", "time": None, "subject_id": 3}) Traceback (most recent call last): ... - KeyError: "Event column name 'patient_id' cannot be overridden." + KeyError: "Event column name 'subject_id' cannot be overridden." >>> extract_event(complex_raw_data, {"code": "test", "time": None, "foobar": "fuzz"}) Traceback (most recent call last): ... @@ -389,7 +389,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy ValueError: Source column 'discharge_time' for event column foobar is not numeric, string, or categorical! Cannot be used as an event col. """ # noqa: E501 event_cfg = copy.deepcopy(event_cfg) - event_exprs = {"patient_id": pl.col("patient_id")} + event_exprs = {"subject_id": pl.col("subject_id")} if "code" not in event_cfg: raise KeyError( @@ -401,8 +401,8 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy "Event configuration dictionary must contain 'time' key. " f"Got: [{', '.join(event_cfg.keys())}]." ) - if "patient_id" in event_cfg: - raise KeyError("Event column name 'patient_id' cannot be overridden.") + if "subject_id" in event_cfg: + raise KeyError("Event column name 'subject_id' cannot be overridden.") code_expr, code_null_filter_expr, needed_cols = get_code_expr(event_cfg.pop("code")) @@ -502,7 +502,7 @@ def convert_to_events( """Converts a DataFrame of raw data into a DataFrame of events. Args: - df: The raw data DataFrame. This must have a `"patient_id"` column containing the patient ID. The + df: The raw data DataFrame. This must have a `"subject_id"` column containing the subject ID. The other columns it must have are determined by the `event_cfgs` configuration dictionary. For the precise mechanism of column determination, see the `extract_event` function. event_cfgs: A dictionary containing the configurations for the events to extract. The keys of this @@ -518,7 +518,7 @@ def convert_to_events( events extracted from the raw data, with the rows from each event DataFrame concatenated together. After concatenation, this dataframe will not be deduplicated, so if the raw data results in duplicates across events of different name, these will be preserved in the output DataFrame. - The output DataFrame will contain at least three columns: `"patient_id"`, `"code"`, and `"time"`. + The output DataFrame will contain at least three columns: `"subject_id"`, `"code"`, and `"time"`. If any events have additional columns, these will be included in the output DataFrame as well. All columns across all event configurations will be included in the output DataFrame, with `null` values filled in for events that do not have a particular column. @@ -533,7 +533,7 @@ def convert_to_events( >>> from datetime import datetime >>> complex_raw_data = pl.DataFrame( ... { - ... "patient_id": [1, 1, 2, 2, 2, 3], + ... "subject_id": [1, 1, 2, 2, 2, 3], ... "admission_time": [ ... "2021-01-01 00:00:00", ... "2021-01-02 00:00:00", @@ -564,7 +564,7 @@ def convert_to_events( ... "eye_color": ["blue", "blue", "green", "green", "green", "brown"], ... }, ... schema={ - ... "patient_id": pl.UInt8, + ... "subject_id": pl.UInt8, ... "admission_time": pl.Utf8, ... "discharge_time": pl.Datetime, ... "admission_type": pl.Utf8, @@ -602,7 +602,7 @@ def convert_to_events( >>> complex_raw_data shape: (6, 8) ┌────────────┬─────────────────────┬─────────────────────┬────────────────┬────────────────────┬────────────────┬────────────┬───────────┐ - │ patient_id ┆ admission_time ┆ discharge_time ┆ admission_type ┆ discharge_location ┆ severity_score ┆ death_time ┆ eye_color │ + │ subject_id ┆ admission_time ┆ discharge_time ┆ admission_type ┆ discharge_location ┆ severity_score ┆ death_time ┆ eye_color │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] ┆ str ┆ cat ┆ f64 ┆ str ┆ cat │ ╞════════════╪═════════════════════╪═════════════════════╪════════════════╪════════════════════╪════════════════╪════════════╪═══════════╡ @@ -616,7 +616,7 @@ def convert_to_events( >>> convert_to_events(complex_raw_data, event_cfgs) shape: (18, 7) ┌────────────┬───────────┬─────────────────────┬────────────────┬───────────────────────┬────────────────────┬───────────┐ - │ patient_id ┆ code ┆ time ┆ admission_type ┆ severity_on_admission ┆ discharge_location ┆ eye_color │ + │ subject_id ┆ code ┆ time ┆ admission_type ┆ severity_on_admission ┆ discharge_location ┆ eye_color │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ u8 ┆ str ┆ datetime[μs] ┆ str ┆ f64 ┆ cat ┆ cat │ ╞════════════╪═══════════╪═════════════════════╪════════════════╪═══════════════════════╪════════════════════╪═══════════╡ @@ -666,7 +666,7 @@ def convert_to_events( @hydra.main(version_base=None, config_path=str(CONFIG_YAML.parent), config_name=CONFIG_YAML.stem) def main(cfg: DictConfig): - """Converts the event-sharded raw data into MEDS events and storing them in patient subsharded flat files. + """Converts the event-sharded raw data into MEDS events and storing them in subject subsharded flat files. All arguments are specified through the command line into the `cfg` object through Hydra. @@ -680,7 +680,7 @@ def main(cfg: DictConfig): file. """ - input_dir, patient_subsharded_dir, metadata_input_dir = stage_init(cfg) + input_dir, subject_subsharded_dir, metadata_input_dir = stage_init(cfg) shards = json.loads(Path(cfg.shards_map_fp).read_text()) @@ -694,13 +694,13 @@ def main(cfg: DictConfig): event_conversion_cfg = OmegaConf.load(event_conversion_cfg_fp) logger.info(f"Event conversion config:\n{OmegaConf.to_yaml(event_conversion_cfg)}") - default_patient_id_col = event_conversion_cfg.pop("patient_id_col", "patient_id") + default_subject_id_col = event_conversion_cfg.pop("subject_id_col", "subject_id") - patient_subsharded_dir.mkdir(parents=True, exist_ok=True) - OmegaConf.save(event_conversion_cfg, patient_subsharded_dir / "event_conversion_config.yaml") + subject_subsharded_dir.mkdir(parents=True, exist_ok=True) + OmegaConf.save(event_conversion_cfg, subject_subsharded_dir / "event_conversion_config.yaml") - patient_splits = list(shards.items()) - random.shuffle(patient_splits) + subject_splits = list(shards.items()) + random.shuffle(subject_splits) event_configs = list(event_conversion_cfg.items()) random.shuffle(event_configs) @@ -708,28 +708,28 @@ def main(cfg: DictConfig): # Here, we'll be reading files directly, so we'll turn off globbing read_fn = partial(pl.scan_parquet, glob=False) - for sp, patients in patient_splits: + for sp, subjects in subject_splits: for input_prefix, event_cfgs in event_configs: event_cfgs = copy.deepcopy(event_cfgs) - input_patient_id_column = event_cfgs.pop("patient_id_col", default_patient_id_col) + input_subject_id_column = event_cfgs.pop("subject_id_col", default_subject_id_col) event_shards = list((input_dir / input_prefix).glob("*.parquet")) random.shuffle(event_shards) for shard_fp in event_shards: - out_fp = patient_subsharded_dir / sp / input_prefix / shard_fp.name + out_fp = subject_subsharded_dir / sp / input_prefix / shard_fp.name logger.info(f"Converting {shard_fp} to events and saving to {out_fp}") def compute_fn(df: pl.LazyFrame) -> pl.LazyFrame: - typed_patients = pl.Series(patients, dtype=df.schema[input_patient_id_column]) + typed_subjects = pl.Series(subjects, dtype=df.schema[input_subject_id_column]) - if input_patient_id_column != "patient_id": - df = df.rename({input_patient_id_column: "patient_id"}) + if input_subject_id_column != "subject_id": + df = df.rename({input_subject_id_column: "subject_id"}) try: logger.info(f"Extracting events for {input_prefix}/{shard_fp.name}") return convert_to_events( - df.filter(pl.col("patient_id").is_in(typed_patients)), + df.filter(pl.col("subject_id").is_in(typed_subjects)), event_cfgs=copy.deepcopy(event_cfgs), ) except Exception as e: diff --git a/src/MEDS_transforms/extract/extract_code_metadata.py b/src/MEDS_transforms/extract/extract_code_metadata.py index e9133eb..818d88a 100644 --- a/src/MEDS_transforms/extract/extract_code_metadata.py +++ b/src/MEDS_transforms/extract/extract_code_metadata.py @@ -250,9 +250,9 @@ def get_events_and_metadata_by_metadata_fp(event_configs: dict | DictConfig) -> Examples: >>> event_configs = { - ... "patient_id_col": "MRN", + ... "subject_id_col": "MRN", ... "icu/procedureevents": { - ... "patient_id_col": "subject_id", + ... "subject_id_col": "subject_id", ... "start": { ... "code": ["PROCEDURE", "START", "col(itemid)"], ... "_metadata": { @@ -304,11 +304,11 @@ def get_events_and_metadata_by_metadata_fp(event_configs: dict | DictConfig) -> out = {} for file_pfx, event_cfgs_for_pfx in event_configs.items(): - if file_pfx == "patient_id_col": + if file_pfx == "subject_id_col": continue for event_key, event_cfg in event_cfgs_for_pfx.items(): - if event_key == "patient_id_col": + if event_key == "subject_id_col": continue for metadata_pfx, metadata_cfg in event_cfg.get("_metadata", {}).items(): diff --git a/src/MEDS_transforms/extract/finalize_MEDS_data.py b/src/MEDS_transforms/extract/finalize_MEDS_data.py index f9d6873..54e1d20 100644 --- a/src/MEDS_transforms/extract/finalize_MEDS_data.py +++ b/src/MEDS_transforms/extract/finalize_MEDS_data.py @@ -38,7 +38,7 @@ def get_and_validate_data_schema(df: pl.LazyFrame, stage_cfg: DictConfig) -> pa. >>> get_and_validate_data_schema(df.lazy(), dict(do_retype=False)) # doctest: +NORMALIZE_WHITESPACE Traceback (most recent call last): ... - ValueError: MEDS Data DataFrame must have a 'patient_id' column of type Int64. + ValueError: MEDS Data DataFrame must have a 'subject_id' column of type Int64. MEDS Data DataFrame must have a 'time' column of type Datetime(time_unit='us', time_zone=None). MEDS Data DataFrame must have a 'code' column of type String. @@ -46,28 +46,28 @@ def get_and_validate_data_schema(df: pl.LazyFrame, stage_cfg: DictConfig) -> pa. >>> get_and_validate_data_schema(df.lazy(), {}) # doctest: +NORMALIZE_WHITESPACE Traceback (most recent call last): ... - ValueError: MEDS Data DataFrame must have a 'patient_id' column of type Int64. + ValueError: MEDS Data DataFrame must have a 'subject_id' column of type Int64. MEDS Data DataFrame must have a 'code' column of type String. >>> from datetime import datetime >>> df = pl.DataFrame({ - ... "patient_id": pl.Series([1, 2], dtype=pl.UInt32), + ... "subject_id": pl.Series([1, 2], dtype=pl.UInt32), ... "time": [datetime(2021, 1, 1), datetime(2021, 1, 2)], ... "code": ["A", "B"], "text_value": ["1", None], "numeric_value": [None, 34.2] ... }) >>> get_and_validate_data_schema(df.lazy(), dict(do_retype=False)) # doctest: +NORMALIZE_WHITESPACE Traceback (most recent call last): ... - ValueError: MEDS Data 'patient_id' column must be of type Int64. Got UInt32. + ValueError: MEDS Data 'subject_id' column must be of type Int64. Got UInt32. MEDS Data 'numeric_value' column must be of type Float32. Got Float64. >>> get_and_validate_data_schema(df.lazy(), {}) pyarrow.Table - patient_id: int64 + subject_id: int64 time: timestamp[us] code: string numeric_value: float text_value: large_string ---- - patient_id: [[1,2]] + subject_id: [[1,2]] time: [[2021-01-01 00:00:00.000000,2021-01-02 00:00:00.000000]] code: [["A","B"]] numeric_value: [[null,34.2]] @@ -111,7 +111,7 @@ def main(cfg: DictConfig): """Writes out schema compliant MEDS data files for the extracted dataset. In particular, this script ensures that all shard files are MEDS compliant with the mandatory columns - - `patient_id` (Int64) + - `subject_id` (Int64) - `time` (DateTime) - `code` (String) - `numeric_value` (Float32) diff --git a/src/MEDS_transforms/extract/finalize_MEDS_metadata.py b/src/MEDS_transforms/extract/finalize_MEDS_metadata.py index 366d89a..7a6dcc0 100755 --- a/src/MEDS_transforms/extract/finalize_MEDS_metadata.py +++ b/src/MEDS_transforms/extract/finalize_MEDS_metadata.py @@ -15,7 +15,8 @@ code_metadata_schema, dataset_metadata_schema, held_out_split, - patient_split_schema, + subject_id_field, + subject_split_schema, train_split, tuning_split, ) @@ -121,8 +122,8 @@ def main(cfg: DictConfig): - `etl_name` (string) - `etl_version` (string) - `meds_version` (string) - (3) a `metadata/patient_splits.parquet` file exists that has the mandatory columns - - `patient_id` (Int64) + (3) a `metadata/subject_splits.parquet` file exists that has the mandatory columns + - `subject_id` (Int64) - `split` (string) This stage *_should almost always be the last metadata stage in an extraction pipeline._* @@ -151,9 +152,9 @@ def main(cfg: DictConfig): output_code_metadata_fp = output_metadata_dir / "codes.parquet" dataset_metadata_fp = output_metadata_dir / "dataset.json" - patient_splits_fp = output_metadata_dir / "patient_splits.parquet" + subject_splits_fp = output_metadata_dir / "subject_splits.parquet" - for out_fp in [output_code_metadata_fp, dataset_metadata_fp, patient_splits_fp]: + for out_fp in [output_code_metadata_fp, dataset_metadata_fp, subject_splits_fp]: out_fp.parent.mkdir(parents=True, exist_ok=True) if out_fp.exists() and cfg.do_overwrite: out_fp.unlink() @@ -194,28 +195,28 @@ def main(cfg: DictConfig): # Split creation shards_map_fp = Path(cfg.shards_map_fp) - logger.info("Creating patient splits from {str(shards_map_fp.resolve())}") + logger.info("Creating subject splits from {str(shards_map_fp.resolve())}") shards_map = json.loads(shards_map_fp.read_text()) - patient_splits = [] + subject_splits = [] seen_splits = {train_split: 0, tuning_split: 0, held_out_split: 0} - for shard, patient_ids in shards_map.items(): + for shard, subject_ids in shards_map.items(): split = "/".join(shard.split("/")[:-1]) if split not in seen_splits: seen_splits[split] = 0 - seen_splits[split] += len(patient_ids) + seen_splits[split] += len(subject_ids) - patient_splits.extend([{"patient_id": pid, "split": split} for pid in patient_ids]) + subject_splits.extend([{subject_id_field: pid, "split": split} for pid in subject_ids]) for split, cnt in seen_splits.items(): if cnt: - logger.info(f"Split {split} has {cnt} patients") + logger.info(f"Split {split} has {cnt} subjects") else: logger.warning(f"Split {split} not found in shards map") - patient_splits_tbl = pa.Table.from_pylist(patient_splits, schema=patient_split_schema) - logger.info(f"Writing finalized patient splits to {str(patient_splits_fp.resolve())}") - pq.write_table(patient_splits_tbl, patient_splits_fp) + subject_splits_tbl = pa.Table.from_pylist(subject_splits, schema=subject_split_schema) + logger.info(f"Writing finalized subject splits to {str(subject_splits_fp.resolve())}") + pq.write_table(subject_splits_tbl, subject_splits_fp) if __name__ == "__main__": diff --git a/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py b/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py index 49a45e1..e611c77 100755 --- a/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py +++ b/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py @@ -30,8 +30,8 @@ def merge_subdirs_and_sort( warning is logged, but an error is *not* raised. Which rows are retained if the uniqeu-by columns are not all columns is not guaranteed, but is also *not* random, so this may have statistical implications. - additional_sort_by: Additional columns to sort by, in addition to the default sorting by patient ID - and time. If `None`, only patient ID and time are used. If a list of strings, these + additional_sort_by: Additional columns to sort by, in addition to the default sorting by subject ID + and time. If `None`, only subject ID and time are used. If a list of strings, these columns are used in addition to the default sorting. If a column is not found in the dataframe, it is omitted from the sort-by, a warning is logged, but an error is *not* raised. This functionality is useful both for deterministic testing and in cases where a data owner wants to impose @@ -41,7 +41,7 @@ def merge_subdirs_and_sort( A single dataframe containing all the data from the parquet files in the subdirs of `sp_dir`. These files will be concatenated diagonally, taking the union of all rows in all dataframes and all unique columns in all dataframes to form the merged output. The returned dataframe will be made unique by the - columns specified in `unique_by` and sorted by first patient ID, then time, then all columns in + columns specified in `unique_by` and sorted by first subject ID, then time, then all columns in `additional_sort_by`, if any. Raises: @@ -50,15 +50,15 @@ def merge_subdirs_and_sort( Examples: >>> from tempfile import TemporaryDirectory - >>> df1 = pl.DataFrame({"patient_id": [1, 2], "time": [10, 20], "code": ["A", "B"]}) + >>> df1 = pl.DataFrame({"subject_id": [1, 2], "time": [10, 20], "code": ["A", "B"]}) >>> df2 = pl.DataFrame({ - ... "patient_id": [1, 1, 3], + ... "subject_id": [1, 1, 3], ... "time": [2, 1, 8], ... "code": ["C", "D", "E"], ... "numeric_value": [None, 2.0, None], ... }) >>> df3 = pl.DataFrame({ - ... "patient_id": [1, 1, 3], + ... "subject_id": [1, 1, 3], ... "time": [2, 2, 8], ... "code": ["C", "D", "E"], ... "numeric_value": [6.2, 2.0, None], @@ -84,7 +84,7 @@ def merge_subdirs_and_sort( ... ).collect() shape: (8, 4) ┌────────────┬──────┬──────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str ┆ f64 │ ╞════════════╪══════╪══════╪═══════════════╡ @@ -112,7 +112,7 @@ def merge_subdirs_and_sort( ... ).collect() shape: (7, 4) ┌────────────┬──────┬──────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str ┆ f64 │ ╞════════════╪══════╪══════╪═══════════════╡ @@ -131,18 +131,18 @@ def merge_subdirs_and_sort( ... df2.write_parquet(sp_dir / "subdir1" / "file2.parquet") ... (sp_dir / "subdir2").mkdir() ... df3.write_parquet(sp_dir / "subdir2" / "df.parquet") - ... # We just display the patient ID, time, and code columns as the numeric value column + ... # We just display the subject ID, time, and code columns as the numeric value column ... # is not guaranteed to be deterministic in the output given some rows will be dropped due to ... # the unique-by constraint. ... merge_subdirs_and_sort( ... sp_dir, ... event_subsets=["subdir1", "subdir2"], - ... unique_by=["patient_id", "time", "code"], + ... unique_by=["subject_id", "time", "code"], ... additional_sort_by=["code", "numeric_value"] - ... ).select("patient_id", "time", "code").collect() + ... ).select("subject_id", "time", "code").collect() shape: (6, 3) ┌────────────┬──────┬──────┐ - │ patient_id ┆ time ┆ code │ + │ subject_id ┆ time ┆ code │ │ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str │ ╞════════════╪══════╪══════╡ @@ -188,7 +188,7 @@ def merge_subdirs_and_sort( case _: raise ValueError(f"Invalid unique_by value: {unique_by}") - sort_by = ["patient_id", "time"] + sort_by = ["subject_id", "time"] if additional_sort_by is not None: for s in additional_sort_by: if s in df_columns: @@ -201,12 +201,12 @@ def merge_subdirs_and_sort( @hydra.main(version_base=None, config_path=str(CONFIG_YAML.parent), config_name=CONFIG_YAML.stem) def main(cfg: DictConfig): - """Merges the patient sub-sharded events into a single parquet file per patient shard. + """Merges the subject sub-sharded events into a single parquet file per subject shard. This function takes all dataframes (in parquet files) in any subdirs of the `cfg.stage_cfg.input_dir` and merges them into a single dataframe. All dataframes in the subdirs are assumed to be in the unnested, MEDS - format, and cover the same group of patients (specific to the shard being processed). The merged dataframe - will also be sorted by patient ID and time. + format, and cover the same group of subjects (specific to the shard being processed). The merged dataframe + will also be sorted by subject ID and time. All arguments are specified through the command line into the `cfg` object through Hydra. @@ -219,14 +219,14 @@ def main(cfg: DictConfig): stage_configs.merge_to_MEDS_cohort.unique_by: The list of columns that should be ensured to be unique after the dataframes are merged. Defaults to `"*"`, which means all columns are used. stage_configs.merge_to_MEDS_cohort.additional_sort_by: Additional columns to sort by, in addition to - the default sorting by patient ID and time. Defaults to `None`, which means only patient ID + the default sorting by subject ID and time. Defaults to `None`, which means only subject ID and time are used. Returns: Writes the merged dataframes to the shard-specific output filepath in the `cfg.stage_cfg.output_dir`. """ event_conversion_cfg = OmegaConf.load(cfg.event_conversion_config_fp) - event_conversion_cfg.pop("patient_id_col", None) + event_conversion_cfg.pop("subject_id_col", None) read_fn = partial( merge_subdirs_and_sort, diff --git a/src/MEDS_transforms/extract/shard_events.py b/src/MEDS_transforms/extract/shard_events.py index 18450bb..5eebc71 100755 --- a/src/MEDS_transforms/extract/shard_events.py +++ b/src/MEDS_transforms/extract/shard_events.py @@ -11,6 +11,7 @@ import hydra import polars as pl from loguru import logger +from meds import subject_id_field from omegaconf import DictConfig, OmegaConf from MEDS_transforms.extract import CONFIG_YAML @@ -169,7 +170,7 @@ def retrieve_columns(event_conversion_cfg: DictConfig) -> dict[str, list[str]]: event conversion configurations that are specific to each file based on its stem (filename without the extension). It compiles a list of column names needed for each file from the configuration, which includes both general - columns like row index and patient ID, as well as specific columns defined + columns like row index and subject ID, as well as specific columns defined for medical events and times formatted in a special 'col(column_name)' syntax. Args: @@ -185,7 +186,7 @@ def retrieve_columns(event_conversion_cfg: DictConfig) -> dict[str, list[str]]: Examples: >>> cfg = DictConfig({ - ... "patient_id_col": "patient_id_global", + ... "subject_id_col": "subject_id_global", ... "hosp/patients": { ... "eye_color": { ... "code": ["EYE_COLOR", "col(eye_color)"], "time": None, "mod": "mod_col" @@ -195,7 +196,7 @@ def retrieve_columns(event_conversion_cfg: DictConfig) -> dict[str, list[str]]: ... } ... }, ... "icu/chartevents": { - ... "patient_id_col": "patient_id_icu", + ... "subject_id_col": "subject_id_icu", ... "heart_rate": { ... "code": "HEART_RATE", "time": "charttime", "numeric_value": "HR" ... }, @@ -212,19 +213,19 @@ def retrieve_columns(event_conversion_cfg: DictConfig) -> dict[str, list[str]]: ... } ... }) >>> retrieve_columns(cfg) # doctest: +NORMALIZE_WHITESPACE - {'hosp/patients': ['eye_color', 'height', 'mod_col', 'patient_id_global'], - 'icu/chartevents': ['HR', 'charttime', 'itemid', 'mod_lab', 'patient_id_icu', 'value', 'valuenum', + {'hosp/patients': ['eye_color', 'height', 'mod_col', 'subject_id_global'], + 'icu/chartevents': ['HR', 'charttime', 'itemid', 'mod_lab', 'subject_id_icu', 'value', 'valuenum', 'valueuom'], - 'icu/meds': ['medication', 'medtime', 'patient_id_global']} + 'icu/meds': ['medication', 'medtime', 'subject_id_global']} >>> cfg = DictConfig({ ... "subjects": { - ... "patient_id_col": "MRN", + ... "subject_id_col": "MRN", ... "eye_color": {"code": ["col(eye_color)"], "time": None}, ... }, ... "labs": {"lab": {"code": "col(labtest)", "time": "charttime"}}, ... }) >>> retrieve_columns(cfg) - {'subjects': ['MRN', 'eye_color'], 'labs': ['charttime', 'labtest', 'patient_id']} + {'subjects': ['MRN', 'eye_color'], 'labs': ['charttime', 'labtest', 'subject_id']} """ event_conversion_cfg = copy.deepcopy(event_conversion_cfg) @@ -232,11 +233,11 @@ def retrieve_columns(event_conversion_cfg: DictConfig) -> dict[str, list[str]]: # Initialize a dictionary to store file paths as keys and lists of column names as values. prefix_to_columns = {} - default_patient_id_col = event_conversion_cfg.pop("patient_id_col", "patient_id") + default_subject_id_col = event_conversion_cfg.pop("subject_id_col", subject_id_field) for input_prefix, event_cfgs in event_conversion_cfg.items(): - input_patient_id_column = event_cfgs.pop("patient_id_col", default_patient_id_col) + input_subject_id_column = event_cfgs.pop("subject_id_col", default_subject_id_col) - prefix_to_columns[input_prefix] = {input_patient_id_column} + prefix_to_columns[input_prefix] = {input_subject_id_column} for event_cfg in event_cfgs.values(): # If the config has a 'code' key and it contains column fields, parse and add them. diff --git a/src/MEDS_transforms/extract/split_and_shard_patients.py b/src/MEDS_transforms/extract/split_and_shard_subjects.py similarity index 71% rename from src/MEDS_transforms/extract/split_and_shard_patients.py rename to src/MEDS_transforms/extract/split_and_shard_subjects.py index 61f1726..3081405 100755 --- a/src/MEDS_transforms/extract/split_and_shard_patients.py +++ b/src/MEDS_transforms/extract/split_and_shard_subjects.py @@ -13,77 +13,77 @@ from MEDS_transforms.utils import stage_init -def shard_patients[ +def shard_subjects[ SUBJ_ID_T ]( - patients: np.ndarray, - n_patients_per_shard: int = 50000, + subjects: np.ndarray, + n_subjects_per_shard: int = 50000, external_splits: dict[str, Sequence[SUBJ_ID_T]] | None = None, split_fracs_dict: dict[str, float] | None = {"train": 0.8, "tuning": 0.1, "held_out": 0.1}, seed: int = 1, ) -> dict[str, list[SUBJ_ID_T]]: - """Shard a list of patients, nested within train/tuning/held-out splits. + """Shard a list of subjects, nested within train/tuning/held-out splits. - This function takes a list of patients and shards them into train/tuning/held-out splits, with the shards + This function takes a list of subjects and shards them into train/tuning/held-out splits, with the shards of a consistent size, nested within the splits. The function will also respect external splits, if provided, such that mandated splits (such as prospective held out sets or pre-existing, task-specific held out sets) are with-held and sharded as separate splits from the IID splits defined by `split_fracs_dict`. It returns a dictionary mapping the split and shard names (realized as f"{split}/{shard}") to the list of - patients in that shard. + subjects in that shard. Args: - patients: The list of patients to shard. - n_patients_per_shard: The maximum number of patients to include in each shard. + subjects: The list of subjects to shard. + n_subjects_per_shard: The maximum number of subjects to include in each shard. external_splits: The externally defined splits to respect. If provided, the keys of this dictionary - will be used as split names, and the values as the list of patients in that split. These + will be used as split names, and the values as the list of subjects in that split. These pre-defined splits will be excluded from IID splits generated by this function, but will be sharded like normal. Note that this is largely only appropriate for held-out sets for pre-defined - tasks or test cases (e.g., prospective tests); training patients should often still be included in + tasks or test cases (e.g., prospective tests); training subjects should often still be included in the IID splits to maximize the amount of data that can be used for training. - split_fracs_dict: A dictionary mapping the split name to the fraction of patients to include in that + split_fracs_dict: A dictionary mapping the split name to the fraction of subjects to include in that split. Defaults to 80% train, 10% tuning, 10% held-out. This can be None or empty only when external splits fully specify the population. - seed: The random seed to use for shuffling the patients before seeding and sharding. This is useful + seed: The random seed to use for shuffling the subjects before seeding and sharding. This is useful for ensuring reproducibility. Returns: - A dictionary mapping f"{split}/{shard}" to the list of patients in that shard. This may include - overlapping patients across a subset of these splits, but never across shards within a split. Any + A dictionary mapping f"{split}/{shard}" to the list of subjects in that shard. This may include + overlapping subjects across a subset of these splits, but never across shards within a split. Any overlap will solely occur between the an external split and another external split. Raises: ValueError: If the sum of the split fractions in `split_fracs_dict` is not equal to 1. Examples: - >>> patients = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=int) - >>> shard_patients(patients, n_patients_per_shard=3) + >>> subjects = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=int) + >>> shard_subjects(subjects, n_subjects_per_shard=3) {'train/0': [9, 4, 8], 'train/1': [2, 1, 10], 'train/2': [6, 5], 'tuning/0': [3], 'held_out/0': [7]} >>> external_splits = { ... 'taskA/held_out': np.array([8, 9, 10], dtype=int), ... 'taskB/held_out': np.array([10, 8, 9], dtype=int), ... } - >>> shard_patients(patients, 3, external_splits) # doctest: +NORMALIZE_WHITESPACE + >>> shard_subjects(subjects, 3, external_splits) # doctest: +NORMALIZE_WHITESPACE {'train/0': [5, 7, 4], 'train/1': [1, 2], 'tuning/0': [3], 'held_out/0': [6], 'taskA/held_out/0': [8, 9, 10], 'taskB/held_out/0': [10, 8, 9]} - >>> shard_patients(patients, n_patients_per_shard=3, split_fracs_dict={'train': 0.5}) + >>> shard_subjects(subjects, n_subjects_per_shard=3, split_fracs_dict={'train': 0.5}) Traceback (most recent call last): ... ValueError: The sum of the split fractions must be equal to 1. - >>> shard_patients([1, 2], n_patients_per_shard=3) + >>> shard_subjects([1, 2], n_subjects_per_shard=3) Traceback (most recent call last): ... - ValueError: Unable to adjust splits to ensure all splits have at least 1 patient. + ValueError: Unable to adjust splits to ensure all splits have at least 1 subject. >>> external_splits = { ... 'train': np.array([1, 2, 3, 4, 5, 6], dtype=int), ... 'test': np.array([7, 8, 9, 10], dtype=int), ... } - >>> shard_patients(patients, 6, external_splits, split_fracs_dict=None) + >>> shard_subjects(subjects, 6, external_splits, split_fracs_dict=None) {'train/0': [1, 2, 3, 4, 5, 6], 'test/0': [7, 8, 9, 10]} - >>> shard_patients(patients, 3, external_splits) + >>> shard_subjects(subjects, 3, external_splits) {'train/0': [5, 1, 3], 'train/1': [2, 6, 4], 'test/0': [10, 7], 'test/1': [8, 9]} """ @@ -94,62 +94,62 @@ def shard_patients[ if not isinstance(external_splits[k], np.ndarray): logger.warning( f"External split {k} is not a numpy array and thus type safety is not guaranteed. " - f"Attempting to convert to numpy array of dtype {patients.dtype}." + f"Attempting to convert to numpy array of dtype {subjects.dtype}." ) - external_splits[k] = np.array(external_splits[k], dtype=patients.dtype) + external_splits[k] = np.array(external_splits[k], dtype=subjects.dtype) - patients = np.unique(patients) + subjects = np.unique(subjects) # Splitting all_external_splits = set().union(*external_splits.values()) - is_in_external_split = np.isin(patients, list(all_external_splits)) - patient_ids_to_split = patients[~is_in_external_split] + is_in_external_split = np.isin(subjects, list(all_external_splits)) + subject_ids_to_split = subjects[~is_in_external_split] splits = external_splits rng = np.random.default_rng(seed) - if n_patients := len(patient_ids_to_split): + if n_subjects := len(subject_ids_to_split): if sum(split_fracs_dict.values()) != 1: raise ValueError("The sum of the split fractions must be equal to 1.") split_names_idx = rng.permutation(len(split_fracs_dict)) split_names = np.array(list(split_fracs_dict.keys()))[split_names_idx] split_fracs = np.array([split_fracs_dict[k] for k in split_names]) - split_lens = np.round(split_fracs[:-1] * n_patients).astype(int) - split_lens = np.append(split_lens, n_patients - split_lens.sum()) + split_lens = np.round(split_fracs[:-1] * n_subjects).astype(int) + split_lens = np.append(split_lens, n_subjects - split_lens.sum()) if split_lens.min() == 0: logger.warning( - "Some splits are empty. Adjusting splits to ensure all splits have at least 1 patient." + "Some splits are empty. Adjusting splits to ensure all splits have at least 1 subject." ) max_split = split_lens.argmax() split_lens[max_split] -= 1 split_lens[split_lens.argmin()] += 1 if split_lens.min() == 0: - raise ValueError("Unable to adjust splits to ensure all splits have at least 1 patient.") + raise ValueError("Unable to adjust splits to ensure all splits have at least 1 subject.") - patients = rng.permutation(patient_ids_to_split) - patients_per_split = np.split(patients, split_lens.cumsum()) + subjects = rng.permutation(subject_ids_to_split) + subjects_per_split = np.split(subjects, split_lens.cumsum()) - splits = {**{k: v for k, v in zip(split_names, patients_per_split)}, **splits} + splits = {**{k: v for k, v in zip(split_names, subjects_per_split)}, **splits} else: if split_fracs_dict: logger.warning( - "External splits were provided covering all patients, but split_fracs_dict was not empty. " + "External splits were provided covering all subjects, but split_fracs_dict was not empty. " "Ignoring the split_fracs_dict." ) else: - logger.info("External splits were provided covering all patients.") + logger.info("External splits were provided covering all subjects.") # Sharding final_shards = {} for sp, pts in splits.items(): - if len(pts) <= n_patients_per_shard: + if len(pts) <= n_subjects_per_shard: final_shards[f"{sp}/0"] = pts.tolist() else: pts = rng.permutation(pts) n_pts = len(pts) - n_shards = int(np.ceil(n_pts / n_patients_per_shard)) + n_shards = int(np.ceil(n_pts / n_subjects_per_shard)) shards = np.array_split(pts, n_shards) for i, shard in enumerate(shards): final_shards[f"{sp}/{i}"] = shard.tolist() @@ -157,12 +157,12 @@ def shard_patients[ seen = {} for k, pts in final_shards.items(): - logger.info(f"Split {k} has {len(pts)} patients.") + logger.info(f"Split {k} has {len(pts)} subjects.") for kk, v in seen.items(): shared = set(pts).intersection(v) if shared: - logger.info(f" - intersects {kk} on {len(shared)} patients.") + logger.info(f" - intersects {kk} on {len(shared)} subjects.") seen[k] = set(pts) @@ -171,9 +171,9 @@ def shard_patients[ @hydra.main(version_base=None, config_path=str(CONFIG_YAML.parent), config_name=CONFIG_YAML.stem) def main(cfg: DictConfig): - """Extracts the set of unique patients from the raw data and splits/shards them and saves the result. + """Extracts the set of unique subjects from the raw data and splits/shards them and saves the result. - This stage splits the patients into training, tuning, and held-out sets, and further splits those sets + This stage splits the subjects into training, tuning, and held-out sets, and further splits those sets into shards. All arguments are specified through the command line into the `cfg` object through Hydra. @@ -181,19 +181,19 @@ def main(cfg: DictConfig): The `cfg.stage_cfg` object is a special key that is imputed by OmegaConf to contain the stage-specific configuration arguments based on the global, pipeline-level configuration file. It cannot be overwritten directly on the command line, but can be overwritten implicitly by overwriting components of the - `stage_configs.split_and_shard_patients` key. + `stage_configs.split_and_shard_subjects` key. Args: - stage_configs.split_and_shard_patients.n_patients_per_shard: The maximum number of patients to include - in any shard. Realized shards will not necessarily have this many patients, though they will never - exceed this number. Instead, the number of shards necessary to include all patients in a split - such that no shard exceeds this number will be calculated, then the patients will be evenly, + stage_configs.split_and_shard_subjects.n_subjects_per_shard: The maximum number of subjects to include + in any shard. Realized shards will not necessarily have this many subjects, though they will never + exceed this number. Instead, the number of shards necessary to include all subjects in a split + such that no shard exceeds this number will be calculated, then the subjects will be evenly, randomly split amongst those shards so that all shards within a split have approximately the same number of patietns. - stage_configs.split_and_shard_patients.external_splits_json_fp: The path to a json file containing any + stage_configs.split_and_shard_subjects.external_splits_json_fp: The path to a json file containing any pre-defined splits for specialty held-out test sets beyond the IID held out set that will be produced (e.g., for prospective datasets, etc.). - stage_configs.split_and_shard_patients.split_fracs: The fraction of patients to include in the IID + stage_configs.split_and_shard_subjects.split_fracs: The fraction of subjects to include in the IID training, tuning, and held-out sets. Split fractions can be changed for the default names by adding a hydra-syntax command line argument for the nested name; e.g., `split_fracs.train=0.7 split_fracs.tuning=0.1 split_fracs.held_out=0.2`. A split can be removed with the `~` override @@ -209,38 +209,38 @@ def main(cfg: DictConfig): raise FileNotFoundError(f"Event conversion config file not found: {event_conversion_cfg_fp}") logger.info( - f"Reading event conversion config from {event_conversion_cfg_fp} (needed for patient ID columns)" + f"Reading event conversion config from {event_conversion_cfg_fp} (needed for subject ID columns)" ) event_conversion_cfg = OmegaConf.load(event_conversion_cfg_fp) logger.info(f"Event conversion config:\n{OmegaConf.to_yaml(event_conversion_cfg)}") dfs = [] - default_patient_id_col = event_conversion_cfg.pop("patient_id_col", "patient_id") + default_subject_id_col = event_conversion_cfg.pop("subject_id_col", "subject_id") for input_prefix, event_cfgs in event_conversion_cfg.items(): - input_patient_id_column = event_cfgs.get("patient_id_col", default_patient_id_col) + input_subject_id_column = event_cfgs.get("subject_id_col", default_subject_id_col) input_fps = list((subsharded_dir / input_prefix).glob("**/*.parquet")) input_fps_strs = "\n".join(f" - {str(fp.resolve())}" for fp in input_fps) - logger.info(f"Reading patient IDs from {input_prefix} files:\n{input_fps_strs}") + logger.info(f"Reading subject IDs from {input_prefix} files:\n{input_fps_strs}") for input_fp in input_fps: dfs.append( pl.scan_parquet(input_fp, glob=False) - .select(pl.col(input_patient_id_column).alias("patient_id")) + .select(pl.col(input_subject_id_column).alias("subject_id")) .unique() ) - logger.info(f"Joining all patient IDs from {len(dfs)} dataframes") - patient_ids = ( + logger.info(f"Joining all subject IDs from {len(dfs)} dataframes") + subject_ids = ( pl.concat(dfs) - .select(pl.col("patient_id").drop_nulls().drop_nans().unique()) - .collect(streaming=True)["patient_id"] + .select(pl.col("subject_id").drop_nulls().drop_nans().unique()) + .collect(streaming=True)["subject_id"] .to_numpy(use_pyarrow=True) ) - logger.info(f"Found {len(patient_ids)} unique patient IDs of type {patient_ids.dtype}") + logger.info(f"Found {len(subject_ids)} unique subject IDs of type {subject_ids.dtype}") if cfg.stage_cfg.external_splits_json_fp: external_splits_json_fp = Path(cfg.stage_cfg.external_splits_json_fp) @@ -255,21 +255,21 @@ def main(cfg: DictConfig): else: external_splits = None - logger.info("Sharding and splitting patients") + logger.info("Sharding and splitting subjects") - sharded_patients = shard_patients( - patients=patient_ids, + sharded_subjects = shard_subjects( + subjects=subject_ids, external_splits=external_splits, split_fracs_dict=cfg.stage_cfg.split_fracs, - n_patients_per_shard=cfg.stage_cfg.n_patients_per_shard, + n_subjects_per_shard=cfg.stage_cfg.n_subjects_per_shard, seed=cfg.seed, ) shards_map_fp = Path(cfg.shards_map_fp) - logger.info(f"Writing sharded patients to {str(shards_map_fp.resolve())}") + logger.info(f"Writing sharded subjects to {str(shards_map_fp.resolve())}") shards_map_fp.parent.mkdir(parents=True, exist_ok=True) - shards_map_fp.write_text(json.dumps(sharded_patients)) - logger.info("Done writing sharded patients") + shards_map_fp.write_text(json.dumps(sharded_subjects)) + logger.info("Done writing sharded subjects") if __name__ == "__main__": diff --git a/src/MEDS_transforms/filters/README.md b/src/MEDS_transforms/filters/README.md index 9d582f0..22baa4b 100644 --- a/src/MEDS_transforms/filters/README.md +++ b/src/MEDS_transforms/filters/README.md @@ -1,5 +1,5 @@ # Filters -Filters remove wholesale events within the data, either at the patient or event level. For transformations +Filters remove wholesale events within the data, either at the subject or event level. For transformations that simply _occlude_ aspects of the data (e.g., by setting a code variable to `UNK`), see the `transforms` library section. diff --git a/src/MEDS_transforms/filters/filter_measurements.py b/src/MEDS_transforms/filters/filter_measurements.py index 36a6938..4c0db29 100644 --- a/src/MEDS_transforms/filters/filter_measurements.py +++ b/src/MEDS_transforms/filters/filter_measurements.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -"""A polars-to-polars transformation function for filtering patients by sequence length.""" +"""A polars-to-polars transformation function for filtering subjects by sequence length.""" from collections.abc import Callable import hydra @@ -13,7 +13,7 @@ def filter_measurements_fntr( stage_cfg: DictConfig, code_metadata: pl.LazyFrame, code_modifiers: list[str] | None = None ) -> Callable[[pl.LazyFrame], pl.LazyFrame]: - """Returns a function that filters patient events to only encompass those with a set of permissible codes. + """Returns a function that filters subject events to only encompass those with a set of permissible codes. Args: df: The input DataFrame. @@ -26,44 +26,44 @@ def filter_measurements_fntr( >>> code_metadata_df = pl.DataFrame({ ... "code": ["A", "A", "B", "C"], ... "modifier1": [1, 2, 1, 2], - ... "code/n_patients": [2, 1, 3, 2], + ... "code/n_subjects": [2, 1, 3, 2], ... "code/n_occurrences": [4, 5, 3, 2], ... }) >>> data = pl.DataFrame({ - ... "patient_id": [1, 1, 2, 2], + ... "subject_id": [1, 1, 2, 2], ... "code": ["A", "B", "A", "C"], ... "modifier1": [1, 1, 2, 2], ... }).lazy() - >>> stage_cfg = DictConfig({"min_patients_per_code": 2, "min_occurrences_per_code": 3}) + >>> stage_cfg = DictConfig({"min_subjects_per_code": 2, "min_occurrences_per_code": 3}) >>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"]) >>> fn(data).collect() shape: (2, 3) ┌────────────┬──────┬───────────┐ - │ patient_id ┆ code ┆ modifier1 │ + │ subject_id ┆ code ┆ modifier1 │ │ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ i64 │ ╞════════════╪══════╪═══════════╡ │ 1 ┆ A ┆ 1 │ │ 1 ┆ B ┆ 1 │ └────────────┴──────┴───────────┘ - >>> stage_cfg = DictConfig({"min_patients_per_code": 1, "min_occurrences_per_code": 4}) + >>> stage_cfg = DictConfig({"min_subjects_per_code": 1, "min_occurrences_per_code": 4}) >>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"]) >>> fn(data).collect() shape: (2, 3) ┌────────────┬──────┬───────────┐ - │ patient_id ┆ code ┆ modifier1 │ + │ subject_id ┆ code ┆ modifier1 │ │ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ i64 │ ╞════════════╪══════╪═══════════╡ │ 1 ┆ A ┆ 1 │ │ 2 ┆ A ┆ 2 │ └────────────┴──────┴───────────┘ - >>> stage_cfg = DictConfig({"min_patients_per_code": 1}) + >>> stage_cfg = DictConfig({"min_subjects_per_code": 1}) >>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"]) >>> fn(data).collect() shape: (4, 3) ┌────────────┬──────┬───────────┐ - │ patient_id ┆ code ┆ modifier1 │ + │ subject_id ┆ code ┆ modifier1 │ │ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ i64 │ ╞════════════╪══════╪═══════════╡ @@ -72,12 +72,12 @@ def filter_measurements_fntr( │ 2 ┆ A ┆ 2 │ │ 2 ┆ C ┆ 2 │ └────────────┴──────┴───────────┘ - >>> stage_cfg = DictConfig({"min_patients_per_code": None, "min_occurrences_per_code": None}) + >>> stage_cfg = DictConfig({"min_subjects_per_code": None, "min_occurrences_per_code": None}) >>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"]) >>> fn(data).collect() shape: (4, 3) ┌────────────┬──────┬───────────┐ - │ patient_id ┆ code ┆ modifier1 │ + │ subject_id ┆ code ┆ modifier1 │ │ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ i64 │ ╞════════════╪══════╪═══════════╡ @@ -91,7 +91,7 @@ def filter_measurements_fntr( >>> fn(data).collect() shape: (1, 3) ┌────────────┬──────┬───────────┐ - │ patient_id ┆ code ┆ modifier1 │ + │ subject_id ┆ code ┆ modifier1 │ │ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ i64 │ ╞════════════╪══════╪═══════════╡ @@ -99,12 +99,12 @@ def filter_measurements_fntr( └────────────┴──────┴───────────┘ """ - min_patients_per_code = stage_cfg.get("min_patients_per_code", None) + min_subjects_per_code = stage_cfg.get("min_subjects_per_code", None) min_occurrences_per_code = stage_cfg.get("min_occurrences_per_code", None) filter_exprs = [] - if min_patients_per_code is not None: - filter_exprs.append(pl.col("code/n_patients") >= min_patients_per_code) + if min_subjects_per_code is not None: + filter_exprs.append(pl.col("code/n_subjects") >= min_subjects_per_code) if min_occurrences_per_code is not None: filter_exprs.append(pl.col("code/n_occurrences") >= min_occurrences_per_code) @@ -118,10 +118,10 @@ def filter_measurements_fntr( allowed_code_metadata = (code_metadata.filter(pl.all_horizontal(filter_exprs)).select(join_cols)).lazy() def filter_measurements_fn(df: pl.LazyFrame) -> pl.LazyFrame: - f"""Filters patient events to only encompass those with a set of permissible codes. + f"""Filters subject events to only encompass those with a set of permissible codes. In particular, this function filters the DataFrame to only include (code, modifier) pairs that have - at least {min_patients_per_code} patients and {min_occurrences_per_code} occurrences. + at least {min_subjects_per_code} subjects and {min_occurrences_per_code} occurrences. """ idx_col = "_row_idx" diff --git a/src/MEDS_transforms/filters/filter_patients.py b/src/MEDS_transforms/filters/filter_subjects.py similarity index 67% rename from src/MEDS_transforms/filters/filter_patients.py rename to src/MEDS_transforms/filters/filter_subjects.py index 36dc398..007168d 100644 --- a/src/MEDS_transforms/filters/filter_patients.py +++ b/src/MEDS_transforms/filters/filter_subjects.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -"""A polars-to-polars transformation function for filtering patients by sequence length.""" +"""A polars-to-polars transformation function for filtering subjects by sequence length.""" from collections.abc import Callable from functools import partial @@ -12,25 +12,25 @@ from MEDS_transforms.mapreduce.mapper import map_over -def filter_patients_by_num_measurements(df: pl.LazyFrame, min_measurements_per_patient: int) -> pl.LazyFrame: - """Filters patients by the number of measurements they have. +def filter_subjects_by_num_measurements(df: pl.LazyFrame, min_measurements_per_subject: int) -> pl.LazyFrame: + """Filters subjects by the number of measurements they have. Args: df: The input DataFrame. - min_measurements_per_patient: The minimum number of measurements a patient must have to be included. + min_measurements_per_subject: The minimum number of measurements a subject must have to be included. Returns: The filtered DataFrame. Examples: >>> df = pl.DataFrame({ - ... "patient_id": [1, 1, 1, 2, 2, 3], + ... "subject_id": [1, 1, 1, 2, 2, 3], ... "time": [1, 2, 1, 1, 2, 1], ... }) - >>> filter_patients_by_num_measurements(df, 1) + >>> filter_subjects_by_num_measurements(df, 1) shape: (6, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ @@ -41,10 +41,10 @@ def filter_patients_by_num_measurements(df: pl.LazyFrame, min_measurements_per_p │ 2 ┆ 2 │ │ 3 ┆ 1 │ └────────────┴──────┘ - >>> filter_patients_by_num_measurements(df, 2) + >>> filter_subjects_by_num_measurements(df, 2) shape: (5, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ @@ -54,10 +54,10 @@ def filter_patients_by_num_measurements(df: pl.LazyFrame, min_measurements_per_p │ 2 ┆ 1 │ │ 2 ┆ 2 │ └────────────┴──────┘ - >>> filter_patients_by_num_measurements(df, 3) + >>> filter_subjects_by_num_measurements(df, 3) shape: (3, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ @@ -65,47 +65,47 @@ def filter_patients_by_num_measurements(df: pl.LazyFrame, min_measurements_per_p │ 1 ┆ 2 │ │ 1 ┆ 1 │ └────────────┴──────┘ - >>> filter_patients_by_num_measurements(df, 4) + >>> filter_subjects_by_num_measurements(df, 4) shape: (0, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ └────────────┴──────┘ - >>> filter_patients_by_num_measurements(df, 2.2) + >>> filter_subjects_by_num_measurements(df, 2.2) Traceback (most recent call last): ... - TypeError: min_measurements_per_patient must be an integer; got 2.2 + TypeError: min_measurements_per_subject must be an integer; got 2.2 """ - if not isinstance(min_measurements_per_patient, int): + if not isinstance(min_measurements_per_subject, int): raise TypeError( - f"min_measurements_per_patient must be an integer; got {type(min_measurements_per_patient)} " - f"{min_measurements_per_patient}" + f"min_measurements_per_subject must be an integer; got {type(min_measurements_per_subject)} " + f"{min_measurements_per_subject}" ) - return df.filter(pl.col("time").count().over("patient_id") >= min_measurements_per_patient) + return df.filter(pl.col("time").count().over("subject_id") >= min_measurements_per_subject) -def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) -> pl.LazyFrame: - """Filters patients by the number of events (unique timepoints) they have. +def filter_subjects_by_num_events(df: pl.LazyFrame, min_events_per_subject: int) -> pl.LazyFrame: + """Filters subjects by the number of events (unique timepoints) they have. Args: df: The input DataFrame. - min_events_per_patient: The minimum number of events a patient must have to be included. + min_events_per_subject: The minimum number of events a subject must have to be included. Returns: The filtered DataFrame. Examples: >>> df = pl.DataFrame({ - ... "patient_id": [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4], + ... "subject_id": [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4], ... "time": [1, 1, 1, 1, 2, 1, 1, 2, 3, None, None, 1, 2, 3], ... }) - >>> filter_patients_by_num_events(df, 1) + >>> filter_subjects_by_num_events(df, 1) shape: (14, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ @@ -124,10 +124,10 @@ def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) │ 4 ┆ 2 │ │ 4 ┆ 3 │ └────────────┴──────┘ - >>> filter_patients_by_num_events(df, 2) + >>> filter_subjects_by_num_events(df, 2) shape: (11, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ @@ -143,10 +143,10 @@ def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) │ 4 ┆ 2 │ │ 4 ┆ 3 │ └────────────┴──────┘ - >>> filter_patients_by_num_events(df, 3) + >>> filter_subjects_by_num_events(df, 3) shape: (8, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ @@ -159,10 +159,10 @@ def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) │ 4 ┆ 2 │ │ 4 ┆ 3 │ └────────────┴──────┘ - >>> filter_patients_by_num_events(df, 4) + >>> filter_subjects_by_num_events(df, 4) shape: (5, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ @@ -172,48 +172,48 @@ def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) │ 4 ┆ 2 │ │ 4 ┆ 3 │ └────────────┴──────┘ - >>> filter_patients_by_num_events(df, 5) + >>> filter_subjects_by_num_events(df, 5) shape: (0, 2) ┌────────────┬──────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞════════════╪══════╡ └────────────┴──────┘ - >>> filter_patients_by_num_events(df, 2.2) + >>> filter_subjects_by_num_events(df, 2.2) Traceback (most recent call last): ... - TypeError: min_events_per_patient must be an integer; got 2.2 + TypeError: min_events_per_subject must be an integer; got 2.2 """ - if not isinstance(min_events_per_patient, int): + if not isinstance(min_events_per_subject, int): raise TypeError( - f"min_events_per_patient must be an integer; got {type(min_events_per_patient)} " - f"{min_events_per_patient}" + f"min_events_per_subject must be an integer; got {type(min_events_per_subject)} " + f"{min_events_per_subject}" ) - return df.filter(pl.col("time").n_unique().over("patient_id") >= min_events_per_patient) + return df.filter(pl.col("time").n_unique().over("subject_id") >= min_events_per_subject) -def filter_patients_fntr(stage_cfg: DictConfig) -> Callable[[pl.LazyFrame], pl.LazyFrame]: +def filter_subjects_fntr(stage_cfg: DictConfig) -> Callable[[pl.LazyFrame], pl.LazyFrame]: compute_fns = [] - if stage_cfg.min_measurements_per_patient: + if stage_cfg.min_measurements_per_subject: logger.info( - f"Filtering patients with fewer than {stage_cfg.min_measurements_per_patient} measurements " + f"Filtering subjects with fewer than {stage_cfg.min_measurements_per_subject} measurements " "(observations of any kind)." ) compute_fns.append( partial( - filter_patients_by_num_measurements, - min_measurements_per_patient=stage_cfg.min_measurements_per_patient, + filter_subjects_by_num_measurements, + min_measurements_per_subject=stage_cfg.min_measurements_per_subject, ) ) - if stage_cfg.min_events_per_patient: + if stage_cfg.min_events_per_subject: logger.info( - f"Filtering patients with fewer than {stage_cfg.min_events_per_patient} events " + f"Filtering subjects with fewer than {stage_cfg.min_events_per_subject} events " "(unique timepoints)." ) compute_fns.append( - partial(filter_patients_by_num_events, min_events_per_patient=stage_cfg.min_events_per_patient) + partial(filter_subjects_by_num_events, min_events_per_subject=stage_cfg.min_events_per_subject) ) def fn(data: pl.LazyFrame) -> pl.LazyFrame: @@ -230,7 +230,7 @@ def fn(data: pl.LazyFrame) -> pl.LazyFrame: def main(cfg: DictConfig): """TODO.""" - map_over(cfg, compute_fn=filter_patients_fntr) + map_over(cfg, compute_fn=filter_subjects_fntr) if __name__ == "__main__": diff --git a/src/MEDS_transforms/mapreduce/mapper.py b/src/MEDS_transforms/mapreduce/mapper.py index 7649082..9888d9d 100644 --- a/src/MEDS_transforms/mapreduce/mapper.py +++ b/src/MEDS_transforms/mapreduce/mapper.py @@ -309,7 +309,7 @@ def match_revise_fntr(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_CO compute function with the ``local_arg_1=baz`` parameter. Both of these local compute functions will be applied to the input DataFrame in sequence, and the resulting DataFrames will be concatenated alongside any of the dataframe that matches no matcher (which will be left unmodified) and merged in a sorted way - that respects the ``patient_id``, ``time`` ordering first, then the order of the match & revise blocks + that respects the ``subject_id``, ``time`` ordering first, then the order of the match & revise blocks themselves, then the order of the rows in each match & revise block output. Each local compute function will also use the ``global_arg_1=foo`` parameter. @@ -331,7 +331,7 @@ def match_revise_fntr(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_CO Examples: >>> df = pl.DataFrame({ - ... "patient_id": [1, 1, 1, 2, 2, 2], + ... "subject_id": [1, 1, 1, 2, 2, 2], ... "time": [1, 2, 2, 1, 1, 2], ... "initial_idx": [0, 1, 2, 3, 4, 5], ... "code": ["FINAL", "CODE//TEMP_2", "CODE//TEMP_1", "FINAL", "CODE//TEMP_2", "CODE//TEMP_1"] @@ -353,7 +353,7 @@ def match_revise_fntr(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_CO >>> match_revise_fn(df.lazy()).collect() shape: (6, 4) ┌────────────┬──────┬─────────────┬────────────────┐ - │ patient_id ┆ time ┆ initial_idx ┆ code │ + │ subject_id ┆ time ┆ initial_idx ┆ code │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ i64 ┆ str │ ╞════════════╪══════╪═════════════╪════════════════╡ @@ -376,7 +376,7 @@ def match_revise_fntr(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_CO >>> match_revise_fn(df.lazy()).collect() shape: (6, 4) ┌────────────┬──────┬─────────────┬─────────────────┐ - │ patient_id ┆ time ┆ initial_idx ┆ code │ + │ subject_id ┆ time ┆ initial_idx ┆ code │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ i64 ┆ str │ ╞════════════╪══════╪═════════════╪═════════════════╡ @@ -397,7 +397,7 @@ def match_revise_fntr(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_CO ... ValueError: Missing needed columns {'missing'} for local matcher 0: [(col("missing")) == (String(CODE//TEMP_2))].all_horizontal() - Columns available: 'code', 'initial_idx', 'patient_id', 'time' + Columns available: 'code', 'initial_idx', 'subject_id', 'time' >>> stage_cfg = DictConfig({"global_code_end": "foo"}) >>> cfg = DictConfig({"stage_cfg": stage_cfg}) >>> match_revise_fn = match_revise_fntr(cfg, stage_cfg, compute_fn) @@ -439,7 +439,7 @@ def match_revise_fn(df: DF_T) -> DF_T: revision_parts.append(local_compute_fn(matched_df)) revision_parts.append(unmatched_df) - return pl.concat(revision_parts, how="vertical").sort(["patient_id", "time"], maintain_order=True) + return pl.concat(revision_parts, how="vertical").sort(["subject_id", "time"], maintain_order=True) return match_revise_fn @@ -580,7 +580,7 @@ def map_over( start = datetime.now() train_only = cfg.stage_cfg.get("train_only", False) - split_fp = Path(cfg.input_dir) / "metadata" / "patient_split.parquet" + split_fp = Path(cfg.input_dir) / "metadata" / "subject_split.parquet" shards, includes_only_train = shard_iterator_fntr(cfg) @@ -591,18 +591,18 @@ def map_over( ) elif split_fp.exists(): logger.info(f"Processing train split only by filtering read dfs via {str(split_fp.resolve())}") - train_patients = ( + train_subjects = ( pl.scan_parquet(split_fp) .filter(pl.col("split") == "train") - .select(pl.col("patient_id")) + .select(pl.col("subject_id")) .collect() .to_list() ) - read_fn = read_and_filter_fntr(train_patients, read_fn) + read_fn = read_and_filter_fntr(train_subjects, read_fn) else: raise FileNotFoundError( f"Train split requested, but shard prefixes can't be used and " - f"patient split file not found at {str(split_fp.resolve())}." + f"subject split file not found at {str(split_fp.resolve())}." ) elif includes_only_train: raise ValueError("All splits should be used, but shard iterator is returning only train splits?!?") diff --git a/src/MEDS_transforms/mapreduce/utils.py b/src/MEDS_transforms/mapreduce/utils.py index 4aa4c53..2832c00 100644 --- a/src/MEDS_transforms/mapreduce/utils.py +++ b/src/MEDS_transforms/mapreduce/utils.py @@ -315,14 +315,14 @@ def shard_iterator( >>> from tempfile import TemporaryDirectory >>> import polars as pl >>> df = pl.DataFrame({ - ... "patient_id": [1, 2, 3, 4, 5, 6, 7, 8, 9], + ... "subject_id": [1, 2, 3, 4, 5, 6, 7, 8, 9], ... "code": ["A", "B", "C", "D", "E", "F", "G", "H", "I"], ... "time": [1, 2, 3, 4, 5, 6, 1, 2, 3], ... }) >>> shards = {"train/0": [1, 2, 3, 4], "train/1": [5, 6, 7], "tuning": [8], "held_out": [9]} >>> def write_dfs(input_dir: Path, df: pl.DataFrame=df, shards: dict=shards, sfx: str=".parquet"): - ... for shard_name, patient_ids in shards.items(): - ... df = df.filter(pl.col("patient_id").is_in(patient_ids)) + ... for shard_name, subject_ids in shards.items(): + ... df = df.filter(pl.col("subject_id").is_in(subject_ids)) ... shard_fp = input_dir / f"{shard_name}{sfx}" ... shard_fp.parent.mkdir(exist_ok=True, parents=True) ... if sfx == ".parquet": df.write_parquet(shard_fp) @@ -485,7 +485,7 @@ def shard_iterator( elif train_only: logger.info( f"train_only={train_only} requested but no dedicated train shards found; processing all shards " - "and relying on `patient_splits.parquet` for filtering." + "and relying on `subject_splits.parquet` for filtering." ) shards = shuffle_shards(shards, cfg) diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index f936196..fb3358c 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -13,7 +13,7 @@ from omegaconf import DictConfig from MEDS_transforms import PREPROCESS_CONFIG_YAML -from MEDS_transforms.extract.split_and_shard_patients import shard_patients +from MEDS_transforms.extract.split_and_shard_subjects import shard_subjects from MEDS_transforms.mapreduce.utils import rwlock_wrap, shard_iterator, shuffle_shards from MEDS_transforms.utils import stage_init, write_lazyframe @@ -34,9 +34,9 @@ def make_new_shards_fn(df: pl.DataFrame, cfg: DictConfig, stage_cfg: DictConfig) for pt_id, sp in df.iter_rows(): splits_map[sp].append(pt_id) - return shard_patients( - patients=df["patient_id"].to_numpy(), - n_patients_per_shard=stage_cfg.n_patients_per_shard, + return shard_subjects( + subjects=df["subject_id"].to_numpy(), + n_subjects_per_shard=stage_cfg.n_subjects_per_shard, external_splits=splits_map, split_fracs_dict=None, seed=cfg.get("seed", 1), @@ -51,13 +51,13 @@ def write_json(d: dict, fp: Path) -> None: version_base=None, config_path=str(PREPROCESS_CONFIG_YAML.parent), config_name=PREPROCESS_CONFIG_YAML.stem ) def main(cfg: DictConfig): - """Re-shard a MEDS cohort to in a manner that subdivides patient splits.""" + """Re-shard a MEDS cohort to in a manner that subdivides subject splits.""" stage_init(cfg) output_dir = Path(cfg.stage_cfg.output_dir) - splits_file = Path(cfg.input_dir) / "metadata" / "patient_splits.parquet" + splits_file = Path(cfg.input_dir) / "metadata" / "subject_splits.parquet" shards_fp = output_dir / ".shards.json" rwlock_wrap( @@ -92,22 +92,22 @@ def main(cfg: DictConfig): logger.info("Starting sub-sharding") for subshard_name, out_fp in new_shards_iter: - patients = new_sharded_splits[subshard_name] + subjects = new_sharded_splits[subshard_name] def read_fn(input_dir: Path) -> pl.LazyFrame: df = None logger.info(f"Reading shards for {subshard_name} (file names are in the input sharding scheme):") for in_fp, _ in orig_shards_iter: logger.info(f" - {str(in_fp.relative_to(input_dir).resolve())}") - new_df = pl.scan_parquet(in_fp, glob=False).filter(pl.col("patient_id").is_in(patients)) + new_df = pl.scan_parquet(in_fp, glob=False).filter(pl.col("subject_id").is_in(subjects)) if df is None: df = new_df else: - df = df.merge_sorted(new_df, key="patient_id") + df = df.merge_sorted(new_df, key="subject_id") return df def compute_fn(df: list[pl.DataFrame]) -> pl.LazyFrame: - return df.sort(by=["patient_id", "time"], maintain_order=True, multithreaded=False) + return df.sort(by=["subject_id", "time"], maintain_order=True, multithreaded=False) def write_fn(df: pl.LazyFrame, out_fp: Path) -> None: write_lazyframe(df, out_fp) diff --git a/src/MEDS_transforms/transforms/add_time_derived_measurements.py b/src/MEDS_transforms/transforms/add_time_derived_measurements.py index 19a7abf..3f48a8c 100644 --- a/src/MEDS_transforms/transforms/add_time_derived_measurements.py +++ b/src/MEDS_transforms/transforms/add_time_derived_measurements.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -"""Transformations for adding time-derived measurements (e.g., a patient's age) to a MEDS dataset.""" +"""Transformations for adding time-derived measurements (e.g., a subject's age) to a MEDS dataset.""" from collections.abc import Callable import hydra @@ -25,7 +25,7 @@ def add_new_events_fntr(fn: Callable[[pl.DataFrame], pl.DataFrame]) -> Callable[ >>> from datetime import datetime >>> df = pl.DataFrame( ... { - ... "patient_id": [1, 1, 1, 1, 2, 2, 3, 3], + ... "subject_id": [1, 1, 1, 1, 2, 2, 3, 3], ... "time": [ ... None, ... datetime(1990, 1, 1), @@ -38,12 +38,12 @@ def add_new_events_fntr(fn: Callable[[pl.DataFrame], pl.DataFrame]) -> Callable[ ... ], ... "code": ["static", "DOB", "lab//A", "lab//B", "DOB", "lab//A", "lab//B", "dx//1"], ... }, - ... schema={"patient_id": pl.UInt32, "time": pl.Datetime, "code": pl.Utf8}, + ... schema={"subject_id": pl.UInt32, "time": pl.Datetime, "code": pl.Utf8}, ... ) >>> df shape: (8, 3) ┌────────────┬─────────────────────┬────────┐ - │ patient_id ┆ time ┆ code │ + │ subject_id ┆ time ┆ code │ │ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ str │ ╞════════════╪═════════════════════╪════════╡ @@ -62,7 +62,7 @@ def add_new_events_fntr(fn: Callable[[pl.DataFrame], pl.DataFrame]) -> Callable[ >>> age_fn(df) shape: (2, 4) ┌────────────┬─────────────────────┬──────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ str ┆ f32 │ ╞════════════╪═════════════════════╪══════╪═══════════════╡ @@ -74,7 +74,7 @@ def add_new_events_fntr(fn: Callable[[pl.DataFrame], pl.DataFrame]) -> Callable[ >>> add_age_fn(df) shape: (10, 4) ┌────────────┬─────────────────────┬────────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ str ┆ f32 │ ╞════════════╪═════════════════════╪════════╪═══════════════╡ @@ -96,7 +96,7 @@ def out_fn(df: pl.DataFrame) -> pl.DataFrame: df = df.with_row_index("__idx") new_events = new_events.with_columns(pl.lit(0, dtype=df.schema["__idx"]).alias("__idx")) return ( - pl.concat([df, new_events], how="diagonal").sort(by=["patient_id", "time", "__idx"]).drop("__idx") + pl.concat([df, new_events], how="diagonal").sort(by=["subject_id", "time", "__idx"]).drop("__idx") ) return out_fn @@ -170,7 +170,7 @@ def normalize_time_unit(unit: str) -> tuple[str, float]: def age_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: - """Create a function that adds a patient's age to a DataFrame. + """Create a function that adds a subject's age to a DataFrame. Args: cfg: The configuration for the age function. This must contain the following mandatory keys: @@ -179,8 +179,8 @@ def age_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: - "age_unit": The unit for the age event when converted to a numeric value in the output data. Returns: - A function that returns the to-be-added "age" events with the patient's age for all input events with - unique, non-null times in the data, for all patients who have an observed date of birth. It does + A function that returns the to-be-added "age" events with the subject's age for all input events with + unique, non-null times in the data, for all subjects who have an observed date of birth. It does not add an event for times that are equal to the date of birth. Raises: @@ -190,7 +190,7 @@ def age_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: >>> from datetime import datetime >>> df = pl.DataFrame( ... { - ... "patient_id": [1, 1, 1, 1, 1, 2, 2, 3, 3], + ... "subject_id": [1, 1, 1, 1, 1, 2, 2, 3, 3], ... "time": [ ... None, ... datetime(1990, 1, 1), @@ -204,12 +204,12 @@ def age_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: ... ], ... "code": ["static", "DOB", "lab//A", "lab//B", "rx", "DOB", "lab//A", "lab//B", "dx//1"], ... }, - ... schema={"patient_id": pl.UInt32, "time": pl.Datetime, "code": pl.Utf8}, + ... schema={"subject_id": pl.UInt32, "time": pl.Datetime, "code": pl.Utf8}, ... ) >>> df shape: (9, 3) ┌────────────┬─────────────────────┬────────┐ - │ patient_id ┆ time ┆ code │ + │ subject_id ┆ time ┆ code │ │ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ str │ ╞════════════╪═════════════════════╪════════╡ @@ -228,7 +228,7 @@ def age_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: >>> age_fn(df) shape: (3, 4) ┌────────────┬─────────────────────┬──────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ str ┆ f32 │ ╞════════════╪═════════════════════╪══════╪═══════════════╡ @@ -248,15 +248,15 @@ def age_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: microseconds_in_unit = int(1e6) * seconds_in_unit def fn(df: pl.LazyFrame) -> pl.LazyFrame: - dob_expr = pl.when(pl.col("code") == cfg.DOB_code).then(pl.col("time")).min().over("patient_id") + dob_expr = pl.when(pl.col("code") == cfg.DOB_code).then(pl.col("time")).min().over("subject_id") age_expr = (pl.col("time") - dob_expr).dt.total_microseconds() / microseconds_in_unit age_expr = age_expr.cast(pl.Float32, strict=False) return ( df.drop_nulls(subset=["time"]) - .unique(subset=["patient_id", "time"], maintain_order=True) + .unique(subset=["subject_id", "time"], maintain_order=True) .select( - "patient_id", + "subject_id", "time", pl.lit(cfg.age_code, dtype=df.schema["code"]).alias("code"), age_expr.alias("numeric_value"), @@ -283,7 +283,7 @@ def time_of_day_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: >>> from datetime import datetime >>> df = pl.DataFrame( ... { - ... "patient_id": [1, 1, 1, 1, 2, 2, 3, 3], + ... "subject_id": [1, 1, 1, 1, 2, 2, 3, 3], ... "time": [ ... None, ... datetime(1990, 1, 1, 1, 0), @@ -296,12 +296,12 @@ def time_of_day_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: ... ], ... "code": ["static", "DOB", "lab//A", "lab//B", "DOB", "lab//A", "lab//B", "dx//1"], ... }, - ... schema={"patient_id": pl.UInt32, "time": pl.Datetime, "code": pl.Utf8}, + ... schema={"subject_id": pl.UInt32, "time": pl.Datetime, "code": pl.Utf8}, ... ) >>> df shape: (8, 3) ┌────────────┬─────────────────────┬────────┐ - │ patient_id ┆ time ┆ code │ + │ subject_id ┆ time ┆ code │ │ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ str │ ╞════════════╪═════════════════════╪════════╡ @@ -319,7 +319,7 @@ def time_of_day_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: >>> time_of_day_fn(df) shape: (6, 3) ┌────────────┬─────────────────────┬──────────────────────┐ - │ patient_id ┆ time ┆ code │ + │ subject_id ┆ time ┆ code │ │ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ str │ ╞════════════╪═════════════════════╪══════════════════════╡ @@ -357,8 +357,8 @@ def tod_code(start: int, end: int) -> str: time_of_day = time_of_day.when(hour >= end).then(tod_code(end, 24)) return ( df.drop_nulls(subset=["time"]) - .unique(subset=["patient_id", "time"], maintain_order=True) - .select("patient_id", "time", time_of_day.alias("code")) + .unique(subset=["subject_id", "time"], maintain_order=True) + .select("subject_id", "time", time_of_day.alias("code")) ) return fn diff --git a/src/MEDS_transforms/transforms/normalization.py b/src/MEDS_transforms/transforms/normalization.py index fbad9ac..363192a 100644 --- a/src/MEDS_transforms/transforms/normalization.py +++ b/src/MEDS_transforms/transforms/normalization.py @@ -16,7 +16,7 @@ def normalize( """Normalize a MEDS dataset across both categorical and continuous dimensions. This function expects a MEDS dataset in flattened form, with columns for: - - `patient_id` + - `subject_id` - `time` - `code` - `numeric_value` @@ -61,7 +61,7 @@ def normalize( >>> from datetime import datetime >>> MEDS_df = pl.DataFrame( ... { - ... "patient_id": [1, 1, 1, 2, 2, 2, 3], + ... "subject_id": [1, 1, 1, 2, 2, 2, 3], ... "time": [ ... datetime(2021, 1, 1), ... datetime(2021, 1, 1), @@ -76,7 +76,7 @@ def normalize( ... "unit": ["mg/dL", "g/dL", None, "mg/dL", None, None, None], ... }, ... schema = { - ... "patient_id": pl.UInt32, + ... "subject_id": pl.UInt32, ... "time": pl.Datetime, ... "code": pl.Utf8, ... "numeric_value": pl.Float64, @@ -100,7 +100,7 @@ def normalize( >>> normalize(MEDS_df.lazy(), code_metadata).collect() shape: (6, 4) ┌────────────┬─────────────────────┬──────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ u32 ┆ f64 │ ╞════════════╪═════════════════════╪══════╪═══════════════╡ @@ -113,7 +113,7 @@ def normalize( └────────────┴─────────────────────┴──────┴───────────────┘ >>> MEDS_df = pl.DataFrame( ... { - ... "patient_id": [1, 1, 1, 2, 2, 2, 3], + ... "subject_id": [1, 1, 1, 2, 2, 2, 3], ... "time": [ ... datetime(2021, 1, 1), ... datetime(2021, 1, 1), @@ -128,7 +128,7 @@ def normalize( ... "unit": ["mg/dL", "g/dL", None, "mg/dL", None, None, None], ... }, ... schema = { - ... "patient_id": pl.UInt32, + ... "subject_id": pl.UInt32, ... "time": pl.Datetime, ... "code": pl.Utf8, ... "numeric_value": pl.Float64, @@ -154,7 +154,7 @@ def normalize( >>> normalize(MEDS_df.lazy(), code_metadata, ["unit"]).collect() shape: (6, 4) ┌────────────┬─────────────────────┬──────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ u32 ┆ f64 │ ╞════════════╪═════════════════════╪══════╪═══════════════╡ @@ -201,7 +201,7 @@ def normalize( ) .select( idx_col, - "patient_id", + "subject_id", "time", pl.col("code/vocab_index").alias("code"), ((pl.col("numeric_value") - pl.col("values/mean")) / pl.col("values/std")).alias("numeric_value"), diff --git a/src/MEDS_transforms/transforms/occlude_outliers.py b/src/MEDS_transforms/transforms/occlude_outliers.py index 107407d..528977b 100644 --- a/src/MEDS_transforms/transforms/occlude_outliers.py +++ b/src/MEDS_transforms/transforms/occlude_outliers.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -"""A polars-to-polars transformation function for filtering patients by sequence length.""" +"""A polars-to-polars transformation function for filtering subjects by sequence length.""" from collections.abc import Callable import hydra @@ -13,7 +13,7 @@ def occlude_outliers_fntr( stage_cfg: DictConfig, code_metadata: pl.LazyFrame, code_modifiers: list[str] | None = None ) -> Callable[[pl.LazyFrame], pl.LazyFrame]: - """Filters patient events to only encompass those with a set of permissible codes. + """Filters subject events to only encompass those with a set of permissible codes. Args: df: The input DataFrame. @@ -33,7 +33,7 @@ def occlude_outliers_fntr( ... # for clarity: --- stddev = [3.0, 0.0, 3.0, 1.0] ... }) >>> data = pl.DataFrame({ - ... "patient_id": [1, 1, 2, 2], + ... "subject_id": [1, 1, 2, 2], ... "code": ["A", "B", "A", "C"], ... "modifier1": [1, 1, 2, 2], ... # for clarity: mean [0.0, 4.0, 4.0, 1.0] @@ -45,7 +45,7 @@ def occlude_outliers_fntr( >>> fn(data).collect() shape: (4, 5) ┌────────────┬──────┬───────────┬───────────────┬─────────────────────────┐ - │ patient_id ┆ code ┆ modifier1 ┆ numeric_value ┆ numeric_value/is_inlier │ + │ subject_id ┆ code ┆ modifier1 ┆ numeric_value ┆ numeric_value/is_inlier │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ i64 ┆ f64 ┆ bool │ ╞════════════╪══════╪═══════════╪═══════════════╪═════════════════════════╡ @@ -78,7 +78,7 @@ def occlude_outliers_fntr( code_metadata = code_metadata.lazy().select(cols_to_select) def occlude_outliers_fn(df: pl.LazyFrame) -> pl.LazyFrame: - f"""Filters out outlier numeric values from patient events. + f"""Filters out outlier numeric values from subject events. In particular, this function filters the DataFrame to only include numeric values that are within {stddev_cutoff} standard deviations of the mean for the corresponding (code, modifier) pair. diff --git a/src/MEDS_transforms/transforms/reorder_measurements.py b/src/MEDS_transforms/transforms/reorder_measurements.py index 1205f77..5218551 100644 --- a/src/MEDS_transforms/transforms/reorder_measurements.py +++ b/src/MEDS_transforms/transforms/reorder_measurements.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -"""A polars-to-polars transformation function for filtering patients by sequence length.""" +"""A polars-to-polars transformation function for filtering subjects by sequence length.""" from collections.abc import Callable import hydra @@ -19,7 +19,7 @@ def reorder_by_code_fntr( Args: stage_cfg: The stage-specific configuration object which contains the `ordered_code_patterns` field - that defines the order of the codes within each patient event (unique timepoint). Each element of + that defines the order of the codes within each subject event (unique timepoint). Each element of this list should be a regex pattern that matches codes that should be re-ordered at the index of the regex pattern in the list. Codes are matched in the order of the list, and if a code matches multiple regex patterns, it will be ordered by the first regex pattern that matches it. @@ -34,7 +34,7 @@ def reorder_by_code_fntr( Examples: >>> code_metadata_df = pl.DataFrame({"code": ["A", "A", "B", "C"], "modifier1": [1, 2, 1, 2]}) >>> data = pl.DataFrame({ - ... "patient_id":[1, 1, 2, 2], "time": [1, 1, 1, 1], + ... "subject_id":[1, 1, 2, 2], "time": [1, 1, 1, 1], ... "code": ["A", "B", "A", "C"], "modifier1": [1, 2, 1, 2] ... }) >>> stage_cfg = DictConfig({"ordered_code_patterns": ["B", "A"]}) @@ -42,7 +42,7 @@ def reorder_by_code_fntr( >>> fn(data.lazy()).collect() shape: (4, 4) ┌────────────┬──────┬──────┬───────────┐ - │ patient_id ┆ time ┆ code ┆ modifier1 │ + │ subject_id ┆ time ┆ code ┆ modifier1 │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str ┆ i64 │ ╞════════════╪══════╪══════╪═══════════╡ @@ -55,7 +55,7 @@ def reorder_by_code_fntr( ... "code": ["LAB//foo", "ADMISSION//bar", "LAB//baz", "ADMISSION//qux", "DISCHARGE"], ... }) >>> data = pl.DataFrame({ - ... "patient_id":[1, 1, 1, 2, 2, 2], + ... "subject_id":[1, 1, 1, 2, 2, 2], ... "time": [1, 1, 1, 1, 2, 3], ... "code": ["LAB//foo", "ADMISSION//bar", "LAB//baz", "ADMISSION//qux", "DISCHARGE", "LAB//baz"], ... }) @@ -66,7 +66,7 @@ def reorder_by_code_fntr( >>> fn(data.lazy()).collect() shape: (6, 3) ┌────────────┬──────┬────────────────┐ - │ patient_id ┆ time ┆ code │ + │ subject_id ┆ time ┆ code │ │ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str │ ╞════════════╪══════╪════════════════╡ @@ -81,7 +81,7 @@ def reorder_by_code_fntr( >>> fn(data.lazy()).collect() shape: (6, 3) ┌────────────┬──────┬────────────────┐ - │ patient_id ┆ time ┆ code │ + │ subject_id ┆ time ┆ code │ │ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ str │ ╞════════════╪══════╪════════════════╡ @@ -141,7 +141,7 @@ def reorder_fn(df: pl.LazyFrame) -> pl.LazyFrame: return ( df.join(code_indices, on=join_cols, how="left", coalesce=True) - .sort("patient_id", "time", "code_order_idx", maintain_order=True) + .sort("subject_id", "time", "code_order_idx", maintain_order=True) .drop("code_order_idx") ) @@ -152,13 +152,13 @@ def reorder_fn(df: pl.LazyFrame) -> pl.LazyFrame: version_base=None, config_path=str(PREPROCESS_CONFIG_YAML.parent), config_name=PREPROCESS_CONFIG_YAML.stem ) def main(cfg: DictConfig): - """Reorders measurements within each patient event (unique timepoint) by the specified code order. + """Reorders measurements within each subject event (unique timepoint) by the specified code order. In particular, given a set of [regex crate compatible](https://docs.rs/regex/latest/regex/) regexes in the `stage_cfg.ordered_code_patterns` list, this script will re-order the measurements within each event (unique timepoint) such that the measurements are sorted by the index of the first regex that matches their code in the `ordered_code_patterns` list. So, if the `ordered_code_patterns` list is - `["foo$", "bar", "foo.*"]`, and a single patient event has measurements with codes + `["foo$", "bar", "foo.*"]`, and a single subject event has measurements with codes `["foobar", "barbaz", "foo", "quat"]`, the measurements will be re-ordered to the order: `["foo", "foobar", "barbaz", "quat"]`, because: - "foo" matches the first regex in the list (the `foo$` matches any string with "foo" at the end). @@ -176,7 +176,7 @@ def main(cfg: DictConfig): Args: stage_configs.reorder_measurements.ordered_code_patterns: A list of regex patterns that specify the - order of the codes within each patient event (unique timepoint). To specify this on the command + order of the codes within each subject event (unique timepoint). To specify this on the command line, use the hydra list syntax by enclosing the entire key-value string argument in single quotes: ``'stage_configs.reorder_measurements.ordered_code_patterns=["foo$", "bar", "foo.*"]'``. """ diff --git a/src/MEDS_transforms/transforms/tensorization.py b/src/MEDS_transforms/transforms/tensorization.py index 0266ce2..92fa7e0 100644 --- a/src/MEDS_transforms/transforms/tensorization.py +++ b/src/MEDS_transforms/transforms/tensorization.py @@ -32,7 +32,7 @@ def convert_to_NRT(df: pl.LazyFrame) -> JointNestedRaggedTensorDict: Examples: >>> df = pl.DataFrame({ - ... "patient_id": [1, 2], + ... "subject_id": [1, 2], ... "time_delta_days": [[float("nan"), 12.0], [float("nan")]], ... "code": [[[101.0, 102.0], [103.0]], [[201.0, 202.0]]], ... "numeric_value": [[[2.0, 3.0], [4.0]], [[6.0, 7.0]]] @@ -40,7 +40,7 @@ def convert_to_NRT(df: pl.LazyFrame) -> JointNestedRaggedTensorDict: >>> df shape: (2, 4) ┌────────────┬─────────────────┬───────────────────────────┬─────────────────────┐ - │ patient_id ┆ time_delta_days ┆ code ┆ numeric_value │ + │ subject_id ┆ time_delta_days ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ list[f64] ┆ list[list[f64]] ┆ list[list[f64]] │ ╞════════════╪═════════════════╪═══════════════════════════╪═════════════════════╡ diff --git a/src/MEDS_transforms/transforms/tokenization.py b/src/MEDS_transforms/transforms/tokenization.py index d6f5003..31cb264 100644 --- a/src/MEDS_transforms/transforms/tokenization.py +++ b/src/MEDS_transforms/transforms/tokenization.py @@ -6,7 +6,7 @@ All these functions take in _normalized_ data -- meaning data where there are _no longer_ any code modifiers, as those have been normalized alongside codes into integer indices (in the output code column). The only -columns of concern here thus are `patient_id`, `time`, `code`, `numeric_value`. +columns of concern here thus are `subject_id`, `time`, `code`, `numeric_value`. """ from pathlib import Path @@ -69,7 +69,7 @@ def split_static_and_dynamic(df: pl.LazyFrame) -> tuple[pl.LazyFrame, pl.LazyFra Examples: >>> from datetime import datetime >>> df = pl.DataFrame({ - ... "patient_id": [1, 1, 2, 2], + ... "subject_id": [1, 1, 2, 2], ... "time": [None, datetime(2021, 1, 1), None, datetime(2021, 1, 2)], ... "code": [100, 101, 200, 201], ... "numeric_value": [1.0, 2.0, 3.0, 4.0] @@ -78,7 +78,7 @@ def split_static_and_dynamic(df: pl.LazyFrame) -> tuple[pl.LazyFrame, pl.LazyFra >>> static.collect() shape: (2, 3) ┌────────────┬──────┬───────────────┐ - │ patient_id ┆ code ┆ numeric_value │ + │ subject_id ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ f64 │ ╞════════════╪══════╪═══════════════╡ @@ -88,7 +88,7 @@ def split_static_and_dynamic(df: pl.LazyFrame) -> tuple[pl.LazyFrame, pl.LazyFra >>> dynamic.collect() shape: (2, 4) ┌────────────┬─────────────────────┬──────┬───────────────┐ - │ patient_id ┆ time ┆ code ┆ numeric_value │ + │ subject_id ┆ time ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ datetime[μs] ┆ i64 ┆ f64 │ ╞════════════╪═════════════════════╪══════╪═══════════════╡ @@ -103,19 +103,19 @@ def split_static_and_dynamic(df: pl.LazyFrame) -> tuple[pl.LazyFrame, pl.LazyFra def extract_statics_and_schema(df: pl.LazyFrame) -> pl.LazyFrame: - """This function extracts static data and schema information (sequence of patient unique times). + """This function extracts static data and schema information (sequence of subject unique times). Args: df: The input data. Returns: - A `pl.LazyFrame` object containing the static data and the unique times of the patient, grouped - by patient as lists, in the same order as the patient IDs occurred in the original file. + A `pl.LazyFrame` object containing the static data and the unique times of the subject, grouped + by subject as lists, in the same order as the subject IDs occurred in the original file. Examples: >>> from datetime import datetime >>> df = pl.DataFrame({ - ... "patient_id": [1, 1, 1, 1, 2, 2, 2], + ... "subject_id": [1, 1, 1, 1, 2, 2, 2], ... "time": [ ... None, datetime(2021, 1, 1), datetime(2021, 1, 1), datetime(2021, 1, 13), ... None, datetime(2021, 1, 2), datetime(2021, 1, 2)], @@ -126,17 +126,17 @@ def extract_statics_and_schema(df: pl.LazyFrame) -> pl.LazyFrame: >>> df.drop("time") shape: (2, 4) ┌────────────┬───────────┬───────────────┬─────────────────────┐ - │ patient_id ┆ code ┆ numeric_value ┆ start_time │ + │ subject_id ┆ code ┆ numeric_value ┆ start_time │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ list[i64] ┆ list[f64] ┆ datetime[μs] │ ╞════════════╪═══════════╪═══════════════╪═════════════════════╡ │ 1 ┆ [100] ┆ [1.0] ┆ 2021-01-01 00:00:00 │ │ 2 ┆ [200] ┆ [5.0] ┆ 2021-01-02 00:00:00 │ └────────────┴───────────┴───────────────┴─────────────────────┘ - >>> df.select("patient_id", "time").explode("time") + >>> df.select("subject_id", "time").explode("time") shape: (3, 2) ┌────────────┬─────────────────────┐ - │ patient_id ┆ time │ + │ subject_id ┆ time │ │ --- ┆ --- │ │ i64 ┆ datetime[μs] │ ╞════════════╪═════════════════════╡ @@ -148,21 +148,21 @@ def extract_statics_and_schema(df: pl.LazyFrame) -> pl.LazyFrame: static, dynamic = split_static_and_dynamic(df) - # This collects static data by patient ID and stores only (as a list) the codes and numeric values. - static_by_patient = static.group_by("patient_id", maintain_order=True).agg("code", "numeric_value") + # This collects static data by subject ID and stores only (as a list) the codes and numeric values. + static_by_subject = static.group_by("subject_id", maintain_order=True).agg("code", "numeric_value") - # This collects the unique times for each patient. - schema_by_patient = dynamic.group_by("patient_id", maintain_order=True).agg( + # This collects the unique times for each subject. + schema_by_subject = dynamic.group_by("subject_id", maintain_order=True).agg( pl.col("time").min().alias("start_time"), pl.col("time").unique(maintain_order=True) ) - # TODO(mmd): Consider tracking patient offset explicitly here. + # TODO(mmd): Consider tracking subject offset explicitly here. - return static_by_patient.join(schema_by_patient, on="patient_id", how="inner") + return static_by_subject.join(schema_by_subject, on="subject_id", how="inner") -def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: - """This function extracts sequences of patient events, which are sequences of measurements. +def extract_seq_of_subject_events(df: pl.LazyFrame) -> pl.LazyFrame: + """This function extracts sequences of subject events, which are sequences of measurements. The result of this can be naturally tensorized into a `JointNestedRaggedTensorDict` object. @@ -170,8 +170,8 @@ def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: df: The input data. Returns: - A `pl.LazyFrame` object containing the sequences of patient events, with the following columns: - - `patient_id`: The patient ID. + A `pl.LazyFrame` object containing the sequences of subject events, with the following columns: + - `subject_id`: The subject ID. - `time_delta_days`: The time delta in days, as a list of floats (ragged). - `code`: The code, as a list of lists of ints (ragged in both levels). - `numeric_value`: The numeric value as a list of lists of floats (ragged in both levels). @@ -179,17 +179,17 @@ def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: Examples: >>> from datetime import datetime >>> df = pl.DataFrame({ - ... "patient_id": [1, 1, 1, 1, 2, 2, 2], + ... "subject_id": [1, 1, 1, 1, 2, 2, 2], ... "time": [ ... None, datetime(2021, 1, 1), datetime(2021, 1, 1), datetime(2021, 1, 13), ... None, datetime(2021, 1, 2), datetime(2021, 1, 2)], ... "code": [100, 101, 102, 103, 200, 201, 202], ... "numeric_value": pl.Series([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], dtype=pl.Float32) ... }).lazy() - >>> extract_seq_of_patient_events(df).collect() + >>> extract_seq_of_subject_events(df).collect() shape: (2, 4) ┌────────────┬─────────────────┬─────────────────────┬─────────────────────┐ - │ patient_id ┆ time_delta_days ┆ code ┆ numeric_value │ + │ subject_id ┆ time_delta_days ┆ code ┆ numeric_value │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ list[f32] ┆ list[list[i64]] ┆ list[list[f32]] │ ╞════════════╪═════════════════╪═════════════════════╪═════════════════════╡ @@ -203,9 +203,9 @@ def extract_seq_of_patient_events(df: pl.LazyFrame) -> pl.LazyFrame: time_delta_days_expr = (pl.col("time").diff().dt.total_seconds() / SECONDS_PER_DAY).cast(pl.Float32) return ( - dynamic.group_by("patient_id", "time", maintain_order=True) + dynamic.group_by("subject_id", "time", maintain_order=True) .agg(pl.col("code").name.keep(), fill_to_nans("numeric_value").name.keep()) - .group_by("patient_id", maintain_order=True) + .group_by("subject_id", maintain_order=True) .agg( fill_to_nans(time_delta_days_expr).alias("time_delta_days"), "code", @@ -258,7 +258,7 @@ def main(cfg: DictConfig): event_seq_out_fp, pl.scan_parquet, write_lazyframe, - extract_seq_of_patient_events, + extract_seq_of_subject_events, do_overwrite=cfg.do_overwrite, ) diff --git a/src/MEDS_transforms/utils.py b/src/MEDS_transforms/utils.py index 59a7cd6..b62f7d1 100644 --- a/src/MEDS_transforms/utils.py +++ b/src/MEDS_transforms/utils.py @@ -412,15 +412,15 @@ def is_col_field(field: str | None) -> bool: bool: True if the field is formatted as "col(column_name)", False otherwise. Examples: - >>> is_col_field("col(patient_id)") + >>> is_col_field("col(subject_id)") True - >>> is_col_field("col(patient_id") + >>> is_col_field("col(subject_id") False - >>> is_col_field("patient_id)") + >>> is_col_field("subject_id)") False - >>> is_col_field("column(patient_id)") + >>> is_col_field("column(subject_id)") False - >>> is_col_field("patient_id") + >>> is_col_field("subject_id") False >>> is_col_field(None) False @@ -440,16 +440,16 @@ def parse_col_field(field: str) -> str: ValueError: If the input string does not match the expected format. Examples: - >>> parse_col_field("col(patient_id)") - 'patient_id' - >>> parse_col_field("col(patient_id") + >>> parse_col_field("col(subject_id)") + 'subject_id' + >>> parse_col_field("col(subject_id") Traceback (most recent call last): ... - ValueError: Invalid column field: col(patient_id - >>> parse_col_field("column(patient_id)") + ValueError: Invalid column field: col(subject_id + >>> parse_col_field("column(subject_id)") Traceback (most recent call last): ... - ValueError: Invalid column field: column(patient_id) + ValueError: Invalid column field: column(subject_id) """ if not is_col_field(field): raise ValueError(f"Invalid column field: {field}") diff --git a/tests/test_add_time_derived_measurements.py b/tests/test_add_time_derived_measurements.py index e5653a1..964cad9 100644 --- a/tests/test_add_time_derived_measurements.py +++ b/tests/test_add_time_derived_measurements.py @@ -4,6 +4,7 @@ scripts. """ +from meds import subject_id_field from .transform_tester_base import ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, single_stage_transform_tester from .utils import parse_meds_csvs @@ -96,8 +97,8 @@ ``` """ -WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value +WANT_TRAIN_0 = f""" +{subject_id_field},time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00","TIME_OF_DAY//[00,06)", @@ -156,9 +157,9 @@ 1195293,"06/20/2010, 20:50:04",DISCHARGE, """ -# All patients in this shard had only 4 events. -WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value +# All subjects in this shard had only 4 events. +WANT_TRAIN_1 = f""" +{subject_id_field},time,code,numeric_value 68729,,EYE_COLOR//HAZEL, 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00","TIME_OF_DAY//[00,06)", @@ -185,8 +186,8 @@ 814703,"02/05/2010, 07:02:30",DISCHARGE, """ -WANT_TUNING_0 = """ -patient_id,time,code,numeric_value +WANT_TUNING_0 = f""" +{subject_id_field},time,code,numeric_value 754281,,EYE_COLOR//BROWN, 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00","TIME_OF_DAY//[00,06)", @@ -201,8 +202,8 @@ 754281,"01/03/2010, 08:22:13",DISCHARGE, """ -WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +WANT_HELD_OUT_0 = f""" +{subject_id_field},time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00","TIME_OF_DAY//[00,06)", diff --git a/tests/test_aggregate_code_metadata.py b/tests/test_aggregate_code_metadata.py index 21698cb..7d3d2a4 100644 --- a/tests/test_aggregate_code_metadata.py +++ b/tests/test_aggregate_code_metadata.py @@ -13,7 +13,7 @@ ) WANT_OUTPUT_CODE_METADATA_FILE = """ -code,code/n_occurrences,code/n_patients,values/n_occurrences,values/n_patients,values/sum,values/sum_sqd,values/n_ints,values/min,values/max,description,parent_codes +code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/n_subjects,values/sum,values/sum_sqd,values/n_ints,values/min,values/max,description,parent_codes ,44,4,28,4,3198.8389005974336,382968.28937288234,6,86.0,175.271118,, ADMISSION//CARDIAC,2,2,0,0,0,0,0,,,, ADMISSION//ORTHOPEDIC,1,1,0,0,0,0,0,,,, @@ -45,9 +45,9 @@ "TEMP", ], "code/n_occurrences": [44, 2, 1, 1, 4, 4, 1, 1, 2, 4, 12, 12], - "code/n_patients": [4, 2, 1, 1, 4, 4, 1, 1, 2, 4, 4, 4], + "code/n_subjects": [4, 2, 1, 1, 4, 4, 1, 1, 2, 4, 4, 4], "values/n_occurrences": [28, 0, 0, 0, 0, 0, 0, 0, 0, 4, 12, 12], - "values/n_patients": [4, 0, 0, 0, 0, 0, 0, 0, 0, 4, 4, 4], + "values/n_subjects": [4, 0, 0, 0, 0, 0, 0, 0, 0, 4, 4, 4], "values/sum": [ 3198.8389005974336, 0, @@ -163,9 +163,9 @@ AGGREGATIONS = [ "code/n_occurrences", - "code/n_patients", + "code/n_subjects", "values/n_occurrences", - "values/n_patients", + "values/n_subjects", "values/sum", "values/sum_sqd", "values/n_ints", diff --git a/tests/test_extract.py b/tests/test_extract.py index d8a3c3e..787d5d8 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -15,7 +15,7 @@ if os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1": SHARD_EVENTS_SCRIPT = extraction_root / "shard_events.py" - SPLIT_AND_SHARD_SCRIPT = extraction_root / "split_and_shard_patients.py" + SPLIT_AND_SHARD_SCRIPT = extraction_root / "split_and_shard_subjects.py" CONVERT_TO_SHARDED_EVENTS_SCRIPT = extraction_root / "convert_to_sharded_events.py" MERGE_TO_MEDS_COHORT_SCRIPT = extraction_root / "merge_to_MEDS_cohort.py" EXTRACT_CODE_METADATA_SCRIPT = extraction_root / "extract_code_metadata.py" @@ -23,7 +23,7 @@ FINALIZE_METADATA_SCRIPT = extraction_root / "finalize_MEDS_metadata.py" else: SHARD_EVENTS_SCRIPT = "MEDS_extract-shard_events" - SPLIT_AND_SHARD_SCRIPT = "MEDS_extract-split_and_shard_patients" + SPLIT_AND_SHARD_SCRIPT = "MEDS_extract-split_and_shard_subjects" CONVERT_TO_SHARDED_EVENTS_SCRIPT = "MEDS_extract-convert_to_sharded_events" MERGE_TO_MEDS_COHORT_SCRIPT = "MEDS_extract-merge_to_MEDS_cohort" EXTRACT_CODE_METADATA_SCRIPT = "MEDS_extract-extract_code_metadata" @@ -53,7 +53,7 @@ """ ADMIT_VITALS_CSV = """ -patient_id,admit_date,disch_date,department,vitals_date,HR,temp +subject_id,admit_date,disch_date,department,vitals_date,HR,temp 239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:57:18",112.6,95.5 754281,"01/03/2010, 06:27:59","01/03/2010, 08:22:13",PULMONARY,"01/03/2010, 06:27:59",142.0,99.8 814703,"02/05/2010, 05:55:39","02/05/2010, 07:02:30",ORTHOPEDIC,"02/05/2010, 05:55:39",170.2,100.1 @@ -88,7 +88,7 @@ EVENT_CFGS_YAML = """ subjects: - patient_id_col: MRN + subject_id_col: MRN eye_color: code: - EYE_COLOR @@ -144,9 +144,9 @@ "held_out/0": [1500733], } -PATIENT_SPLITS_DF = pl.DataFrame( +SUBJECT_SPLITS_DF = pl.DataFrame( { - "patient_id": [239684, 1195293, 68729, 814703, 754281, 1500733], + "subject_id": [239684, 1195293, 68729, 814703, 754281, 1500733], "split": ["train", "train", "train", "train", "tuning", "held_out"], } ) @@ -156,17 +156,17 @@ def get_expected_output(df: str) -> pl.DataFrame: return ( pl.read_csv(source=StringIO(df)) .select( - "patient_id", + "subject_id", pl.col("time").str.strptime(pl.Datetime, "%m/%d/%Y, %H:%M:%S").alias("time"), pl.col("code"), "numeric_value", ) - .sort(by=["patient_id", "time"]) + .sort(by=["subject_id", "time"]) ) MEDS_OUTPUT_TRAIN_0_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -176,7 +176,7 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TRAIN_0_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, 239684,"05/11/2010, 17:41:51",HR,102.6 239684,"05/11/2010, 17:41:51",TEMP,96.0 @@ -204,7 +204,7 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TRAIN_1_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,,EYE_COLOR//HAZEL, 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00",DOB, @@ -214,7 +214,7 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TRAIN_1_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, 68729,"05/26/2010, 02:30:56",HR,86.0 68729,"05/26/2010, 02:30:56",TEMP,97.8 @@ -226,14 +226,14 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TUNING_0_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,,EYE_COLOR//BROWN, 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, """ MEDS_OUTPUT_TUNING_0_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, 754281,"01/03/2010, 06:27:59",HR,142.0 754281,"01/03/2010, 06:27:59",TEMP,99.8 @@ -241,14 +241,14 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_HELD_OUT_0_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, """ MEDS_OUTPUT_HELD_OUT_0_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, 1500733,"06/03/2010, 14:54:38",HR,91.4 1500733,"06/03/2010, 14:54:38",TEMP,100.0 @@ -338,19 +338,19 @@ def test_extraction(): # Run the extraction script # 1. Sub-shard the data (this will be a null operation in this case, but it is worth doing just in # case. - # 2. Collect the patient splits. - # 3. Extract the events and sub-shard by patient. + # 2. Collect the subject splits. + # 3. Extract the events and sub-shard by subject. # 4. Merge to the final output. extraction_config_kwargs = { "input_dir": str(raw_cohort_dir.resolve()), "cohort_dir": str(MEDS_cohort_dir.resolve()), "event_conversion_config_fp": str(event_cfgs_yaml.resolve()), - "stage_configs.split_and_shard_patients.split_fracs.train": 4 / 6, - "stage_configs.split_and_shard_patients.split_fracs.tuning": 1 / 6, - "stage_configs.split_and_shard_patients.split_fracs.held_out": 1 / 6, + "stage_configs.split_and_shard_subjects.split_fracs.train": 4 / 6, + "stage_configs.split_and_shard_subjects.split_fracs.tuning": 1 / 6, + "stage_configs.split_and_shard_subjects.split_fracs.held_out": 1 / 6, "stage_configs.shard_events.row_chunksize": 10, - "stage_configs.split_and_shard_patients.n_patients_per_shard": 2, + "stage_configs.split_and_shard_subjects.n_subjects_per_shard": 2, "hydra.verbose": True, "etl_metadata.dataset_name": "TEST", "etl_metadata.dataset_version": "1.0", @@ -405,11 +405,11 @@ def test_extraction(): check_row_order=False, ) - # Stage 2: Collect the patient splits + # Stage 2: Collect the subject splits stderr, stdout = run_command( SPLIT_AND_SHARD_SCRIPT, extraction_config_kwargs, - "split_and_shard_patients", + "split_and_shard_subjects", ) all_stderrs.append(stderr) @@ -435,12 +435,12 @@ def test_extraction(): "NEEDING TO BE UPDATED." ) except AssertionError as e: - print("Failed to split patients") + print("Failed to split subjects") print(f"stderr:\n{stderr}") print(f"stdout:\n{stdout}") raise e - # Stage 3: Extract the events and sub-shard by patient + # Stage 3: Extract the events and sub-shard by subject stderr, stdout = run_command( CONVERT_TO_SHARDED_EVENTS_SCRIPT, extraction_config_kwargs, @@ -449,8 +449,8 @@ def test_extraction(): all_stderrs.append(stderr) all_stdouts.append(stdout) - patient_subsharded_folder = MEDS_cohort_dir / "convert_to_sharded_events" - assert patient_subsharded_folder.is_dir(), f"Expected {patient_subsharded_folder} to be a directory." + subject_subsharded_folder = MEDS_cohort_dir / "convert_to_sharded_events" + assert subject_subsharded_folder.is_dir(), f"Expected {subject_subsharded_folder} to be a directory." for split, expected_outputs in SUB_SHARDED_OUTPUTS.items(): for prefix, expected_df_L in expected_outputs.items(): @@ -459,7 +459,7 @@ def test_extraction(): expected_df = pl.concat([get_expected_output(df) for df in expected_df_L]) - fps = list((patient_subsharded_folder / split / prefix).glob("*.parquet")) + fps = list((subject_subsharded_folder / split / prefix).glob("*.parquet")) assert len(fps) > 0 # We add a "unique" here as there may be some duplicates across the row-group sub-shards. @@ -511,12 +511,12 @@ def test_extraction(): check_row_order=False, ) - assert got_df["patient_id"].is_sorted(), f"Patient IDs should be sorted for split {split}." + assert got_df["subject_id"].is_sorted(), f"Subject IDs should be sorted for split {split}." for subj in splits[split]: - got_df_subj = got_df.filter(pl.col("patient_id") == subj) + got_df_subj = got_df.filter(pl.col("subject_id") == subj) assert got_df_subj[ "time" - ].is_sorted(), f"Times should be sorted for patient {subj} in split {split}." + ].is_sorted(), f"Times should be sorted for subject {subj} in split {split}." except AssertionError as e: print(f"Failed on split {split}") @@ -593,12 +593,12 @@ def test_extraction(): check_row_order=False, ) - assert got_df["patient_id"].is_sorted(), f"Patient IDs should be sorted for split {split}." + assert got_df["subject_id"].is_sorted(), f"Subject IDs should be sorted for split {split}." for subj in splits[split]: - got_df_subj = got_df.filter(pl.col("patient_id") == subj) + got_df_subj = got_df.filter(pl.col("subject_id") == subj) assert got_df_subj[ "time" - ].is_sorted(), f"Times should be sorted for patient {subj} in split {split}." + ].is_sorted(), f"Times should be sorted for subject {subj} in split {split}." except AssertionError as e: print(f"Failed on split {split}") @@ -651,14 +651,14 @@ def test_extraction(): assert got_json == MEDS_OUTPUT_DATASET_METADATA_JSON, f"Dataset metadata differs: {got_json}" # Check the splits parquet - output_file = MEDS_cohort_dir / "metadata" / "patient_splits.parquet" + output_file = MEDS_cohort_dir / "metadata" / "subject_splits.parquet" assert output_file.is_file(), f"Expected {output_file} to exist: stderr:\n{stderr}\nstdout:\n{stdout}" got_df = pl.read_parquet(output_file, glob=False, use_pyarrow=True) assert_df_equal( - PATIENT_SPLITS_DF, + SUBJECT_SPLITS_DF, got_df, - "Patient splits should be equal to the expected splits.", + "Subject splits should be equal to the expected splits.", check_column_order=False, check_row_order=False, ) diff --git a/tests/test_extract_no_metadata.py b/tests/test_extract_no_metadata.py index f1945af..2391a97 100644 --- a/tests/test_extract_no_metadata.py +++ b/tests/test_extract_no_metadata.py @@ -15,7 +15,7 @@ if os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1": SHARD_EVENTS_SCRIPT = extraction_root / "shard_events.py" - SPLIT_AND_SHARD_SCRIPT = extraction_root / "split_and_shard_patients.py" + SPLIT_AND_SHARD_SCRIPT = extraction_root / "split_and_shard_subjects.py" CONVERT_TO_SHARDED_EVENTS_SCRIPT = extraction_root / "convert_to_sharded_events.py" MERGE_TO_MEDS_COHORT_SCRIPT = extraction_root / "merge_to_MEDS_cohort.py" EXTRACT_CODE_METADATA_SCRIPT = extraction_root / "extract_code_metadata.py" @@ -23,7 +23,7 @@ FINALIZE_METADATA_SCRIPT = extraction_root / "finalize_MEDS_metadata.py" else: SHARD_EVENTS_SCRIPT = "MEDS_extract-shard_events" - SPLIT_AND_SHARD_SCRIPT = "MEDS_extract-split_and_shard_patients" + SPLIT_AND_SHARD_SCRIPT = "MEDS_extract-split_and_shard_subjects" CONVERT_TO_SHARDED_EVENTS_SCRIPT = "MEDS_extract-convert_to_sharded_events" MERGE_TO_MEDS_COHORT_SCRIPT = "MEDS_extract-merge_to_MEDS_cohort" EXTRACT_CODE_METADATA_SCRIPT = "MEDS_extract-extract_code_metadata" @@ -53,7 +53,7 @@ """ ADMIT_VITALS_CSV = """ -patient_id,admit_date,disch_date,department,vitals_date,HR,temp +subject_id,admit_date,disch_date,department,vitals_date,HR,temp 239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:57:18",112.6,95.5 754281,"01/03/2010, 06:27:59","01/03/2010, 08:22:13",PULMONARY,"01/03/2010, 06:27:59",142.0,99.8 814703,"02/05/2010, 05:55:39","02/05/2010, 07:02:30",ORTHOPEDIC,"02/05/2010, 05:55:39",170.2,100.1 @@ -88,7 +88,7 @@ EVENT_CFGS_YAML = """ subjects: - patient_id_col: MRN + subject_id_col: MRN eye_color: code: - EYE_COLOR @@ -133,9 +133,9 @@ "held_out/0": [1500733], } -PATIENT_SPLITS_DF = pl.DataFrame( +SUBJECT_SPLITS_DF = pl.DataFrame( { - "patient_id": [239684, 1195293, 68729, 814703, 754281, 1500733], + "subject_id": [239684, 1195293, 68729, 814703, 754281, 1500733], "split": ["train", "train", "train", "train", "tuning", "held_out"], } ) @@ -145,17 +145,17 @@ def get_expected_output(df: str) -> pl.DataFrame: return ( pl.read_csv(source=StringIO(df)) .select( - "patient_id", + "subject_id", pl.col("time").str.strptime(pl.Datetime, "%m/%d/%Y, %H:%M:%S").alias("time"), pl.col("code"), "numeric_value", ) - .sort(by=["patient_id", "time"]) + .sort(by=["subject_id", "time"]) ) MEDS_OUTPUT_TRAIN_0_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -165,7 +165,7 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TRAIN_0_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, 239684,"05/11/2010, 17:41:51",HR,102.6 239684,"05/11/2010, 17:41:51",TEMP,96.0 @@ -193,7 +193,7 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TRAIN_1_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,,EYE_COLOR//HAZEL, 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00",DOB, @@ -203,7 +203,7 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TRAIN_1_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, 68729,"05/26/2010, 02:30:56",HR,86.0 68729,"05/26/2010, 02:30:56",TEMP,97.8 @@ -215,14 +215,14 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_TUNING_0_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,,EYE_COLOR//BROWN, 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, """ MEDS_OUTPUT_TUNING_0_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, 754281,"01/03/2010, 06:27:59",HR,142.0 754281,"01/03/2010, 06:27:59",TEMP,99.8 @@ -230,14 +230,14 @@ def get_expected_output(df: str) -> pl.DataFrame: """ MEDS_OUTPUT_HELD_OUT_0_SUBJECTS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, """ MEDS_OUTPUT_HELD_OUT_0_ADMIT_VITALS = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, 1500733,"06/03/2010, 14:54:38",HR,91.4 1500733,"06/03/2010, 14:54:38",TEMP,100.0 @@ -322,19 +322,19 @@ def test_extraction(): # Run the extraction script # 1. Sub-shard the data (this will be a null operation in this case, but it is worth doing just in # case. - # 2. Collect the patient splits. - # 3. Extract the events and sub-shard by patient. + # 2. Collect the subject splits. + # 3. Extract the events and sub-shard by subject. # 4. Merge to the final output. extraction_config_kwargs = { "input_dir": str(raw_cohort_dir.resolve()), "cohort_dir": str(MEDS_cohort_dir.resolve()), "event_conversion_config_fp": str(event_cfgs_yaml.resolve()), - "stage_configs.split_and_shard_patients.split_fracs.train": 4 / 6, - "stage_configs.split_and_shard_patients.split_fracs.tuning": 1 / 6, - "stage_configs.split_and_shard_patients.split_fracs.held_out": 1 / 6, + "stage_configs.split_and_shard_subjects.split_fracs.train": 4 / 6, + "stage_configs.split_and_shard_subjects.split_fracs.tuning": 1 / 6, + "stage_configs.split_and_shard_subjects.split_fracs.held_out": 1 / 6, "stage_configs.shard_events.row_chunksize": 10, - "stage_configs.split_and_shard_patients.n_patients_per_shard": 2, + "stage_configs.split_and_shard_subjects.n_subjects_per_shard": 2, "hydra.verbose": True, "etl_metadata.dataset_name": "TEST", "etl_metadata.dataset_version": "1.0", @@ -389,11 +389,11 @@ def test_extraction(): check_row_order=False, ) - # Stage 2: Collect the patient splits + # Stage 2: Collect the subject splits stderr, stdout = run_command( SPLIT_AND_SHARD_SCRIPT, extraction_config_kwargs, - "split_and_shard_patients", + "split_and_shard_subjects", ) all_stderrs.append(stderr) @@ -419,12 +419,12 @@ def test_extraction(): "NEEDING TO BE UPDATED." ) except AssertionError as e: - print("Failed to split patients") + print("Failed to split subjects") print(f"stderr:\n{stderr}") print(f"stdout:\n{stdout}") raise e - # Stage 3: Extract the events and sub-shard by patient + # Stage 3: Extract the events and sub-shard by subject stderr, stdout = run_command( CONVERT_TO_SHARDED_EVENTS_SCRIPT, extraction_config_kwargs, @@ -433,8 +433,8 @@ def test_extraction(): all_stderrs.append(stderr) all_stdouts.append(stdout) - patient_subsharded_folder = MEDS_cohort_dir / "convert_to_sharded_events" - assert patient_subsharded_folder.is_dir(), f"Expected {patient_subsharded_folder} to be a directory." + subject_subsharded_folder = MEDS_cohort_dir / "convert_to_sharded_events" + assert subject_subsharded_folder.is_dir(), f"Expected {subject_subsharded_folder} to be a directory." for split, expected_outputs in SUB_SHARDED_OUTPUTS.items(): for prefix, expected_df_L in expected_outputs.items(): @@ -443,7 +443,7 @@ def test_extraction(): expected_df = pl.concat([get_expected_output(df) for df in expected_df_L]) - fps = list((patient_subsharded_folder / split / prefix).glob("*.parquet")) + fps = list((subject_subsharded_folder / split / prefix).glob("*.parquet")) assert len(fps) > 0 # We add a "unique" here as there may be some duplicates across the row-group sub-shards. @@ -495,12 +495,12 @@ def test_extraction(): check_row_order=False, ) - assert got_df["patient_id"].is_sorted(), f"Patient IDs should be sorted for split {split}." + assert got_df["subject_id"].is_sorted(), f"Subject IDs should be sorted for split {split}." for subj in splits[split]: - got_df_subj = got_df.filter(pl.col("patient_id") == subj) + got_df_subj = got_df.filter(pl.col("subject_id") == subj) assert got_df_subj[ "time" - ].is_sorted(), f"Times should be sorted for patient {subj} in split {split}." + ].is_sorted(), f"Times should be sorted for subject {subj} in split {split}." except AssertionError as e: print(f"Failed on split {split}") @@ -560,12 +560,12 @@ def test_extraction(): check_row_order=False, ) - assert got_df["patient_id"].is_sorted(), f"Patient IDs should be sorted for split {split}." + assert got_df["subject_id"].is_sorted(), f"Subject IDs should be sorted for split {split}." for subj in splits[split]: - got_df_subj = got_df.filter(pl.col("patient_id") == subj) + got_df_subj = got_df.filter(pl.col("subject_id") == subj) assert got_df_subj[ "time" - ].is_sorted(), f"Times should be sorted for patient {subj} in split {split}." + ].is_sorted(), f"Times should be sorted for subject {subj} in split {split}." except AssertionError as e: print(f"Failed on split {split}") @@ -618,14 +618,14 @@ def test_extraction(): assert got_json == MEDS_OUTPUT_DATASET_METADATA_JSON, f"Dataset metadata differs: {got_json}" # Check the splits parquet - output_file = MEDS_cohort_dir / "metadata" / "patient_splits.parquet" + output_file = MEDS_cohort_dir / "metadata" / "subject_splits.parquet" assert output_file.is_file(), f"Expected {output_file} to exist: stderr:\n{stderr}\nstdout:\n{stdout}" got_df = pl.read_parquet(output_file, glob=False, use_pyarrow=True) assert_df_equal( - PATIENT_SPLITS_DF, + SUBJECT_SPLITS_DF, got_df, - "Patient splits should be equal to the expected splits.", + "Subject splits should be equal to the expected splits.", check_column_order=False, check_row_order=False, ) diff --git a/tests/test_filter_measurements.py b/tests/test_filter_measurements.py index cb919d1..cb5b4ee 100644 --- a/tests/test_filter_measurements.py +++ b/tests/test_filter_measurements.py @@ -10,7 +10,7 @@ # This is the code metadata # MEDS_CODE_METADATA_CSV = """ -# code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,description,parent_code +# code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/sum,values/sum_sqd,description,parent_code # ,44,4,28,3198.8389005974336,382968.28937288234,, # ADMISSION//CARDIAC,2,2,0,,,, # ADMISSION//ORTHOPEDIC,1,1,0,,,, @@ -25,11 +25,11 @@ # TEMP,12,4,12,1181.4999999999998,116373.38999999998,"Body Temperature",LOINC/8310-5 # """ # -# We'll keep only the codes that occur for at least 2 patients, which are: ADMISSION//CARDIAC, DISCHARGE, DOB, +# We'll keep only the codes that occur for at least 2 subjects, which are: ADMISSION//CARDIAC, DISCHARGE, DOB, # EYE_COLOR//HAZEL, HEIGHT, HR, TEMP WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, @@ -61,7 +61,7 @@ """ WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,,EYE_COLOR//HAZEL, 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00",DOB, @@ -77,7 +77,7 @@ """ WANT_TUNING_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, 754281,"01/03/2010, 06:27:59",HR,142.0 @@ -86,7 +86,7 @@ """ WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, 1500733,"06/03/2010, 14:54:38",HR,91.4 @@ -112,14 +112,14 @@ def test_filter_measurements(): single_stage_transform_tester( transform_script=FILTER_MEASUREMENTS_SCRIPT, stage_name="filter_measurements", - transform_stage_kwargs={"min_patients_per_code": 2}, + transform_stage_kwargs={"min_subjects_per_code": 2}, want_data=WANT_SHARDS, ) # This is the code metadata # MEDS_CODE_METADATA_CSV = """ -# code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,description,parent_code +# code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/sum,values/sum_sqd,description,parent_code # ,44,4,28,3198.8389005974336,382968.28937288234,, # ADMISSION//CARDIAC,2,2,0,,,, # ADMISSION//ORTHOPEDIC,1,1,0,,,, @@ -144,7 +144,7 @@ def test_filter_measurements(): # - Other codes won't be filtered, so we will retain HEIGHT, DISCHARGE, DOB, TEMP MR_WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, @@ -166,7 +166,7 @@ def test_filter_measurements(): """ MR_WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00",DOB, 68729,"05/26/2010, 02:30:56",TEMP,97.8 @@ -178,7 +178,7 @@ def test_filter_measurements(): """ MR_WANT_TUNING_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, 754281,"01/03/2010, 06:27:59",TEMP,99.8 @@ -186,7 +186,7 @@ def test_filter_measurements(): """ MR_WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, 1500733,"06/03/2010, 14:54:38",TEMP,100.0 @@ -214,13 +214,13 @@ def test_match_revise_filter_measurements(): stage_name="filter_measurements", transform_stage_kwargs={ "_match_revise": [ - {"_matcher": {"code": "ADMISSION//CARDIAC"}, "min_patients_per_code": 2}, - {"_matcher": {"code": "ADMISSION//ORTHOPEDIC"}, "min_patients_per_code": 2}, - {"_matcher": {"code": "ADMISSION//PULMONARY"}, "min_patients_per_code": 2}, - {"_matcher": {"code": "HR"}, "min_patients_per_code": 15}, - {"_matcher": {"code": "EYE_COLOR//BLUE"}, "min_patients_per_code": 4}, - {"_matcher": {"code": "EYE_COLOR//BROWN"}, "min_patients_per_code": 4}, - {"_matcher": {"code": "EYE_COLOR//HAZEL"}, "min_patients_per_code": 4}, + {"_matcher": {"code": "ADMISSION//CARDIAC"}, "min_subjects_per_code": 2}, + {"_matcher": {"code": "ADMISSION//ORTHOPEDIC"}, "min_subjects_per_code": 2}, + {"_matcher": {"code": "ADMISSION//PULMONARY"}, "min_subjects_per_code": 2}, + {"_matcher": {"code": "HR"}, "min_subjects_per_code": 15}, + {"_matcher": {"code": "EYE_COLOR//BLUE"}, "min_subjects_per_code": 4}, + {"_matcher": {"code": "EYE_COLOR//BROWN"}, "min_subjects_per_code": 4}, + {"_matcher": {"code": "EYE_COLOR//HAZEL"}, "min_subjects_per_code": 4}, ], }, want_data=MR_WANT_SHARDS, diff --git a/tests/test_filter_patients.py b/tests/test_filter_subjects.py similarity index 75% rename from tests/test_filter_patients.py rename to tests/test_filter_subjects.py index 0b07836..1defee4 100644 --- a/tests/test_filter_patients.py +++ b/tests/test_filter_subjects.py @@ -1,15 +1,16 @@ -"""Tests the filter patients script. +"""Tests the filter subjects script. Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed scripts. """ +from meds import subject_id_field -from .transform_tester_base import FILTER_PATIENTS_SCRIPT, single_stage_transform_tester +from .transform_tester_base import FILTER_SUBJECTS_SCRIPT, single_stage_transform_tester from .utils import parse_meds_csvs -WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value +WANT_TRAIN_0 = f""" +{subject_id_field},time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -42,18 +43,18 @@ 1195293,"06/20/2010, 20:50:04",DISCHARGE, """ -# All patients in this shard had only 4 events. -WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value +# All subjects in this shard had only 4 events. +WANT_TRAIN_1 = f""" +{subject_id_field},time,code,numeric_value """ -# All patients in this shard had only 4 events. -WANT_TUNING_0 = """ -patient_id,time,code,numeric_value +# All subjects in this shard had only 4 events. +WANT_TUNING_0 = f""" +{subject_id_field},time,code,numeric_value """ -WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +WANT_HELD_OUT_0 = f""" +{subject_id_field},time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, @@ -77,10 +78,10 @@ ) -def test_filter_patients(): +def test_filter_subjects(): single_stage_transform_tester( - transform_script=FILTER_PATIENTS_SCRIPT, - stage_name="filter_patients", - transform_stage_kwargs={"min_events_per_patient": 5}, + transform_script=FILTER_SUBJECTS_SCRIPT, + stage_name="filter_subjects", + transform_stage_kwargs={"min_events_per_subject": 5}, want_data=WANT_SHARDS, ) diff --git a/tests/test_fit_vocabulary_indices.py b/tests/test_fit_vocabulary_indices.py index ce7c40a..f67ebe1 100644 --- a/tests/test_fit_vocabulary_indices.py +++ b/tests/test_fit_vocabulary_indices.py @@ -12,7 +12,7 @@ ) WANT_CSV = """ -code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,description,parent_codes,code/vocab_index +code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/sum,values/sum_sqd,description,parent_codes,code/vocab_index ,44,4,28,3198.8389005974336,382968.28937288234,,,1 ADMISSION//CARDIAC,2,2,0,,,,,2 ADMISSION//ORTHOPEDIC,1,1,0,,,,,3 diff --git a/tests/test_multi_stage_preprocess_pipeline.py b/tests/test_multi_stage_preprocess_pipeline.py index eda9060..32c4a27 100644 --- a/tests/test_multi_stage_preprocess_pipeline.py +++ b/tests/test_multi_stage_preprocess_pipeline.py @@ -4,7 +4,7 @@ scripts. In this test, the following stages are run: - - filter_patients + - filter_subjects - add_time_derived_measurements - fit_outlier_detection - occlude_outliers @@ -20,12 +20,13 @@ from datetime import datetime import polars as pl +from meds import subject_id_field from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict from .transform_tester_base import ( ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, AGGREGATE_CODE_METADATA_SCRIPT, - FILTER_PATIENTS_SCRIPT, + FILTER_SUBJECTS_SCRIPT, FIT_VOCABULARY_INDICES_SCRIPT, NORMALIZATION_SCRIPT, OCCLUDE_OUTLIERS_SCRIPT, @@ -51,8 +52,8 @@ ) STAGE_CONFIG_YAML = """ -filter_patients: - min_events_per_patient: 5 +filter_subjects: + min_events_per_subject: 5 add_time_derived_measurements: age: DOB_code: "DOB" # This is the MEDS official code for BIRTH @@ -71,17 +72,17 @@ fit_normalization: aggregations: - "code/n_occurrences" - - "code/n_patients" + - "code/n_subjects" - "values/n_occurrences" - "values/sum" - "values/sum_sqd" """ -# After filtering out patients with fewer than 5 events: +# After filtering out subjects with fewer than 5 events: WANT_FILTER = parse_shards_yaml( - """ - "filter_patients/train/0": |-2 - patient_id,time,code,numeric_value + f""" + "filter_subjects/train/0": |-2 + {subject_id_field},time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -113,14 +114,14 @@ 1195293,"06/20/2010, 20:41:33",TEMP,100.4 1195293,"06/20/2010, 20:50:04",DISCHARGE, - "filter_patients/train/1": |-2 - patient_id,time,code,numeric_value + "filter_subjects/train/1": |-2 + {subject_id_field},time,code,numeric_value - "filter_patients/tuning/0": |-2 - patient_id,time,code,numeric_value + "filter_subjects/tuning/0": |-2 + {subject_id_field},time,code,numeric_value - "filter_patients/held_out/0": |-2 - patient_id,time,code,numeric_value + "filter_subjects/held_out/0": |-2 + {subject_id_field},time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, @@ -136,9 +137,9 @@ ) WANT_TIME_DERIVED = parse_shards_yaml( - """ + f""" "add_time_derived_measurements/train/0": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00","TIME_OF_DAY//[00,06)", @@ -197,13 +198,13 @@ 1195293,"06/20/2010, 20:50:04",DISCHARGE, "add_time_derived_measurements/train/1": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value "add_time_derived_measurements/tuning/0": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value "add_time_derived_measurements/held_out/0": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00","TIME_OF_DAY//[00,06)", @@ -387,9 +388,9 @@ """ WANT_OCCLUDE_OUTLIERS = parse_shards_yaml( - """ + f""" "occlude_outliers/train/0": |-2 - patient_id,time,code,numeric_value,numeric_value/is_inlier + {subject_id_field},time,code,numeric_value,numeric_value/is_inlier 239684,,EYE_COLOR//BROWN,, 239684,,HEIGHT,,false 239684,"12/28/1980, 00:00:00","TIME_OF_DAY//[00,06)",, @@ -448,13 +449,13 @@ 1195293,"06/20/2010, 20:50:04",DISCHARGE,, "occlude_outliers/train/1": |-2 - patient_id,time,code,numeric_value,numeric_value/is_inlier + {subject_id_field},time,code,numeric_value,numeric_value/is_inlier "occlude_outliers/tuning/0": |-2 - patient_id,time,code,numeric_value,numeric_value/is_inlier + {subject_id_field},time,code,numeric_value,numeric_value/is_inlier "occlude_outliers/held_out/0": |-2 - patient_id,time,code,numeric_value,numeric_value/is_inlier + {subject_id_field},time,code,numeric_value,numeric_value/is_inlier 1500733,,EYE_COLOR//BROWN,, 1500733,,HEIGHT,,false 1500733,"07/20/1986, 00:00:00","TIME_OF_DAY//[00,06)",, @@ -488,7 +489,7 @@ ... .group_by("code") ... .agg( ... pl.len().alias("code/n_occurrences"), -... pl.col("patient_id").n_unique().alias("code/n_patients"), +... pl.col("subject_id").n_unique().alias("code/n_subjects"), ... VALS.len().alias("values/n_occurrences"), ... VALS.sum().alias("values/sum"), ... (VALS**2).sum().alias("values/sum_sqd") @@ -497,7 +498,7 @@ >>> post_transform.filter(pl.col("values/n_occurrences") > 0) shape: (3, 6) ┌──────┬────────────────────┬─────────────────┬──────────────────────┬────────────┬────────────────┐ -│ code ┆ code/n_occurrences ┆ code/n_patients ┆ values/n_occurrences ┆ values/sum ┆ values/sum_sqd │ +│ code ┆ code/n_occurrences ┆ code/n_subjects ┆ values/n_occurrences ┆ values/sum ┆ values/sum_sqd │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ u32 ┆ u32 ┆ u32 ┆ f32 ┆ f32 │ ╞══════╪════════════════════╪═════════════════╪══════════════════════╪════════════╪════════════════╡ @@ -508,7 +509,7 @@ >>> print(post_transform.filter(pl.col("values/n_occurrences") > 0).to_dict(as_series=False)) {'code': ['HR', 'TEMP', 'AGE'], 'code/n_occurrences': [10, 10, 12], - 'code/n_patients': [2, 2, 2], + 'code/n_subjects': [2, 2, 2], 'values/n_occurrences': [7, 6, 7], 'values/sum': [776.7999877929688, 600.1000366210938, 224.02084350585938], 'values/sum_sqd': [86249.921875, 60020.21484375, 7169.33349609375]} @@ -533,7 +534,7 @@ "DOB", ], "code/n_occurrences": [1, 1, 10, 10, 12, 2, 10, 2, 2, 2, 2, 2], - "code/n_patients": [1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2], + "code/n_subjects": [1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2], "values/n_occurrences": [0, 0, 7, 6, 7, 0, 0, 0, 0, 0, 0, 0], "values/sum": [ 0.0, @@ -597,7 +598,7 @@ "description": pl.String, "parent_codes": pl.List(pl.String), "code/n_occurrences": pl.UInt8, - "code/n_patients": pl.UInt8, + "code/n_subjects": pl.UInt8, "values/n_occurrences": pl.UInt8, # In the real stage, this is shrunk, so it differs from the ex. "values/sum": pl.Float32, "values/sum_sqd": pl.Float32, @@ -625,7 +626,7 @@ ], "code/vocab_index": [5, 6, 8, 9, 2, 7, 12, 11, 10, 1, 3, 4], "code/n_occurrences": [1, 1, 10, 10, 12, 2, 10, 2, 2, 2, 2, 2], - "code/n_patients": [1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2], + "code/n_subjects": [1, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2], "values/n_occurrences": [0, 0, 7, 6, 7, 0, 0, 0, 0, 0, 0, 0], "values/sum": [ 0.0, @@ -689,7 +690,7 @@ "description": pl.String, "parent_codes": pl.List(pl.String), "code/n_occurrences": pl.UInt8, - "code/n_patients": pl.UInt8, + "code/n_subjects": pl.UInt8, "code/vocab_index": pl.UInt8, "values/n_occurrences": pl.UInt8, "values/sum": pl.Float32, @@ -772,9 +773,9 @@ # Note we have dropped the row in the held out shard that doesn't have a code in the vocabulary! WANT_NORMALIZATION = parse_shards_yaml( - """ + f""" "normalization/train/0": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value 239684,,6, 239684,,7, 239684,"12/28/1980, 00:00:00",10, @@ -833,13 +834,13 @@ 1195293,"06/20/2010, 20:50:04",3, "normalization/train/1": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value "normalization/tuning/0": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value "normalization/held_out/0": |-2 - patient_id,time,code,numeric_value + {subject_id_field},time,code,numeric_value 1500733,,6, 1500733,,7, 1500733,"07/20/1986, 00:00:00",10, @@ -864,7 +865,7 @@ ) TOKENIZATION_SCHEMA_DF_SCHEMA = { - "patient_id": pl.UInt32, + subject_id_field: pl.Int64, "code": pl.List(pl.UInt8), "numeric_value": pl.List(pl.Float32), "start_time": pl.Datetime("us"), @@ -873,7 +874,7 @@ WANT_TOKENIZATION_SCHEMAS = { "tokenization/schemas/train/0": pl.DataFrame( { - "patient_id": [239684, 1195293], + subject_id_field: [239684, 1195293], "code": [[6, 7], [5, 7]], "numeric_value": [[None, None], [None, None]], "start_time": [datetime(1980, 12, 28), datetime(1978, 6, 20)], @@ -901,16 +902,16 @@ schema=TOKENIZATION_SCHEMA_DF_SCHEMA, ), "tokenization/schemas/train/1": pl.DataFrame( - {k: [] for k in ["patient_id", "code", "numeric_value", "start_time", "time"]}, + {k: [] for k in [subject_id_field, "code", "numeric_value", "start_time", "time"]}, schema=TOKENIZATION_SCHEMA_DF_SCHEMA, ), "tokenization/schemas/tuning/0": pl.DataFrame( - {k: [] for k in ["patient_id", "code", "numeric_value", "start_time", "time"]}, + {k: [] for k in [subject_id_field, "code", "numeric_value", "start_time", "time"]}, schema=TOKENIZATION_SCHEMA_DF_SCHEMA, ), "tokenization/schemas/held_out/0": pl.DataFrame( { - "patient_id": [1500733], + subject_id_field: [1500733], "code": [[6, 7]], "numeric_value": [[None, None]], "start_time": [datetime(1986, 7, 20)], @@ -928,18 +929,9 @@ ), } -TOKENIZATION_CODE = """ -```python - ->>> import polars as pl ->>> from tests.test_multi_stage_preprocess_pipeline import WANT_NORMALIZATION as dfs ->>> - -``` -""" TOKENIZATION_EVENT_SEQS_DF_SCHEMA = { - "patient_id": pl.UInt32, + subject_id_field: pl.Int64, "code": pl.List(pl.List(pl.UInt8)), "numeric_value": pl.List(pl.List(pl.Float32)), "time_delta_days": pl.List(pl.Float32), @@ -948,7 +940,7 @@ WANT_TOKENIZATION_EVENT_SEQS = { "tokenization/event_seqs/train/0": pl.DataFrame( { - "patient_id": [239684, 1195293], + subject_id_field: [239684, 1195293], "code": [ [[10, 4], [11, 2, 1, 8, 9], [11, 2, 8, 9], [12, 2, 8, 9], [12, 2, 8, 9], [12, 2, 3]], [ @@ -995,16 +987,16 @@ schema=TOKENIZATION_EVENT_SEQS_DF_SCHEMA, ), "tokenization/event_seqs/train/1": pl.DataFrame( - {k: [] for k in ["patient_id", "code", "numeric_value", "time_delta_days"]}, + {k: [] for k in [subject_id_field, "code", "numeric_value", "time_delta_days"]}, schema=TOKENIZATION_EVENT_SEQS_DF_SCHEMA, ), "tokenization/event_seqs/tuning/0": pl.DataFrame( - {k: [] for k in ["patient_id", "code", "numeric_value", "time_delta_days"]}, + {k: [] for k in [subject_id_field, "code", "numeric_value", "time_delta_days"]}, schema=TOKENIZATION_EVENT_SEQS_DF_SCHEMA, ), "tokenization/event_seqs/held_out/0": pl.DataFrame( { - "patient_id": [1500733], + subject_id_field: [1500733], "code": [ [ [10, 4], @@ -1057,7 +1049,7 @@ def test_pipeline(): multi_stage_transform_tester( transform_scripts=[ - FILTER_PATIENTS_SCRIPT, + FILTER_SUBJECTS_SCRIPT, ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, AGGREGATE_CODE_METADATA_SCRIPT, OCCLUDE_OUTLIERS_SCRIPT, @@ -1068,7 +1060,7 @@ def test_pipeline(): TENSORIZATION_SCRIPT, ], stage_names=[ - "filter_patients", + "filter_subjects", "add_time_derived_measurements", "fit_outlier_detection", "occlude_outliers", diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 46992ed..14207c4 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -12,7 +12,7 @@ # This is the code metadata file we'll use in this transform test. It is different than the default as we need # a code/vocab_index MEDS_CODE_METADATA_CSV = """ -code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,code/vocab_index +code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/sum,values/sum_sqd,code/vocab_index ADMISSION//CARDIAC,2,2,0,,,1 ADMISSION//ORTHOPEDIC,1,1,0,,,2 ADMISSION//PULMONARY,1,1,0,,,3 @@ -129,7 +129,7 @@ # TEMP: 11 WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,,7, 239684,,9,1.5770268440246582 239684,"12/28/1980, 00:00:00",5, @@ -163,7 +163,7 @@ """ WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,,8, 68729,,9,-0.5438239574432373 68729,"03/09/1978, 00:00:00",5, @@ -181,7 +181,7 @@ """ WANT_TUNING_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,,7, 754281,,9,0.28697699308395386 754281,"12/19/1988, 00:00:00",5, @@ -192,7 +192,7 @@ """ WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,,7, 1500733,,9,-0.7995940446853638 1500733,"07/20/1986, 00:00:00",5, diff --git a/tests/test_occlude_outliers.py b/tests/test_occlude_outliers.py index 63e9376..f13a4fa 100644 --- a/tests/test_occlude_outliers.py +++ b/tests/test_occlude_outliers.py @@ -12,7 +12,7 @@ # This is the code metadata # MEDS_CODE_METADATA_CSV = """ -# code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,description,parent_code +# code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/sum,values/sum_sqd,description,parent_code # ,44,4,28,3198.8389005974336,382968.28937288234,, # ADMISSION//CARDIAC,2,2,0,,,, # ADMISSION//ORTHOPEDIC,1,1,0,,,, @@ -75,7 +75,7 @@ """ # noqa: E501 WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value,numeric_value/is_inlier +subject_id,time,code,numeric_value,numeric_value/is_inlier 239684,,EYE_COLOR//BROWN,, 239684,,HEIGHT,,false 239684,"12/28/1980, 00:00:00",DOB,, @@ -109,7 +109,7 @@ """ WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value,numeric_value/is_inlier +subject_id,time,code,numeric_value,numeric_value/is_inlier 68729,,EYE_COLOR//HAZEL,, 68729,,HEIGHT,160.3953106166676,true 68729,"03/09/1978, 00:00:00",DOB,, @@ -127,7 +127,7 @@ """ WANT_TUNING_0 = """ -patient_id,time,code,numeric_value,numeric_value/is_inlier +subject_id,time,code,numeric_value,numeric_value/is_inlier 754281,,EYE_COLOR//BROWN,, 754281,,HEIGHT,166.22261567137025,true 754281,"12/19/1988, 00:00:00",DOB,, @@ -138,7 +138,7 @@ """ WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value,numeric_value/is_inlier +subject_id,time,code,numeric_value,numeric_value/is_inlier 1500733,,EYE_COLOR//BROWN,, 1500733,,HEIGHT,158.60131573580904,true 1500733,"07/20/1986, 00:00:00",DOB,, diff --git a/tests/test_reorder_measurements.py b/tests/test_reorder_measurements.py index c90dee4..7cc7aaa 100644 --- a/tests/test_reorder_measurements.py +++ b/tests/test_reorder_measurements.py @@ -19,7 +19,7 @@ WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -53,7 +53,7 @@ """ WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,,HEIGHT,160.3953106166676 68729,,EYE_COLOR//HAZEL, 68729,"03/09/1978, 00:00:00",DOB, @@ -71,7 +71,7 @@ """ WANT_TUNING_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,,EYE_COLOR//BROWN, 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, @@ -82,7 +82,7 @@ """ WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, diff --git a/tests/test_reshard_to_split.py b/tests/test_reshard_to_split.py index 65056e5..7975103 100644 --- a/tests/test_reshard_to_split.py +++ b/tests/test_reshard_to_split.py @@ -5,6 +5,8 @@ """ +from meds import subject_id_field + from .transform_tester_base import RESHARD_TO_SPLIT_SCRIPT, single_stage_transform_tester from .utils import parse_meds_csvs @@ -14,8 +16,8 @@ "2": [239684, 1500733], } -IN_SHARD_0 = """ -patient_id,time,code,numeric_value +IN_SHARD_0 = f""" +{subject_id_field},time,code,numeric_value 68729,,EYE_COLOR//HAZEL, 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00",DOB, @@ -42,8 +44,8 @@ 1195293,"06/20/2010, 20:50:04",DISCHARGE, """ -IN_SHARD_1 = """ -patient_id,time,code,numeric_value +IN_SHARD_1 = f""" +{subject_id_field},time,code,numeric_value 754281,,EYE_COLOR//BROWN, 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, @@ -60,8 +62,8 @@ 814703,"02/05/2010, 07:02:30",DISCHARGE, """ -IN_SHARD_2 = """ -patient_id,time,code,numeric_value +IN_SHARD_2 = f""" +{subject_id_field},time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -94,8 +96,8 @@ "held_out": [1500733], } -WANT_TRAIN_0 = """ -patient_id,time,code,numeric_value +WANT_TRAIN_0 = f""" +{subject_id_field},time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -128,8 +130,8 @@ 1195293,"06/20/2010, 20:50:04",DISCHARGE, """ -WANT_TRAIN_1 = """ -patient_id,time,code,numeric_value +WANT_TRAIN_1 = f""" +{subject_id_field},time,code,numeric_value 68729,,EYE_COLOR//HAZEL, 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00",DOB, @@ -146,8 +148,8 @@ 814703,"02/05/2010, 07:02:30",DISCHARGE, """ -WANT_TUNING_0 = """ -patient_id,time,code,numeric_value +WANT_TUNING_0 = f""" +{subject_id_field},time,code,numeric_value 754281,,EYE_COLOR//BROWN, 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, @@ -157,8 +159,8 @@ 754281,"01/03/2010, 08:22:13",DISCHARGE, """ -WANT_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +WANT_HELD_OUT_0 = f""" +{subject_id_field},time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, @@ -194,7 +196,7 @@ def test_reshard_to_split(): single_stage_transform_tester( transform_script=RESHARD_TO_SPLIT_SCRIPT, stage_name="reshard_to_split", - transform_stage_kwargs={"n_patients_per_shard": 2}, + transform_stage_kwargs={"n_subjects_per_shard": 2}, want_data=WANT_SHARDS, input_shards=IN_SHARDS, input_shards_map=IN_SHARDS_MAP, diff --git a/tests/test_tokenization.py b/tests/test_tokenization.py index 693add1..cf5883e 100644 --- a/tests/test_tokenization.py +++ b/tests/test_tokenization.py @@ -20,17 +20,17 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: """TODO: Doctests""" out = [] - for patient_ts in ts: + for subject_ts in ts: out.append([float("nan")]) - for i in range(1, len(patient_ts)): - out[-1].append((patient_ts[i] - patient_ts[i - 1]).total_seconds() / SECONDS_PER_DAY) + for i in range(1, len(subject_ts)): + out[-1].append((subject_ts[i] - subject_ts[i - 1]).total_seconds() / SECONDS_PER_DAY) return out # TODO: Make these schemas exportable, maybe??? # TODO: Why is the code getting converted to a float? SCHEMAS_SCHEMA = { - "patient_id": NORMALIZED_MEDS_SCHEMA["patient_id"], + "subject_id": NORMALIZED_MEDS_SCHEMA["subject_id"], "code": pl.List(NORMALIZED_MEDS_SCHEMA["code"]), "numeric_value": pl.List(NORMALIZED_MEDS_SCHEMA["numeric_value"]), "start_time": NORMALIZED_MEDS_SCHEMA["time"], @@ -38,7 +38,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: } SEQ_SCHEMA = { - "patient_id": NORMALIZED_MEDS_SCHEMA["patient_id"], + "subject_id": NORMALIZED_MEDS_SCHEMA["subject_id"], "code": pl.List(pl.List(pl.UInt8)), "numeric_value": pl.List(pl.List(NORMALIZED_MEDS_SCHEMA["numeric_value"])), "time_delta_days": pl.List(pl.Float32), @@ -66,7 +66,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: ] WANT_SCHEMAS_TRAIN_0 = pl.DataFrame( { - "patient_id": [239684, 1195293], + "subject_id": [239684, 1195293], "code": [[7, 9], [6, 9]], "numeric_value": [[None, 1.5770268440246582], [None, 0.06802856922149658]], "start_time": [ts[0] for ts in TRAIN_0_TIMES], @@ -77,7 +77,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: WANT_EVENT_SEQ_TRAIN_0 = pl.DataFrame( { - "patient_id": [239684, 1195293], + "subject_id": [239684, 1195293], "time_delta_days": ts_to_time_delta_days(TRAIN_0_TIMES), "code": [ [[5], [1, 10, 11], [10, 11], [10, 11], [10, 11], [4]], @@ -114,7 +114,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: WANT_SCHEMAS_TRAIN_1 = pl.DataFrame( { - "patient_id": [68729, 814703], + "subject_id": [68729, 814703], "code": [[8, 9], [8, 9]], "numeric_value": [[None, -0.5438239574432373], [None, -1.1012336015701294]], "start_time": [ts[0] for ts in TRAIN_1_TIMES], @@ -125,7 +125,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: WANT_EVENT_SEQ_TRAIN_1 = pl.DataFrame( { - "patient_id": [68729, 814703], + "subject_id": [68729, 814703], "time_delta_days": ts_to_time_delta_days(TRAIN_1_TIMES), "code": [[[5], [3, 10, 11], [4]], [[5], [2, 10, 11], [4]]], "numeric_value": [ @@ -140,7 +140,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: WANT_SCHEMAS_TUNING_0 = pl.DataFrame( { - "patient_id": [754281], + "subject_id": [754281], "code": [[7, 9]], "numeric_value": [[None, 0.28697699308395386]], "start_time": [ts[0] for ts in TUNING_0_TIMES], @@ -151,7 +151,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: WANT_EVENT_SEQ_TUNING_0 = pl.DataFrame( { - "patient_id": [754281], + "subject_id": [754281], "time_delta_days": ts_to_time_delta_days(TUNING_0_TIMES), "code": [[[5], [3, 10, 11], [4]]], "numeric_value": [ @@ -174,7 +174,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: WANT_SCHEMAS_HELD_OUT_0 = pl.DataFrame( { - "patient_id": [1500733], + "subject_id": [1500733], "code": [[7, 9]], "numeric_value": [[None, -0.7995940446853638]], "start_time": [ts[0] for ts in HELD_OUT_0_TIMES], @@ -185,7 +185,7 @@ def ts_to_time_delta_days(ts: list[list[datetime]]) -> list[list[float]]: WANT_EVENT_SEQ_HELD_OUT_0 = pl.DataFrame( { - "patient_id": [1500733], + "subject_id": [1500733], "time_delta_days": ts_to_time_delta_days(HELD_OUT_0_TIMES), "code": [[[5], [2, 10, 11], [10, 11], [10, 11], [4]]], "numeric_value": [ diff --git a/tests/transform_tester_base.py b/tests/transform_tester_base.py index 6845e0c..2deddd4 100644 --- a/tests/transform_tester_base.py +++ b/tests/transform_tester_base.py @@ -22,6 +22,7 @@ import numpy as np import polars as pl import rootutils +from meds import subject_id_field from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict from .utils import MEDS_PL_SCHEMA, assert_df_equal, parse_meds_csvs, run_command @@ -40,7 +41,7 @@ # Filters FILTER_MEASUREMENTS_SCRIPT = filters_root / "filter_measurements.py" - FILTER_PATIENTS_SCRIPT = filters_root / "filter_patients.py" + FILTER_SUBJECTS_SCRIPT = filters_root / "filter_subjects.py" # Transforms ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT = transforms_root / "add_time_derived_measurements.py" @@ -57,7 +58,7 @@ # Filters FILTER_MEASUREMENTS_SCRIPT = "MEDS_transform-filter_measurements" - FILTER_PATIENTS_SCRIPT = "MEDS_transform-filter_patients" + FILTER_SUBJECTS_SCRIPT = "MEDS_transform-filter_subjects" # Transforms ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT = "MEDS_transform-add_time_derived_measurements" @@ -83,7 +84,7 @@ } MEDS_TRAIN_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 239684,,EYE_COLOR//BROWN, 239684,,HEIGHT,175.271115221764 239684,"12/28/1980, 00:00:00",DOB, @@ -117,7 +118,7 @@ """ MEDS_TRAIN_1 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 68729,,EYE_COLOR//HAZEL, 68729,,HEIGHT,160.3953106166676 68729,"03/09/1978, 00:00:00",DOB, @@ -135,7 +136,7 @@ """ MEDS_TUNING_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 754281,,EYE_COLOR//BROWN, 754281,,HEIGHT,166.22261567137025 754281,"12/19/1988, 00:00:00",DOB, @@ -146,7 +147,7 @@ """ MEDS_HELD_OUT_0 = """ -patient_id,time,code,numeric_value +subject_id,time,code,numeric_value 1500733,,EYE_COLOR//BROWN, 1500733,,HEIGHT,158.60131573580904 1500733,"07/20/1986, 00:00:00",DOB, @@ -171,7 +172,7 @@ MEDS_CODE_METADATA_CSV = """ -code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,description,parent_codes +code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/sum,values/sum_sqd,description,parent_codes ,44,4,28,3198.8389005974336,382968.28937288234,, ADMISSION//CARDIAC,2,2,0,,,, ADMISSION//ORTHOPEDIC,1,1,0,,,, @@ -189,9 +190,9 @@ MEDS_CODE_METADATA_SCHEMA = { "code": pl.Utf8, "code/n_occurrences": pl.UInt8, - "code/n_patients": pl.UInt8, + "code/n_subjects": pl.UInt8, "values/n_occurrences": pl.UInt8, - "values/n_patients": pl.UInt8, + "values/n_subjects": pl.UInt8, "values/sum": pl.Float32, "values/sum_sqd": pl.Float32, "values/n_ints": pl.UInt8, @@ -323,11 +324,11 @@ def input_MEDS_dataset( if input_splits_map is None: input_splits_map = SPLITS input_splits_as_df = defaultdict(list) - for split_name, patient_ids in input_splits_map.items(): - input_splits_as_df["patient_id"].extend(patient_ids) - input_splits_as_df["split"].extend([split_name] * len(patient_ids)) + for split_name, subject_ids in input_splits_map.items(): + input_splits_as_df[subject_id_field].extend(subject_ids) + input_splits_as_df["split"].extend([split_name] * len(subject_ids)) input_splits_df = pl.DataFrame(input_splits_as_df) - input_splits_fp = MEDS_metadata_dir / "patient_splits.parquet" + input_splits_fp = MEDS_metadata_dir / "subject_splits.parquet" input_splits_df.write_parquet(input_splits_fp, use_pyarrow=True) if input_shards is None: diff --git a/tests/utils.py b/tests/utils.py index e6c9d3f..74efae5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,7 +11,7 @@ # TODO: Make use meds library MEDS_PL_SCHEMA = { - "patient_id": pl.UInt32, + "subject_id": pl.Int64, "time": pl.Datetime("us"), "code": pl.Utf8, "numeric_value": pl.Float32, From 0a61775f44825c9e7a1f63bcfff0f71484b3c8ea Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 08:48:44 -0400 Subject: [PATCH 02/23] Added tests to reshard stage --- src/MEDS_transforms/mapreduce/utils.py | 2 +- src/MEDS_transforms/reshard_to_split.py | 48 ++++++++++++++++++++++--- tests/test_reshard_to_split.py | 11 ++++++ tests/transform_tester_base.py | 4 +++ tests/utils.py | 14 +++++--- 5 files changed, 69 insertions(+), 10 deletions(-) diff --git a/src/MEDS_transforms/mapreduce/utils.py b/src/MEDS_transforms/mapreduce/utils.py index 300e4c3..9a7151f 100644 --- a/src/MEDS_transforms/mapreduce/utils.py +++ b/src/MEDS_transforms/mapreduce/utils.py @@ -483,7 +483,7 @@ def shard_iterator( shards = train_shards includes_only_train = True elif train_only: - logger.info( + logger.warning( f"train_only={train_only} requested but no dedicated train shards found; processing all shards " "and relying on `patient_splits.parquet` for filtering." ) diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index f936196..a4a0276 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -19,7 +19,31 @@ def valid_json_file(fp: Path) -> bool: - """Check if a file is a valid JSON file.""" + """Check if a file is a valid JSON file. + + Args: + fp: Path to the file. + + Returns: + True if the file is a valid JSON file, False otherwise. + + Examples: + >>> import tempfile + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... fp = Path(tmpdir) / "test.json" + ... valid_json_file(fp) + False + >>> with tempfile.NamedTemporaryFile(suffix=".json") as tmpfile: + ... fp = Path(tmpfile.name) + ... _ = fp.write_text("foobar not a json file.\tHello, world!") + ... valid_json_file(fp) + False + >>> with tempfile.NamedTemporaryFile(suffix=".json") as tmpfile: + ... fp = Path(tmpfile.name) + ... _ = fp.write_text('{"foo": "bar"}') + ... valid_json_file(fp) + True + """ if not fp.is_file(): return False try: @@ -30,6 +54,7 @@ def valid_json_file(fp: Path) -> bool: def make_new_shards_fn(df: pl.DataFrame, cfg: DictConfig, stage_cfg: DictConfig) -> dict[str, list[str]]: + """This function creates a new sharding scheme for the MEDS cohort.""" splits_map = defaultdict(list) for pt_id, sp in df.iter_rows(): splits_map[sp].append(pt_id) @@ -44,6 +69,20 @@ def make_new_shards_fn(df: pl.DataFrame, cfg: DictConfig, stage_cfg: DictConfig) def write_json(d: dict, fp: Path) -> None: + """Write a dictionary to a JSON file. + + Args: + d: Dictionary to write. + fp: Path to the file. + + Examples: + >>> import tempfile + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... fp = Path(tmpdir) / "test.json" + ... write_json({"foo": "bar"}, fp) + ... fp.read_text() + '{"foo": "bar"}' + """ fp.write_text(json.dumps(d)) @@ -79,9 +118,10 @@ def main(cfg: DictConfig): new_sharded_splits = json.loads(shards_fp.read_text()) - orig_shards_iter, include_only_train = shard_iterator(cfg, out_suffix="") - if include_only_train: - raise ValueError("This stage does not support include_only_train=True") + if cfg.stage_cfg.get("train_only", False): + raise ValueError("This stage does not support train_only=True") + + orig_shards_iter, _ = shard_iterator(cfg, out_suffix="") orig_shards_iter = [(in_fp, out_fp.relative_to(output_dir)) for in_fp, out_fp in orig_shards_iter] diff --git a/tests/test_reshard_to_split.py b/tests/test_reshard_to_split.py index 65056e5..3af7f19 100644 --- a/tests/test_reshard_to_split.py +++ b/tests/test_reshard_to_split.py @@ -200,3 +200,14 @@ def test_reshard_to_split(): input_shards_map=IN_SHARDS_MAP, input_splits_map=SPLITS, ) + + single_stage_transform_tester( + transform_script=RESHARD_TO_SPLIT_SCRIPT, + stage_name="reshard_to_split", + transform_stage_kwargs={"n_patients_per_shard": 2, "+train_only": True}, + want_data=WANT_SHARDS, + input_shards=IN_SHARDS, + input_shards_map=IN_SHARDS_MAP, + input_splits_map=SPLITS, + should_error=True, + ) diff --git a/tests/transform_tester_base.py b/tests/transform_tester_base.py index bca36ad..ec061fc 100644 --- a/tests/transform_tester_base.py +++ b/tests/transform_tester_base.py @@ -404,6 +404,7 @@ def single_stage_transform_tester( want_data: dict[str, pl.DataFrame] | None = None, want_metadata: pl.DataFrame | None = None, assert_no_other_outputs: bool = True, + should_error: bool = False, **input_data_kwargs, ): with input_MEDS_dataset(**input_data_kwargs) as (MEDS_dir, cohort_dir): @@ -421,6 +422,7 @@ def single_stage_transform_tester( "script": transform_script, "hydra_kwargs": pipeline_config_kwargs, "test_name": f"Single stage transform: {stage_name}", + "should_error": should_error, } if do_use_config_yaml: run_command_kwargs["do_use_config_yaml"] = True @@ -431,6 +433,8 @@ def single_stage_transform_tester( # Run the transform stderr, stdout = run_command(**run_command_kwargs) + if should_error: + return try: check_outputs(cohort_dir, want_data=want_data, want_metadata=want_metadata) diff --git a/tests/utils.py b/tests/utils.py index e7220c9..f562bce 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -70,6 +70,8 @@ def dict_to_hydra_kwargs(d: dict[str, str]) -> str: ValueError: Unexpected type for value for key a: : 2021-11-01 00:00:00 """ + modifier_chars = ["~", "'", "++", "+"] + out = [] for k, v in d.items(): if not isinstance(k, str): @@ -86,11 +88,13 @@ def dict_to_hydra_kwargs(d: dict[str, str]) -> str: case dict(): inner_kwargs = dict_to_hydra_kwargs(v) for inner_kv in inner_kwargs: - if inner_kv.startswith("~"): - out.append(f"~{k}.{inner_kv[1:]}") - elif inner_kv.startswith("'"): - out.append(f"'{k}.{inner_kv[1:]}") - else: + handled = False + for mod in modifier_chars: + if inner_kv.startswith(mod): + out.append(f"{mod}{k}.{inner_kv[len(mod):]}") + handled = True + break + if not handled: out.append(f"{k}.{inner_kv}") case list() | tuple(): v = list(v) From 19502ebe07b945bef5b0ce773ef2abe38e0a493e Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 08:58:39 -0400 Subject: [PATCH 03/23] Added error case test for fitting vocabulary indices. --- tests/test_fit_vocabulary_indices.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_fit_vocabulary_indices.py b/tests/test_fit_vocabulary_indices.py index ce7c40a..c468050 100644 --- a/tests/test_fit_vocabulary_indices.py +++ b/tests/test_fit_vocabulary_indices.py @@ -35,3 +35,11 @@ def test_fit_vocabulary_indices_with_default_stage_config(): transform_stage_kwargs=None, want_metadata=parse_code_metadata_csv(WANT_CSV), ) + + single_stage_transform_tester( + transform_script=FIT_VOCABULARY_INDICES_SCRIPT, + stage_name="fit_vocabulary_indices", + transform_stage_kwargs={"ordering_method": "file"}, + want_metadata=parse_code_metadata_csv(WANT_CSV), + should_error=True, + ) From 35500190d7f048a0d0bb26c26e48116759626c9d Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 09:11:58 -0400 Subject: [PATCH 04/23] Added a bunch of no-cover lines for things that will get skipped in GitHub CI --- src/MEDS_transforms/__init__.py | 2 +- src/MEDS_transforms/aggregate_code_metadata.py | 2 +- src/MEDS_transforms/extract/convert_to_sharded_events.py | 2 +- src/MEDS_transforms/extract/extract_code_metadata.py | 2 +- src/MEDS_transforms/extract/finalize_MEDS_data.py | 2 +- src/MEDS_transforms/extract/finalize_MEDS_metadata.py | 2 +- src/MEDS_transforms/extract/merge_to_MEDS_cohort.py | 2 +- src/MEDS_transforms/extract/shard_events.py | 2 +- src/MEDS_transforms/extract/split_and_shard_patients.py | 2 +- src/MEDS_transforms/filters/filter_measurements.py | 2 +- src/MEDS_transforms/filters/filter_patients.py | 2 +- src/MEDS_transforms/fit_vocabulary_indices.py | 2 +- src/MEDS_transforms/reshard_to_split.py | 2 +- src/MEDS_transforms/transforms/add_time_derived_measurements.py | 2 +- src/MEDS_transforms/transforms/extract_values.py | 2 +- src/MEDS_transforms/transforms/normalization.py | 2 +- src/MEDS_transforms/transforms/occlude_outliers.py | 2 +- src/MEDS_transforms/transforms/reorder_measurements.py | 2 +- src/MEDS_transforms/transforms/tensorization.py | 2 +- src/MEDS_transforms/transforms/tokenization.py | 2 +- 20 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/MEDS_transforms/__init__.py b/src/MEDS_transforms/__init__.py index 38f2ac9..c5aba54 100644 --- a/src/MEDS_transforms/__init__.py +++ b/src/MEDS_transforms/__init__.py @@ -6,7 +6,7 @@ __package_name__ = "MEDS_transforms" try: __version__ = version(__package_name__) -except PackageNotFoundError: +except PackageNotFoundError: # pragma: no cover __version__ = "unknown" PREPROCESS_CONFIG_YAML = files(__package_name__).joinpath("configs/preprocess.yaml") diff --git a/src/MEDS_transforms/aggregate_code_metadata.py b/src/MEDS_transforms/aggregate_code_metadata.py index 1f6828b..ddf2a4a 100755 --- a/src/MEDS_transforms/aggregate_code_metadata.py +++ b/src/MEDS_transforms/aggregate_code_metadata.py @@ -730,5 +730,5 @@ def main(cfg: DictConfig): run_map_reduce(cfg) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/extract/convert_to_sharded_events.py b/src/MEDS_transforms/extract/convert_to_sharded_events.py index ee4e9d7..dc6a6b0 100755 --- a/src/MEDS_transforms/extract/convert_to_sharded_events.py +++ b/src/MEDS_transforms/extract/convert_to_sharded_events.py @@ -744,5 +744,5 @@ def compute_fn(df: pl.LazyFrame) -> pl.LazyFrame: logger.info("Subsharded into converted events.") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/extract/extract_code_metadata.py b/src/MEDS_transforms/extract/extract_code_metadata.py index e9133eb..959d2ff 100644 --- a/src/MEDS_transforms/extract/extract_code_metadata.py +++ b/src/MEDS_transforms/extract/extract_code_metadata.py @@ -449,5 +449,5 @@ def reducer_fn(*dfs): logger.info(f"Finished reduction in {datetime.now() - start}") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/extract/finalize_MEDS_data.py b/src/MEDS_transforms/extract/finalize_MEDS_data.py index f9d6873..7920e4c 100644 --- a/src/MEDS_transforms/extract/finalize_MEDS_data.py +++ b/src/MEDS_transforms/extract/finalize_MEDS_data.py @@ -134,5 +134,5 @@ def main(cfg: DictConfig): map_over(cfg, compute_fn=get_and_validate_data_schema, write_fn=pq.write_table) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/extract/finalize_MEDS_metadata.py b/src/MEDS_transforms/extract/finalize_MEDS_metadata.py index 366d89a..0b215f1 100755 --- a/src/MEDS_transforms/extract/finalize_MEDS_metadata.py +++ b/src/MEDS_transforms/extract/finalize_MEDS_metadata.py @@ -218,5 +218,5 @@ def main(cfg: DictConfig): pq.write_table(patient_splits_tbl, patient_splits_fp) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py b/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py index 49a45e1..f7de081 100755 --- a/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py +++ b/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py @@ -242,5 +242,5 @@ def main(cfg: DictConfig): ) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/extract/shard_events.py b/src/MEDS_transforms/extract/shard_events.py index 18450bb..88f8438 100755 --- a/src/MEDS_transforms/extract/shard_events.py +++ b/src/MEDS_transforms/extract/shard_events.py @@ -429,5 +429,5 @@ def main(cfg: DictConfig): logger.info(f"Sub-sharding completed in {datetime.now() - start}") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/extract/split_and_shard_patients.py b/src/MEDS_transforms/extract/split_and_shard_patients.py index a385c73..0cee836 100755 --- a/src/MEDS_transforms/extract/split_and_shard_patients.py +++ b/src/MEDS_transforms/extract/split_and_shard_patients.py @@ -276,5 +276,5 @@ def main(cfg: DictConfig): logger.info("Done writing sharded patients") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/filters/filter_measurements.py b/src/MEDS_transforms/filters/filter_measurements.py index 36a6938..9f30185 100644 --- a/src/MEDS_transforms/filters/filter_measurements.py +++ b/src/MEDS_transforms/filters/filter_measurements.py @@ -147,5 +147,5 @@ def main(cfg: DictConfig): map_over(cfg, compute_fn=filter_measurements_fntr) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/filters/filter_patients.py b/src/MEDS_transforms/filters/filter_patients.py index 36dc398..c5630b2 100644 --- a/src/MEDS_transforms/filters/filter_patients.py +++ b/src/MEDS_transforms/filters/filter_patients.py @@ -233,5 +233,5 @@ def main(cfg: DictConfig): map_over(cfg, compute_fn=filter_patients_fntr) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/fit_vocabulary_indices.py b/src/MEDS_transforms/fit_vocabulary_indices.py index 0fb249b..b5e4327 100644 --- a/src/MEDS_transforms/fit_vocabulary_indices.py +++ b/src/MEDS_transforms/fit_vocabulary_indices.py @@ -236,5 +236,5 @@ def main(cfg: DictConfig): logger.info(f"Done with {cfg.stage}") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index a4a0276..d74ba86 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -165,5 +165,5 @@ def write_fn(df: pl.LazyFrame, out_fp: Path) -> None: logger.info(f"Done with {cfg.stage}") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/transforms/add_time_derived_measurements.py b/src/MEDS_transforms/transforms/add_time_derived_measurements.py index 01ec7f9..c0423c2 100644 --- a/src/MEDS_transforms/transforms/add_time_derived_measurements.py +++ b/src/MEDS_transforms/transforms/add_time_derived_measurements.py @@ -398,5 +398,5 @@ def main(cfg: DictConfig): map_over(cfg, compute_fn=add_time_derived_measurements_fntr) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/transforms/extract_values.py b/src/MEDS_transforms/transforms/extract_values.py index a1d42b6..f99eb6d 100644 --- a/src/MEDS_transforms/transforms/extract_values.py +++ b/src/MEDS_transforms/transforms/extract_values.py @@ -130,5 +130,5 @@ def main(cfg: DictConfig): map_over(cfg, compute_fn=extract_values_fntr) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/transforms/normalization.py b/src/MEDS_transforms/transforms/normalization.py index fbad9ac..ef43dda 100644 --- a/src/MEDS_transforms/transforms/normalization.py +++ b/src/MEDS_transforms/transforms/normalization.py @@ -220,5 +220,5 @@ def main(cfg: DictConfig): map_over(cfg, compute_fn=normalize) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/transforms/occlude_outliers.py b/src/MEDS_transforms/transforms/occlude_outliers.py index 107407d..f9095e2 100644 --- a/src/MEDS_transforms/transforms/occlude_outliers.py +++ b/src/MEDS_transforms/transforms/occlude_outliers.py @@ -110,5 +110,5 @@ def main(cfg: DictConfig): map_over(cfg, compute_fn=occlude_outliers_fntr) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/transforms/reorder_measurements.py b/src/MEDS_transforms/transforms/reorder_measurements.py index 1205f77..32f2857 100644 --- a/src/MEDS_transforms/transforms/reorder_measurements.py +++ b/src/MEDS_transforms/transforms/reorder_measurements.py @@ -184,5 +184,5 @@ def main(cfg: DictConfig): map_over(cfg, reorder_by_code_fntr) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/transforms/tensorization.py b/src/MEDS_transforms/transforms/tensorization.py index 0266ce2..bfadb7a 100644 --- a/src/MEDS_transforms/transforms/tensorization.py +++ b/src/MEDS_transforms/transforms/tensorization.py @@ -115,5 +115,5 @@ def main(cfg: DictConfig): ) -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() diff --git a/src/MEDS_transforms/transforms/tokenization.py b/src/MEDS_transforms/transforms/tokenization.py index d6f5003..8965cf1 100644 --- a/src/MEDS_transforms/transforms/tokenization.py +++ b/src/MEDS_transforms/transforms/tokenization.py @@ -265,5 +265,5 @@ def main(cfg: DictConfig): logger.info(f"Done with {cfg.stage}") -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover main() From a2de58e88499f6492c3992459ba939886e6f8767 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 09:13:51 -0400 Subject: [PATCH 05/23] Added an error case test for tokenization. --- src/MEDS_transforms/transforms/tokenization.py | 5 ++--- tests/test_tokenization.py | 9 +++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/MEDS_transforms/transforms/tokenization.py b/src/MEDS_transforms/transforms/tokenization.py index 8965cf1..0910982 100644 --- a/src/MEDS_transforms/transforms/tokenization.py +++ b/src/MEDS_transforms/transforms/tokenization.py @@ -229,11 +229,10 @@ def main(cfg: DictConfig): ) output_dir = Path(cfg.stage_cfg.output_dir) + if train_only := cfg.stage_cfg.get("train_only", False): + raise ValueError(f"train_only={train_only} is not supported for this stage.") shards_single_output, include_only_train = shard_iterator(cfg) - if include_only_train: - raise ValueError("Not supported for this stage.") - for in_fp, out_fp in shards_single_output: sharded_path = out_fp.relative_to(output_dir) diff --git a/tests/test_tokenization.py b/tests/test_tokenization.py index 693add1..0945c7c 100644 --- a/tests/test_tokenization.py +++ b/tests/test_tokenization.py @@ -225,3 +225,12 @@ def test_tokenization(): input_shards=NORMALIZED_SHARDS, want_data={**WANT_SCHEMAS, **WANT_EVENT_SEQS}, ) + + single_stage_transform_tester( + transform_script=TOKENIZATION_SCRIPT, + stage_name="tokenization", + transform_stage_kwargs={"train_only": True}, + input_shards=NORMALIZED_SHARDS, + want_data={**WANT_SCHEMAS, **WANT_EVENT_SEQS}, + should_error=True, + ) From 214d0e9ca5496ebf5758df8ad2705855554b20ba Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 09:25:17 -0400 Subject: [PATCH 06/23] Added tests to filter patients --- .../filters/filter_patients.py | 45 ++++++++++++++++--- 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/src/MEDS_transforms/filters/filter_patients.py b/src/MEDS_transforms/filters/filter_patients.py index c5630b2..0682257 100644 --- a/src/MEDS_transforms/filters/filter_patients.py +++ b/src/MEDS_transforms/filters/filter_patients.py @@ -13,7 +13,7 @@ def filter_patients_by_num_measurements(df: pl.LazyFrame, min_measurements_per_patient: int) -> pl.LazyFrame: - """Filters patients by the number of measurements they have. + """Filters patients by the number of dynamic (timestamp non-null) measurements they have. Args: df: The input DataFrame. @@ -24,11 +24,11 @@ def filter_patients_by_num_measurements(df: pl.LazyFrame, min_measurements_per_p Examples: >>> df = pl.DataFrame({ - ... "patient_id": [1, 1, 1, 2, 2, 3], - ... "time": [1, 2, 1, 1, 2, 1], + ... "patient_id": [1, 1, 1, 2, 2, 3, 3, 4], + ... "time": [1, 2, 1, 1, 2, 1, None, None], ... }) >>> filter_patients_by_num_measurements(df, 1) - shape: (6, 2) + shape: (7, 2) ┌────────────┬──────┐ │ patient_id ┆ time │ │ --- ┆ --- │ @@ -40,6 +40,7 @@ def filter_patients_by_num_measurements(df: pl.LazyFrame, min_measurements_per_p │ 2 ┆ 1 │ │ 2 ┆ 2 │ │ 3 ┆ 1 │ + │ 3 ┆ null │ └────────────┴──────┘ >>> filter_patients_by_num_measurements(df, 2) shape: (5, 2) @@ -102,7 +103,8 @@ def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) ... "patient_id": [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4], ... "time": [1, 1, 1, 1, 2, 1, 1, 2, 3, None, None, 1, 2, 3], ... }) - >>> filter_patients_by_num_events(df, 1) + >>> with pl.Config(tbl_rows=15): + ... filter_patients_by_num_events(df, 1) shape: (14, 2) ┌────────────┬──────┐ │ patient_id ┆ time │ @@ -124,7 +126,8 @@ def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) │ 4 ┆ 2 │ │ 4 ┆ 3 │ └────────────┴──────┘ - >>> filter_patients_by_num_events(df, 2) + >>> with pl.Config(tbl_rows=15): + ... filter_patients_by_num_events(df, 2) shape: (11, 2) ┌────────────┬──────┐ │ patient_id ┆ time │ @@ -195,6 +198,36 @@ def filter_patients_by_num_events(df: pl.LazyFrame, min_events_per_patient: int) def filter_patients_fntr(stage_cfg: DictConfig) -> Callable[[pl.LazyFrame], pl.LazyFrame]: + """Returns a function that filters patients by the number of measurements and events they have. + + Args: + stage_cfg: The stage configuration. Arguments include: min_measurements_per_patient, + min_events_per_patient, both of which should be integers or None which specify the minimum number + of measurements and events a patient must have to be included, respectively. + + Returns: + The function that filters patients by the number of measurements and/or events they have. + + Examples: + >>> df = pl.DataFrame({ + ... "patient_id": [1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5], + ... "time": [1, 1, 1, 1, 1, 1, 2, 3, None, None, 1, 2, 2, None, 1, 2, 3, 1], + ... }) + >>> stage_cfg = DictConfig({"min_measurements_per_patient": 4, "min_events_per_patient": 2}) + >>> filter_patients_fntr(stage_cfg)(df) + shape: (4, 2) + ┌────────────┬──────┐ + │ patient_id ┆ time │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞════════════╪══════╡ + │ 5 ┆ 1 │ + │ 5 ┆ 2 │ + │ 5 ┆ 3 │ + │ 5 ┆ 1 │ + └────────────┴──────┘ + """ + compute_fns = [] if stage_cfg.min_measurements_per_patient: logger.info( From 7d58386a408b4cb04efde4b39dd8021851d1191d Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 09:27:19 -0400 Subject: [PATCH 07/23] Added tests to filter measurements --- .../filters/filter_measurements.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/MEDS_transforms/filters/filter_measurements.py b/src/MEDS_transforms/filters/filter_measurements.py index 9f30185..2452956 100644 --- a/src/MEDS_transforms/filters/filter_measurements.py +++ b/src/MEDS_transforms/filters/filter_measurements.py @@ -97,6 +97,32 @@ def filter_measurements_fntr( ╞════════════╪══════╪═══════════╡ │ 2 ┆ A ┆ 2 │ └────────────┴──────┴───────────┘ + + This stage works even if the default row index column exists: + >>> code_metadata_df = pl.DataFrame({ + ... "code": ["A", "A", "B", "C"], + ... "modifier1": [1, 2, 1, 2], + ... "code/n_patients": [2, 1, 3, 2], + ... "code/n_occurrences": [4, 5, 3, 2], + ... }) + >>> data = pl.DataFrame({ + ... "patient_id": [1, 1, 2, 2], + ... "code": ["A", "B", "A", "C"], + ... "modifier1": [1, 1, 2, 2], + ... "_row_idx": [1, 1, 1, 1], + ... }).lazy() + >>> stage_cfg = DictConfig({"min_patients_per_code": 2, "min_occurrences_per_code": 3}) + >>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"]) + >>> fn(data).collect() + shape: (2, 4) + ┌────────────┬──────┬───────────┬──────────┐ + │ patient_id ┆ code ┆ modifier1 ┆ _row_idx │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ str ┆ i64 ┆ i64 │ + ╞════════════╪══════╪═══════════╪══════════╡ + │ 1 ┆ A ┆ 1 ┆ 1 │ + │ 1 ┆ B ┆ 1 ┆ 1 │ + └────────────┴──────┴───────────┴──────────┘ """ min_patients_per_code = stage_cfg.get("min_patients_per_code", None) From fde006713cc845404a37294f04ca89d36f1ec64f Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 09:57:53 -0400 Subject: [PATCH 08/23] Corrected typo in filter measurements tests. --- src/MEDS_transforms/filters/filter_measurements.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/MEDS_transforms/filters/filter_measurements.py b/src/MEDS_transforms/filters/filter_measurements.py index b0d9021..979dc0a 100644 --- a/src/MEDS_transforms/filters/filter_measurements.py +++ b/src/MEDS_transforms/filters/filter_measurements.py @@ -102,21 +102,21 @@ def filter_measurements_fntr( >>> code_metadata_df = pl.DataFrame({ ... "code": ["A", "A", "B", "C"], ... "modifier1": [1, 2, 1, 2], - ... "code/n_patients": [2, 1, 3, 2], + ... "code/n_subjects": [2, 1, 3, 2], ... "code/n_occurrences": [4, 5, 3, 2], ... }) >>> data = pl.DataFrame({ - ... "patient_id": [1, 1, 2, 2], + ... "subject_id": [1, 1, 2, 2], ... "code": ["A", "B", "A", "C"], ... "modifier1": [1, 1, 2, 2], ... "_row_idx": [1, 1, 1, 1], ... }).lazy() - >>> stage_cfg = DictConfig({"min_patients_per_code": 2, "min_occurrences_per_code": 3}) + >>> stage_cfg = DictConfig({"min_subjects_per_code": 2, "min_occurrences_per_code": 3}) >>> fn = filter_measurements_fntr(stage_cfg, code_metadata_df, ["modifier1"]) >>> fn(data).collect() shape: (2, 4) ┌────────────┬──────┬───────────┬──────────┐ - │ patient_id ┆ code ┆ modifier1 ┆ _row_idx │ + │ subject_id ┆ code ┆ modifier1 ┆ _row_idx │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ str ┆ i64 ┆ i64 │ ╞════════════╪══════╪═══════════╪══════════╡ From 6bfae894c1d46a1f65f6953a68c31c9bef97b0eb Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 15:07:17 -0400 Subject: [PATCH 09/23] Re-organized tests. --- tests/MEDS_Extract/__init__.py | 0 tests/{ => MEDS_Extract}/test_extract.py | 2 +- tests/{ => MEDS_Extract}/test_extract_no_metadata.py | 2 +- tests/MEDS_Transforms/__init__.py | 0 .../test_add_time_derived_measurements.py | 10 ++++++++-- .../test_aggregate_code_metadata.py | 5 ++++- tests/{ => MEDS_Transforms}/test_extract_values.py | 10 +++++++++- .../{ => MEDS_Transforms}/test_filter_measurements.py | 10 ++++++++-- tests/{ => MEDS_Transforms}/test_filter_subjects.py | 7 +++++-- .../test_fit_vocabulary_indices.py | 6 +++++- .../test_multi_stage_preprocess_pipeline.py | 6 +++++- tests/{ => MEDS_Transforms}/test_normalization.py | 7 +++++-- tests/{ => MEDS_Transforms}/test_occlude_outliers.py | 7 +++++-- .../{ => MEDS_Transforms}/test_reorder_measurements.py | 10 ++++++++-- tests/{ => MEDS_Transforms}/test_reshard_to_split.py | 7 +++++-- tests/{ => MEDS_Transforms}/test_tensorization.py | 7 +++++-- tests/{ => MEDS_Transforms}/test_tokenization.py | 0 tests/{ => MEDS_Transforms}/transform_tester_base.py | 2 +- 18 files changed, 75 insertions(+), 23 deletions(-) create mode 100644 tests/MEDS_Extract/__init__.py rename tests/{ => MEDS_Extract}/test_extract.py (99%) rename tests/{ => MEDS_Extract}/test_extract_no_metadata.py (99%) create mode 100644 tests/MEDS_Transforms/__init__.py rename tests/{ => MEDS_Transforms}/test_add_time_derived_measurements.py (96%) rename tests/{ => MEDS_Transforms}/test_aggregate_code_metadata.py (97%) rename tests/{ => MEDS_Transforms}/test_extract_values.py (94%) rename tests/{ => MEDS_Transforms}/test_filter_measurements.py (96%) rename tests/{ => MEDS_Transforms}/test_filter_subjects.py (91%) rename tests/{ => MEDS_Transforms}/test_fit_vocabulary_indices.py (91%) rename tests/{ => MEDS_Transforms}/test_multi_stage_preprocess_pipeline.py (99%) rename tests/{ => MEDS_Transforms}/test_normalization.py (96%) rename tests/{ => MEDS_Transforms}/test_occlude_outliers.py (95%) rename tests/{ => MEDS_Transforms}/test_reorder_measurements.py (92%) rename tests/{ => MEDS_Transforms}/test_reshard_to_split.py (96%) rename tests/{ => MEDS_Transforms}/test_tensorization.py (73%) rename tests/{ => MEDS_Transforms}/test_tokenization.py (100%) rename tests/{ => MEDS_Transforms}/transform_tester_base.py (99%) diff --git a/tests/MEDS_Extract/__init__.py b/tests/MEDS_Extract/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_extract.py b/tests/MEDS_Extract/test_extract.py similarity index 99% rename from tests/test_extract.py rename to tests/MEDS_Extract/test_extract.py index 787d5d8..b1b50a3 100644 --- a/tests/test_extract.py +++ b/tests/MEDS_Extract/test_extract.py @@ -38,7 +38,7 @@ import polars as pl from meds import __version__ as MEDS_VERSION -from .utils import assert_df_equal, run_command +from tests.utils import assert_df_equal, run_command # Test data (inputs) diff --git a/tests/test_extract_no_metadata.py b/tests/MEDS_Extract/test_extract_no_metadata.py similarity index 99% rename from tests/test_extract_no_metadata.py rename to tests/MEDS_Extract/test_extract_no_metadata.py index 2391a97..ed783f4 100644 --- a/tests/test_extract_no_metadata.py +++ b/tests/MEDS_Extract/test_extract_no_metadata.py @@ -38,7 +38,7 @@ import polars as pl from meds import __version__ as MEDS_VERSION -from .utils import assert_df_equal, run_command +from tests.utils import assert_df_equal, run_command # Test data (inputs) diff --git a/tests/MEDS_Transforms/__init__.py b/tests/MEDS_Transforms/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_add_time_derived_measurements.py b/tests/MEDS_Transforms/test_add_time_derived_measurements.py similarity index 96% rename from tests/test_add_time_derived_measurements.py rename to tests/MEDS_Transforms/test_add_time_derived_measurements.py index 964cad9..ed7bbba 100644 --- a/tests/test_add_time_derived_measurements.py +++ b/tests/MEDS_Transforms/test_add_time_derived_measurements.py @@ -3,11 +3,17 @@ Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed scripts. """ +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) from meds import subject_id_field -from .transform_tester_base import ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, single_stage_transform_tester -from .utils import parse_meds_csvs +from tests.MEDS_Transforms.transform_tester_base import ( + ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, + single_stage_transform_tester, +) +from tests.utils import parse_meds_csvs AGE_CALCULATION_STR = """ See `add_time_derived_measurements.py` for the source of the constant value. diff --git a/tests/test_aggregate_code_metadata.py b/tests/MEDS_Transforms/test_aggregate_code_metadata.py similarity index 97% rename from tests/test_aggregate_code_metadata.py rename to tests/MEDS_Transforms/test_aggregate_code_metadata.py index 7d3d2a4..c62bb94 100644 --- a/tests/test_aggregate_code_metadata.py +++ b/tests/MEDS_Transforms/test_aggregate_code_metadata.py @@ -3,10 +3,13 @@ Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed scripts. """ +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) import polars as pl -from .transform_tester_base import ( +from tests.MEDS_Transforms.transform_tester_base import ( AGGREGATE_CODE_METADATA_SCRIPT, MEDS_CODE_METADATA_SCHEMA, single_stage_transform_tester, diff --git a/tests/test_extract_values.py b/tests/MEDS_Transforms/test_extract_values.py similarity index 94% rename from tests/test_extract_values.py rename to tests/MEDS_Transforms/test_extract_values.py index 0368e5b..d273c99 100644 --- a/tests/test_extract_values.py +++ b/tests/MEDS_Transforms/test_extract_values.py @@ -4,7 +4,15 @@ scripts. """ -from .transform_tester_base import EXTRACT_VALUES_SCRIPT, parse_shards_yaml, single_stage_transform_tester +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + +from tests.MEDS_Transforms.transform_tester_base import ( + EXTRACT_VALUES_SCRIPT, + parse_shards_yaml, + single_stage_transform_tester, +) INPUT_SHARDS = parse_shards_yaml( """ diff --git a/tests/test_filter_measurements.py b/tests/MEDS_Transforms/test_filter_measurements.py similarity index 96% rename from tests/test_filter_measurements.py rename to tests/MEDS_Transforms/test_filter_measurements.py index 3a34835..a3e53d9 100644 --- a/tests/test_filter_measurements.py +++ b/tests/MEDS_Transforms/test_filter_measurements.py @@ -3,10 +3,16 @@ Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed scripts. """ +import rootutils +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) -from .transform_tester_base import FILTER_MEASUREMENTS_SCRIPT, single_stage_transform_tester -from .utils import parse_meds_csvs + +from tests.MEDS_Transforms.transform_tester_base import ( + FILTER_MEASUREMENTS_SCRIPT, + single_stage_transform_tester, +) +from tests.utils import parse_meds_csvs # This is the code metadata # MEDS_CODE_METADATA_CSV = """ diff --git a/tests/test_filter_subjects.py b/tests/MEDS_Transforms/test_filter_subjects.py similarity index 91% rename from tests/test_filter_subjects.py rename to tests/MEDS_Transforms/test_filter_subjects.py index 1defee4..4d4f2ca 100644 --- a/tests/test_filter_subjects.py +++ b/tests/MEDS_Transforms/test_filter_subjects.py @@ -4,10 +4,13 @@ scripts. """ +import rootutils from meds import subject_id_field -from .transform_tester_base import FILTER_SUBJECTS_SCRIPT, single_stage_transform_tester -from .utils import parse_meds_csvs +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + +from tests.MEDS_Transforms.transform_tester_base import FILTER_SUBJECTS_SCRIPT, single_stage_transform_tester +from tests.utils import parse_meds_csvs WANT_TRAIN_0 = f""" {subject_id_field},time,code,numeric_value diff --git a/tests/test_fit_vocabulary_indices.py b/tests/MEDS_Transforms/test_fit_vocabulary_indices.py similarity index 91% rename from tests/test_fit_vocabulary_indices.py rename to tests/MEDS_Transforms/test_fit_vocabulary_indices.py index 607b41e..ea6c1c5 100644 --- a/tests/test_fit_vocabulary_indices.py +++ b/tests/MEDS_Transforms/test_fit_vocabulary_indices.py @@ -4,8 +4,12 @@ scripts. """ +import rootutils -from .transform_tester_base import ( +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + + +from tests.MEDS_Transforms.transform_tester_base import ( FIT_VOCABULARY_INDICES_SCRIPT, parse_code_metadata_csv, single_stage_transform_tester, diff --git a/tests/test_multi_stage_preprocess_pipeline.py b/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py similarity index 99% rename from tests/test_multi_stage_preprocess_pipeline.py rename to tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py index 32c4a27..15c4b96 100644 --- a/tests/test_multi_stage_preprocess_pipeline.py +++ b/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py @@ -17,13 +17,17 @@ The stage configuration arguments will be as given in the yaml block below: """ +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + from datetime import datetime import polars as pl from meds import subject_id_field from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict -from .transform_tester_base import ( +from tests.MEDS_Transforms.transform_tester_base import ( ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, AGGREGATE_CODE_METADATA_SCRIPT, FILTER_SUBJECTS_SCRIPT, diff --git a/tests/test_normalization.py b/tests/MEDS_Transforms/test_normalization.py similarity index 96% rename from tests/test_normalization.py rename to tests/MEDS_Transforms/test_normalization.py index 14207c4..b6f386f 100644 --- a/tests/test_normalization.py +++ b/tests/MEDS_Transforms/test_normalization.py @@ -5,9 +5,12 @@ """ import polars as pl +import rootutils -from .transform_tester_base import NORMALIZATION_SCRIPT, single_stage_transform_tester -from .utils import MEDS_PL_SCHEMA, parse_meds_csvs +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + +from tests.MEDS_Transforms.transform_tester_base import NORMALIZATION_SCRIPT, single_stage_transform_tester +from tests.utils import MEDS_PL_SCHEMA, parse_meds_csvs # This is the code metadata file we'll use in this transform test. It is different than the default as we need # a code/vocab_index diff --git a/tests/test_occlude_outliers.py b/tests/MEDS_Transforms/test_occlude_outliers.py similarity index 95% rename from tests/test_occlude_outliers.py rename to tests/MEDS_Transforms/test_occlude_outliers.py index f13a4fa..ad3d321 100644 --- a/tests/test_occlude_outliers.py +++ b/tests/MEDS_Transforms/test_occlude_outliers.py @@ -4,11 +4,14 @@ scripts. """ +import rootutils + +rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) import polars as pl -from .transform_tester_base import OCCLUDE_OUTLIERS_SCRIPT, single_stage_transform_tester -from .utils import MEDS_PL_SCHEMA, parse_meds_csvs +from tests.MEDS_Transforms.transform_tester_base import OCCLUDE_OUTLIERS_SCRIPT, single_stage_transform_tester +from tests.utils import MEDS_PL_SCHEMA, parse_meds_csvs # This is the code metadata # MEDS_CODE_METADATA_CSV = """ diff --git a/tests/test_reorder_measurements.py b/tests/MEDS_Transforms/test_reorder_measurements.py similarity index 92% rename from tests/test_reorder_measurements.py rename to tests/MEDS_Transforms/test_reorder_measurements.py index 7cc7aaa..c4a2a54 100644 --- a/tests/test_reorder_measurements.py +++ b/tests/MEDS_Transforms/test_reorder_measurements.py @@ -4,9 +4,15 @@ scripts. """ +import rootutils -from .transform_tester_base import REORDER_MEASUREMENTS_SCRIPT, single_stage_transform_tester -from .utils import parse_meds_csvs +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + +from tests.MEDS_Transforms.transform_tester_base import ( + REORDER_MEASUREMENTS_SCRIPT, + single_stage_transform_tester, +) +from tests.utils import parse_meds_csvs ORDERED_CODE_PATTERNS = [ "ADMISSION.*", diff --git a/tests/test_reshard_to_split.py b/tests/MEDS_Transforms/test_reshard_to_split.py similarity index 96% rename from tests/test_reshard_to_split.py rename to tests/MEDS_Transforms/test_reshard_to_split.py index b479e5a..19008bc 100644 --- a/tests/test_reshard_to_split.py +++ b/tests/MEDS_Transforms/test_reshard_to_split.py @@ -3,12 +3,15 @@ Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed scripts. """ +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) from meds import subject_id_field -from .transform_tester_base import RESHARD_TO_SPLIT_SCRIPT, single_stage_transform_tester -from .utils import parse_meds_csvs +from tests.MEDS_Transforms.transform_tester_base import RESHARD_TO_SPLIT_SCRIPT, single_stage_transform_tester +from tests.utils import parse_meds_csvs IN_SHARDS_MAP = { "0": [68729, 1195293], diff --git a/tests/test_tensorization.py b/tests/MEDS_Transforms/test_tensorization.py similarity index 73% rename from tests/test_tensorization.py rename to tests/MEDS_Transforms/test_tensorization.py index 0337155..b648e6e 100644 --- a/tests/test_tensorization.py +++ b/tests/MEDS_Transforms/test_tensorization.py @@ -6,11 +6,14 @@ scripts. """ +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict -from .test_tokenization import WANT_EVENT_SEQS as TOKENIZED_SHARDS -from .transform_tester_base import TENSORIZATION_SCRIPT, single_stage_transform_tester +from tests.MEDS_Transforms.test_tokenization import WANT_EVENT_SEQS as TOKENIZED_SHARDS +from tests.MEDS_Transforms.transform_tester_base import TENSORIZATION_SCRIPT, single_stage_transform_tester WANT_NRTS = { f'{k.replace("event_seqs/", "")}.nrt': JointNestedRaggedTensorDict( diff --git a/tests/test_tokenization.py b/tests/MEDS_Transforms/test_tokenization.py similarity index 100% rename from tests/test_tokenization.py rename to tests/MEDS_Transforms/test_tokenization.py diff --git a/tests/transform_tester_base.py b/tests/MEDS_Transforms/transform_tester_base.py similarity index 99% rename from tests/transform_tester_base.py rename to tests/MEDS_Transforms/transform_tester_base.py index 21fc231..9b1184f 100644 --- a/tests/transform_tester_base.py +++ b/tests/MEDS_Transforms/transform_tester_base.py @@ -25,7 +25,7 @@ from meds import subject_id_field from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict -from .utils import MEDS_PL_SCHEMA, assert_df_equal, parse_meds_csvs, run_command +from tests.utils import MEDS_PL_SCHEMA, assert_df_equal, parse_meds_csvs, run_command root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) From 26665c9f686ec8e45b211dd1ed1f82ceba782e68 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 16:21:17 -0400 Subject: [PATCH 10/23] Moved single and multi-stage helpers out to be general in preparation for extraction test refactoring. --- .../test_aggregate_code_metadata.py | 1 + .../test_multi_stage_preprocess_pipeline.py | 1 - .../MEDS_Transforms/transform_tester_base.py | 328 ++++-------------- tests/utils.py | 264 ++++++++++++++ 4 files changed, 334 insertions(+), 260 deletions(-) diff --git a/tests/MEDS_Transforms/test_aggregate_code_metadata.py b/tests/MEDS_Transforms/test_aggregate_code_metadata.py index c62bb94..48ff79b 100644 --- a/tests/MEDS_Transforms/test_aggregate_code_metadata.py +++ b/tests/MEDS_Transforms/test_aggregate_code_metadata.py @@ -186,4 +186,5 @@ def test_aggregate_code_metadata(): want_metadata=WANT_OUTPUT_CODE_METADATA_FILE, input_code_metadata=MEDS_CODE_METADATA_FILE, do_use_config_yaml=True, + assert_no_other_outputs=False, ) diff --git a/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py b/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py index 15c4b96..0deade2 100644 --- a/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py +++ b/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py @@ -1089,6 +1089,5 @@ def test_pipeline(): **WANT_TOKENIZATION_EVENT_SEQS, **WANT_NRTs, }, - outputs_from_cohort_dir=True, input_code_metadata=MEDS_CODE_METADATA, ) diff --git a/tests/MEDS_Transforms/transform_tester_base.py b/tests/MEDS_Transforms/transform_tester_base.py index 9b1184f..1e692e4 100644 --- a/tests/MEDS_Transforms/transform_tester_base.py +++ b/tests/MEDS_Transforms/transform_tester_base.py @@ -11,21 +11,16 @@ except ImportError: from yaml import Loader -import json import os -import tempfile from collections import defaultdict -from contextlib import contextmanager from io import StringIO from pathlib import Path -import numpy as np import polars as pl import rootutils from meds import subject_id_field -from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict -from tests.utils import MEDS_PL_SCHEMA, assert_df_equal, parse_meds_csvs, run_command +from tests.utils import FILE_T, MEDS_PL_SCHEMA, multi_stage_tester, parse_meds_csvs, single_stage_tester root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) @@ -223,177 +218,45 @@ def parse_code_metadata_csv(csv_str: str) -> pl.DataFrame: MEDS_CODE_METADATA = parse_code_metadata_csv(MEDS_CODE_METADATA_CSV) -def check_NRT_output( - output_fp: Path, - want_nrt: JointNestedRaggedTensorDict, -): - assert output_fp.is_file(), f"Expected {output_fp} to exist." - - got_nrt = JointNestedRaggedTensorDict.load(output_fp) - - # assert got_nrt.schema == want_nrt.schema, ( - # f"Expected the schema of the NRT at {output_fp} to be equal to the target.\n" - # f"Wanted:\n{want_nrt.schema}\n" - # f"Got:\n{got_nrt.schema}" - # ) - - want_tensors = want_nrt.tensors - got_tensors = got_nrt.tensors - - assert got_tensors.keys() == want_tensors.keys(), ( - f"Expected the keys of the NRT at {output_fp} to be equal to the target.\n" - f"Wanted:\n{list(want_tensors.keys())}\n" - f"Got:\n{list(got_tensors.keys())}" - ) - - for k in want_tensors.keys(): - want_v = want_tensors[k] - got_v = got_tensors[k] - - assert type(want_v) is type(got_v), ( - f"Expected tensor {k} of the NRT at {output_fp} to be of the same type as the target.\n" - f"Wanted:\n{type(want_v)}\n" - f"Got:\n{type(got_v)}" - ) - - if isinstance(want_v, list): - assert len(want_v) == len(got_v), ( - f"Expected list {k} of the NRT at {output_fp} to be of the same length as the target.\n" - f"Wanted:\n{len(want_v)}\n" - f"Got:\n{len(got_v)}" - ) - for i, (want_i, got_i) in enumerate(zip(want_v, got_v)): - assert np.array_equal(want_i, got_i, equal_nan=True), ( - f"Expected tensor {k}[{i}] of the NRT at {output_fp} to be equal to the target.\n" - f"Wanted:\n{want_i}\n" - f"Got:\n{got_i}" - ) - else: - assert np.array_equal(want_v, got_v, equal_nan=True), ( - f"Expected tensor {k} of the NRT at {output_fp} to be equal to the target.\n" - f"Wanted:\n{want_v}\n" - f"Got:\n{got_v}" - ) - - -def check_df_output( - output_fp: Path, - want_df: pl.DataFrame, - check_column_order: bool = False, - check_row_order: bool = True, - **kwargs, -): - assert output_fp.is_file(), f"Expected {output_fp} to exist." - - got_df = pl.read_parquet(output_fp, glob=False) - assert_df_equal( - want_df, - got_df, - (f"Expected the dataframe at {output_fp} to be equal to the target.\n"), - check_column_order=check_column_order, - check_row_order=check_row_order, - **kwargs, - ) - - -@contextmanager -def input_MEDS_dataset( +def remap_inputs_for_transform( input_code_metadata: pl.DataFrame | str | None = None, input_shards: dict[str, pl.DataFrame] | None = None, input_shards_map: dict[str, list[int]] | None = None, input_splits_map: dict[str, list[int]] | None = None, -): - with tempfile.TemporaryDirectory() as d: - MEDS_dir = Path(d) / "MEDS_cohort" - cohort_dir = Path(d) / "output_cohort" - - MEDS_data_dir = MEDS_dir / "data" - MEDS_metadata_dir = MEDS_dir / "metadata" - - # Create the directories - MEDS_data_dir.mkdir(parents=True) - MEDS_metadata_dir.mkdir(parents=True) - cohort_dir.mkdir(parents=True) - - # Write the shards map - if input_shards_map is None: - input_shards_map = SHARDS - - shards_fp = MEDS_metadata_dir / ".shards.json" - shards_fp.write_text(json.dumps(input_shards_map)) - - # Write the splits parquet file - if input_splits_map is None: - input_splits_map = SPLITS - input_splits_as_df = defaultdict(list) - for split_name, subject_ids in input_splits_map.items(): - input_splits_as_df[subject_id_field].extend(subject_ids) - input_splits_as_df["split"].extend([split_name] * len(subject_ids)) - input_splits_df = pl.DataFrame(input_splits_as_df) - input_splits_fp = MEDS_metadata_dir / "subject_splits.parquet" - input_splits_df.write_parquet(input_splits_fp, use_pyarrow=True) - - if input_shards is None: - input_shards = MEDS_SHARDS - - # Write the shards - for shard_name, df in input_shards.items(): - fp = MEDS_data_dir / f"{shard_name}.parquet" - fp.parent.mkdir(parents=True, exist_ok=True) - df.write_parquet(fp, use_pyarrow=True) - - code_metadata_fp = MEDS_metadata_dir / "codes.parquet" - if input_code_metadata is None: - input_code_metadata = MEDS_CODE_METADATA - elif isinstance(input_code_metadata, str): - input_code_metadata = parse_code_metadata_csv(input_code_metadata) - input_code_metadata.write_parquet(code_metadata_fp, use_pyarrow=True) - - yield MEDS_dir, cohort_dir - - -def check_outputs( - cohort_dir: Path, - want_data: dict[str, pl.DataFrame] | None = None, - want_metadata: dict[str, pl.DataFrame] | pl.DataFrame | None = None, - assert_no_other_outputs: bool = True, - outputs_from_cohort_dir: bool = False, -): - if want_metadata is not None: - if isinstance(want_metadata, pl.DataFrame): - want_metadata = {"codes.parquet": want_metadata} - metadata_root = cohort_dir if outputs_from_cohort_dir else cohort_dir / "metadata" - for shard_name, want in want_metadata.items(): - if Path(shard_name).suffix == "": - shard_name = f"{shard_name}.parquet" - check_df_output(metadata_root / shard_name, want) +) -> dict[str, FILE_T]: + unified_inputs = {} - if want_data: - data_root = cohort_dir if outputs_from_cohort_dir else cohort_dir / "data" - all_file_suffixes = set() - for shard_name, want in want_data.items(): - if Path(shard_name).suffix == "": - shard_name = f"{shard_name}.parquet" - - file_suffix = Path(shard_name).suffix - all_file_suffixes.add(file_suffix) - - output_fp = data_root / f"{shard_name}" - if file_suffix == ".parquet": - check_df_output(output_fp, want) - elif file_suffix == ".nrt": - check_NRT_output(output_fp, want) - else: - raise ValueError(f"Unknown file suffix: {file_suffix}") - - if assert_no_other_outputs: - all_outputs = [] - for suffix in all_file_suffixes: - all_outputs.extend(list((data_root).glob(f"**/*{suffix}"))) - assert len(want_data) == len(all_outputs), ( - f"Want {len(want_data)} outputs, but found {len(all_outputs)}.\n" - f"Found outputs: {[fp.relative_to(data_root) for fp in all_outputs]}\n" - ) + if input_code_metadata is None: + input_code_metadata = MEDS_CODE_METADATA + elif isinstance(input_code_metadata, str): + input_code_metadata = parse_code_metadata_csv(input_code_metadata) + + unified_inputs["metadata/codes.parquet"] = input_code_metadata + + if input_shards is None: + input_shards = MEDS_SHARDS + + for shard_name, df in input_shards.items(): + unified_inputs[f"data/{shard_name}.parquet"] = df + + if input_shards_map is None: + input_shards_map = SHARDS + + unified_inputs["metadata/.shards.json"] = input_shards_map + + if input_splits_map is None: + input_splits_map = SPLITS + + input_splits_as_df = defaultdict(list) + for split_name, subject_ids in input_splits_map.items(): + input_splits_as_df[subject_id_field].extend(subject_ids) + input_splits_as_df["split"].extend([split_name] * len(subject_ids)) + + input_splits_df = pl.DataFrame(input_splits_as_df) + + unified_inputs["metadata/subject_splits.parquet"] = input_splits_df + + return unified_inputs def single_stage_transform_tester( @@ -408,43 +271,28 @@ def single_stage_transform_tester( should_error: bool = False, **input_data_kwargs, ): - with input_MEDS_dataset(**input_data_kwargs) as (MEDS_dir, cohort_dir): - pipeline_config_kwargs = { - "input_dir": str(MEDS_dir.resolve()), - "cohort_dir": str(cohort_dir.resolve()), - "stages": [stage_name], - "hydra.verbose": True, - } - - if transform_stage_kwargs: - pipeline_config_kwargs["stage_configs"] = {stage_name: transform_stage_kwargs} - - run_command_kwargs = { - "script": transform_script, - "hydra_kwargs": pipeline_config_kwargs, - "test_name": f"Single stage transform: {stage_name}", - "should_error": should_error, - } - if do_use_config_yaml: - run_command_kwargs["do_use_config_yaml"] = True - run_command_kwargs["config_name"] = "preprocess" - if do_pass_stage_name: - run_command_kwargs["stage"] = stage_name - run_command_kwargs["do_pass_stage_name"] = True - - # Run the transform - stderr, stdout = run_command(**run_command_kwargs) - if should_error: - return - - try: - check_outputs(cohort_dir, want_data=want_data, want_metadata=want_metadata) - except Exception as e: - raise AssertionError( - f"Single stage transform {stage_name} failed.\n" - f"Script stdout:\n{stdout}\n" - f"Script stderr:\n{stderr}" - ) from e + base_kwargs = { + "script": transform_script, + "stage_name": stage_name, + "stage_kwargs": transform_stage_kwargs, + "do_pass_stage_name": do_pass_stage_name, + "do_use_config_yaml": do_use_config_yaml, + "assert_no_other_outputs": assert_no_other_outputs, + "should_error": should_error, + "config_name": "preprocess", + "input_files": remap_inputs_for_transform(**input_data_kwargs), + } + + want_outputs = {} + if want_data: + for data_fn, want in want_data.items(): + want_outputs[f"data/{data_fn}"] = want + if want_metadata is not None: + want_outputs["metadata/codes.parquet"] = want_metadata + + base_kwargs["want_outputs"] = want_outputs + + single_stage_tester(**base_kwargs) def multi_stage_transform_tester( @@ -454,55 +302,17 @@ def multi_stage_transform_tester( do_pass_stage_name: bool | dict[str, bool] = True, want_data: dict[str, pl.DataFrame] | None = None, want_metadata: pl.DataFrame | None = None, - outputs_from_cohort_dir: bool = True, **input_data_kwargs, ): - with input_MEDS_dataset(**input_data_kwargs) as (MEDS_dir, cohort_dir): - match stage_configs: - case None: - stage_configs = {} - case str(): - stage_configs = load_yaml(stage_configs, Loader=Loader) - case dict(): - pass - case _: - raise ValueError(f"Unknown stage_configs type: {type(stage_configs)}") - - match do_pass_stage_name: - case True: - do_pass_stage_name = {stage_name: True for stage_name in stage_names} - case False: - do_pass_stage_name = {stage_name: False for stage_name in stage_names} - case dict(): - pass - case _: - raise ValueError(f"Unknown do_pass_stage_name type: {type(do_pass_stage_name)}") - - pipeline_config_kwargs = { - "input_dir": str(MEDS_dir.resolve()), - "cohort_dir": str(cohort_dir.resolve()), - "stages": stage_names, - "stage_configs": stage_configs, - "hydra.verbose": True, - } - - script_outputs = {} - n_stages = len(stage_names) - for i, (stage, script) in enumerate(zip(stage_names, transform_scripts)): - script_outputs[stage] = run_command( - script=script, - hydra_kwargs=pipeline_config_kwargs, - do_use_config_yaml=True, - config_name="preprocess", - test_name=f"Multi stage transform {i}/{n_stages}: {stage}", - stage_name=stage, - do_pass_stage_name=do_pass_stage_name[stage], - ) - - check_outputs( - cohort_dir, - want_data=want_data, - want_metadata=want_metadata, - outputs_from_cohort_dir=outputs_from_cohort_dir, - assert_no_other_outputs=False, # this currently doesn't work due to metadata / data confusions. - ) + base_kwargs = { + "scripts": transform_scripts, + "stage_names": stage_names, + "stage_configs": stage_configs, + "do_pass_stage_name": do_pass_stage_name, + "assert_no_other_outputs": False, # TODO(mmd): eventually fix + "config_name": "preprocess", + "input_files": remap_inputs_for_transform(**input_data_kwargs), + "want_outputs": {**want_data, **want_metadata}, + } + + multi_stage_tester(**base_kwargs) diff --git a/tests/utils.py b/tests/utils.py index 7cb7c18..c258513 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,11 +1,22 @@ +import json import subprocess import tempfile +from contextlib import contextmanager from io import StringIO from pathlib import Path +from typing import Any +import numpy as np import polars as pl +from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict from omegaconf import OmegaConf from polars.testing import assert_frame_equal +from yaml import load as load_yaml + +try: + from yaml import CLoader as Loader +except ImportError: + from yaml import Loader DEFAULT_CSV_TS_FORMAT = "%m/%d/%Y, %H:%M:%S" @@ -192,3 +203,256 @@ def assert_df_equal(want: pl.DataFrame, got: pl.DataFrame, msg: str = None, **kw print("got:") print(got) raise AssertionError(f"{msg}\n{e}") from e + + +def check_NRT_output( + output_fp: Path, + want_nrt: JointNestedRaggedTensorDict, +): + assert output_fp.is_file(), f"Expected {output_fp} to exist." + + got_nrt = JointNestedRaggedTensorDict.load(output_fp) + + # assert got_nrt.schema == want_nrt.schema, ( + # f"Expected the schema of the NRT at {output_fp} to be equal to the target.\n" + # f"Wanted:\n{want_nrt.schema}\n" + # f"Got:\n{got_nrt.schema}" + # ) + + want_tensors = want_nrt.tensors + got_tensors = got_nrt.tensors + + assert got_tensors.keys() == want_tensors.keys(), ( + f"Expected the keys of the NRT at {output_fp} to be equal to the target.\n" + f"Wanted:\n{list(want_tensors.keys())}\n" + f"Got:\n{list(got_tensors.keys())}" + ) + + for k in want_tensors.keys(): + want_v = want_tensors[k] + got_v = got_tensors[k] + + assert type(want_v) is type(got_v), ( + f"Expected tensor {k} of the NRT at {output_fp} to be of the same type as the target.\n" + f"Wanted:\n{type(want_v)}\n" + f"Got:\n{type(got_v)}" + ) + + if isinstance(want_v, list): + assert len(want_v) == len(got_v), ( + f"Expected list {k} of the NRT at {output_fp} to be of the same length as the target.\n" + f"Wanted:\n{len(want_v)}\n" + f"Got:\n{len(got_v)}" + ) + for i, (want_i, got_i) in enumerate(zip(want_v, got_v)): + assert np.array_equal(want_i, got_i, equal_nan=True), ( + f"Expected tensor {k}[{i}] of the NRT at {output_fp} to be equal to the target.\n" + f"Wanted:\n{want_i}\n" + f"Got:\n{got_i}" + ) + else: + assert np.array_equal(want_v, got_v, equal_nan=True), ( + f"Expected tensor {k} of the NRT at {output_fp} to be equal to the target.\n" + f"Wanted:\n{want_v}\n" + f"Got:\n{got_v}" + ) + + +def check_df_output( + output_fp: Path, + want_df: pl.DataFrame, + check_column_order: bool = False, + check_row_order: bool = True, + **kwargs, +): + assert output_fp.is_file(), f"Expected {output_fp} to exist." + + got_df = pl.read_parquet(output_fp, glob=False) + assert_df_equal( + want_df, + got_df, + (f"Expected the dataframe at {output_fp} to be equal to the target.\n"), + check_column_order=check_column_order, + check_row_order=check_row_order, + **kwargs, + ) + + +FILE_T = pl.DataFrame | dict[str, Any] + + +@contextmanager +def input_dataset(input_files: dict[str, FILE_T] | None = None): + with tempfile.TemporaryDirectory() as d: + input_dir = Path(d) / "input_cohort" + cohort_dir = Path(d) / "output_cohort" + + for filename, data in input_files.items(): + fp = input_dir / filename + fp.parent.mkdir(parents=True, exist_ok=True) + + match data: + case pl.DataFrame() if fp.suffix == "": + data.write_parquet(fp.with_suffix(".parquet"), use_pyarrow=True) + case pl.DataFrame() if fp.suffix == ".parquet": + data.write_parquet(fp, use_pyarrow=True) + case dict() if fp.suffix == "": + fp.with_suffix(".json").write_text(json.dumps(data)) + case dict() if fp.suffix.endswith(".json"): + fp.write_text(json.dumps(data)) + case _: + raise ValueError(f"Unknown data type {type(data)} for file {fp.relative_to(input_dir)}") + + yield input_dir, cohort_dir + + +def check_outputs( + cohort_dir: Path, + want_outputs: dict[str, pl.DataFrame], + assert_no_other_outputs: bool = True, +): + all_file_suffixes = set() + + for output_name, want in want_outputs.items(): + if Path(output_name).suffix == "": + output_name = f"{output_name}.parquet" + + file_suffix = Path(output_name).suffix + all_file_suffixes.add(file_suffix) + + output_fp = cohort_dir / output_name + + if not output_fp.is_file(): + raise AssertionError(f"Expected {output_fp} to exist.") + + match file_suffix: + case ".parquet": + check_df_output(output_fp, want) + case ".nrt": + check_NRT_output(output_fp, want) + case _: + raise ValueError(f"Unknown file suffix: {file_suffix}") + + if assert_no_other_outputs: + all_outputs = [] + for suffix in all_file_suffixes: + all_outputs.extend(list(cohort_dir.glob(f"**/*{suffix}"))) + assert len(want_outputs) == len(all_outputs), ( + f"Want {len(want_outputs)} outputs, but found {len(all_outputs)}.\n" + f"Found outputs: {[fp.relative_to(cohort_dir) for fp in all_outputs]}\n" + ) + + +def single_stage_tester( + script: str | Path, + stage_name: str, + stage_kwargs: dict[str, str] | None, + do_pass_stage_name: bool = False, + do_use_config_yaml: bool = False, + want_outputs: dict[str, pl.DataFrame] | None = None, + assert_no_other_outputs: bool = True, + should_error: bool = False, + config_name: str = "preprocess", + input_files: dict[str, FILE_T] | None = None, +): + with input_dataset(input_files) as (input_dir, cohort_dir): + pipeline_config_kwargs = { + "input_dir": str(input_dir.resolve()), + "cohort_dir": str(cohort_dir.resolve()), + "stages": [stage_name], + "hydra.verbose": True, + } + + if stage_kwargs: + pipeline_config_kwargs["stage_configs"] = {stage_name: stage_kwargs} + + run_command_kwargs = { + "script": script, + "hydra_kwargs": pipeline_config_kwargs, + "test_name": f"Single stage transform: {stage_name}", + "should_error": should_error, + "config_name": config_name, + } + if do_use_config_yaml: + run_command_kwargs["do_use_config_yaml"] = True + + if do_pass_stage_name: + run_command_kwargs["stage"] = stage_name + run_command_kwargs["do_pass_stage_name"] = True + + # Run the transform + stderr, stdout = run_command(**run_command_kwargs) + if should_error: + return + + try: + check_outputs( + cohort_dir, want_outputs=want_outputs, assert_no_other_outputs=assert_no_other_outputs + ) + except Exception as e: + raise AssertionError( + f"Single stage transform {stage_name} failed.\n" + f"Script stdout:\n{stdout}\n" + f"Script stderr:\n{stderr}" + ) from e + + +def multi_stage_tester( + scripts: list[str | Path], + stage_names: list[str], + stage_configs: dict[str, str] | str | None, + do_pass_stage_name: bool | dict[str, bool] = True, + want_outputs: dict[str, pl.DataFrame] | None = None, + assert_no_other_outputs: bool = False, + config_name: str = "preprocess", + input_files: dict[str, FILE_T] | None = None, + **pipeline_kwargs, +): + with input_dataset(input_files) as (input_dir, cohort_dir): + match stage_configs: + case None: + stage_configs = {} + case str(): + stage_configs = load_yaml(stage_configs, Loader=Loader) + case dict(): + pass + case _: + raise ValueError(f"Unknown stage_configs type: {type(stage_configs)}") + + match do_pass_stage_name: + case True: + do_pass_stage_name = {stage_name: True for stage_name in stage_names} + case False: + do_pass_stage_name = {stage_name: False for stage_name in stage_names} + case dict(): + pass + case _: + raise ValueError(f"Unknown do_pass_stage_name type: {type(do_pass_stage_name)}") + + pipeline_config_kwargs = { + "input_dir": str(input_dir.resolve()), + "cohort_dir": str(cohort_dir.resolve()), + "stages": stage_names, + "stage_configs": stage_configs, + "hydra.verbose": True, + **pipeline_kwargs, + } + + script_outputs = {} + n_stages = len(stage_names) + for i, (stage, script) in enumerate(zip(stage_names, scripts)): + script_outputs[stage] = run_command( + script=script, + hydra_kwargs=pipeline_config_kwargs, + do_use_config_yaml=True, + config_name=config_name, + test_name=f"Multi stage transform {i}/{n_stages}: {stage}", + stage_name=stage, + do_pass_stage_name=do_pass_stage_name[stage], + ) + + check_outputs( + cohort_dir, + want_outputs=want_outputs, + assert_no_other_outputs=assert_no_other_outputs, + ) From e111ef64ace283a2f07af69a6c7ca8d27a289581 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 16:39:15 -0400 Subject: [PATCH 11/23] Re-organized imports further to fix some issues. --- tests/MEDS_Extract/__init__.py | 24 ++++++++++ tests/MEDS_Extract/test_extract.py | 35 ++++---------- .../MEDS_Extract/test_extract_no_metadata.py | 35 ++++---------- tests/MEDS_Transforms/__init__.py | 46 +++++++++++++++++++ .../test_add_time_derived_measurements.py | 9 +--- .../test_aggregate_code_metadata.py | 5 +- tests/MEDS_Transforms/test_extract_values.py | 11 +---- .../test_filter_measurements.py | 10 +--- tests/MEDS_Transforms/test_filter_subjects.py | 6 +-- .../test_fit_vocabulary_indices.py | 11 +---- .../test_multi_stage_preprocess_pipeline.py | 8 +--- tests/MEDS_Transforms/test_normalization.py | 6 +-- .../MEDS_Transforms/test_occlude_outliers.py | 6 +-- .../test_reorder_measurements.py | 9 +--- .../MEDS_Transforms/test_reshard_to_split.py | 6 +-- tests/MEDS_Transforms/test_tensorization.py | 6 +-- tests/MEDS_Transforms/test_tokenization.py | 4 +- .../MEDS_Transforms/transform_tester_base.py | 45 ------------------ 18 files changed, 114 insertions(+), 168 deletions(-) diff --git a/tests/MEDS_Extract/__init__.py b/tests/MEDS_Extract/__init__.py index e69de29..14ddbce 100644 --- a/tests/MEDS_Extract/__init__.py +++ b/tests/MEDS_Extract/__init__.py @@ -0,0 +1,24 @@ +import os + +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + +extraction_root = root / "src" / "MEDS_transforms" / "extract" + +if os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1": + SHARD_EVENTS_SCRIPT = extraction_root / "shard_events.py" + SPLIT_AND_SHARD_SCRIPT = extraction_root / "split_and_shard_subjects.py" + CONVERT_TO_SHARDED_EVENTS_SCRIPT = extraction_root / "convert_to_sharded_events.py" + MERGE_TO_MEDS_COHORT_SCRIPT = extraction_root / "merge_to_MEDS_cohort.py" + EXTRACT_CODE_METADATA_SCRIPT = extraction_root / "extract_code_metadata.py" + FINALIZE_DATA_SCRIPT = extraction_root / "finalize_MEDS_data.py" + FINALIZE_METADATA_SCRIPT = extraction_root / "finalize_MEDS_metadata.py" +else: + SHARD_EVENTS_SCRIPT = "MEDS_extract-shard_events" + SPLIT_AND_SHARD_SCRIPT = "MEDS_extract-split_and_shard_subjects" + CONVERT_TO_SHARDED_EVENTS_SCRIPT = "MEDS_extract-convert_to_sharded_events" + MERGE_TO_MEDS_COHORT_SCRIPT = "MEDS_extract-merge_to_MEDS_cohort" + EXTRACT_CODE_METADATA_SCRIPT = "MEDS_extract-extract_code_metadata" + FINALIZE_DATA_SCRIPT = "MEDS_extract-finalize_MEDS_data" + FINALIZE_METADATA_SCRIPT = "MEDS_extract-finalize_MEDS_metadata" diff --git a/tests/MEDS_Extract/test_extract.py b/tests/MEDS_Extract/test_extract.py index b1b50a3..be6b246 100644 --- a/tests/MEDS_Extract/test_extract.py +++ b/tests/MEDS_Extract/test_extract.py @@ -4,32 +4,6 @@ scripts. """ -import os - -import rootutils - -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) - -code_root = root / "src" / "MEDS_transforms" -extraction_root = code_root / "extract" - -if os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1": - SHARD_EVENTS_SCRIPT = extraction_root / "shard_events.py" - SPLIT_AND_SHARD_SCRIPT = extraction_root / "split_and_shard_subjects.py" - CONVERT_TO_SHARDED_EVENTS_SCRIPT = extraction_root / "convert_to_sharded_events.py" - MERGE_TO_MEDS_COHORT_SCRIPT = extraction_root / "merge_to_MEDS_cohort.py" - EXTRACT_CODE_METADATA_SCRIPT = extraction_root / "extract_code_metadata.py" - FINALIZE_DATA_SCRIPT = extraction_root / "finalize_MEDS_data.py" - FINALIZE_METADATA_SCRIPT = extraction_root / "finalize_MEDS_metadata.py" -else: - SHARD_EVENTS_SCRIPT = "MEDS_extract-shard_events" - SPLIT_AND_SHARD_SCRIPT = "MEDS_extract-split_and_shard_subjects" - CONVERT_TO_SHARDED_EVENTS_SCRIPT = "MEDS_extract-convert_to_sharded_events" - MERGE_TO_MEDS_COHORT_SCRIPT = "MEDS_extract-merge_to_MEDS_cohort" - EXTRACT_CODE_METADATA_SCRIPT = "MEDS_extract-extract_code_metadata" - FINALIZE_DATA_SCRIPT = "MEDS_extract-finalize_MEDS_data" - FINALIZE_METADATA_SCRIPT = "MEDS_extract-finalize_MEDS_metadata" - import json import tempfile from io import StringIO @@ -38,6 +12,15 @@ import polars as pl from meds import __version__ as MEDS_VERSION +from tests.MEDS_Extract import ( + CONVERT_TO_SHARDED_EVENTS_SCRIPT, + EXTRACT_CODE_METADATA_SCRIPT, + FINALIZE_DATA_SCRIPT, + FINALIZE_METADATA_SCRIPT, + MERGE_TO_MEDS_COHORT_SCRIPT, + SHARD_EVENTS_SCRIPT, + SPLIT_AND_SHARD_SCRIPT, +) from tests.utils import assert_df_equal, run_command # Test data (inputs) diff --git a/tests/MEDS_Extract/test_extract_no_metadata.py b/tests/MEDS_Extract/test_extract_no_metadata.py index ed783f4..0fa8eec 100644 --- a/tests/MEDS_Extract/test_extract_no_metadata.py +++ b/tests/MEDS_Extract/test_extract_no_metadata.py @@ -4,32 +4,6 @@ scripts. """ -import os - -import rootutils - -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) - -code_root = root / "src" / "MEDS_transforms" -extraction_root = code_root / "extract" - -if os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1": - SHARD_EVENTS_SCRIPT = extraction_root / "shard_events.py" - SPLIT_AND_SHARD_SCRIPT = extraction_root / "split_and_shard_subjects.py" - CONVERT_TO_SHARDED_EVENTS_SCRIPT = extraction_root / "convert_to_sharded_events.py" - MERGE_TO_MEDS_COHORT_SCRIPT = extraction_root / "merge_to_MEDS_cohort.py" - EXTRACT_CODE_METADATA_SCRIPT = extraction_root / "extract_code_metadata.py" - FINALIZE_DATA_SCRIPT = extraction_root / "finalize_MEDS_data.py" - FINALIZE_METADATA_SCRIPT = extraction_root / "finalize_MEDS_metadata.py" -else: - SHARD_EVENTS_SCRIPT = "MEDS_extract-shard_events" - SPLIT_AND_SHARD_SCRIPT = "MEDS_extract-split_and_shard_subjects" - CONVERT_TO_SHARDED_EVENTS_SCRIPT = "MEDS_extract-convert_to_sharded_events" - MERGE_TO_MEDS_COHORT_SCRIPT = "MEDS_extract-merge_to_MEDS_cohort" - EXTRACT_CODE_METADATA_SCRIPT = "MEDS_extract-extract_code_metadata" - FINALIZE_DATA_SCRIPT = "MEDS_extract-finalize_MEDS_data" - FINALIZE_METADATA_SCRIPT = "MEDS_extract-finalize_MEDS_metadata" - import json import tempfile from io import StringIO @@ -38,6 +12,15 @@ import polars as pl from meds import __version__ as MEDS_VERSION +from tests.MEDS_Extract import ( + CONVERT_TO_SHARDED_EVENTS_SCRIPT, + EXTRACT_CODE_METADATA_SCRIPT, + FINALIZE_DATA_SCRIPT, + FINALIZE_METADATA_SCRIPT, + MERGE_TO_MEDS_COHORT_SCRIPT, + SHARD_EVENTS_SCRIPT, + SPLIT_AND_SHARD_SCRIPT, +) from tests.utils import assert_df_equal, run_command # Test data (inputs) diff --git a/tests/MEDS_Transforms/__init__.py b/tests/MEDS_Transforms/__init__.py index e69de29..a2d3d56 100644 --- a/tests/MEDS_Transforms/__init__.py +++ b/tests/MEDS_Transforms/__init__.py @@ -0,0 +1,46 @@ +import os + +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) + +code_root = root / "src" / "MEDS_transforms" +transforms_root = code_root / "transforms" +filters_root = code_root / "filters" + +if os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1": + # Root Source + AGGREGATE_CODE_METADATA_SCRIPT = code_root / "aggregate_code_metadata.py" + FIT_VOCABULARY_INDICES_SCRIPT = code_root / "fit_vocabulary_indices.py" + RESHARD_TO_SPLIT_SCRIPT = code_root / "reshard_to_split.py" + + # Filters + FILTER_MEASUREMENTS_SCRIPT = filters_root / "filter_measurements.py" + FILTER_SUBJECTS_SCRIPT = filters_root / "filter_subjects.py" + + # Transforms + ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT = transforms_root / "add_time_derived_measurements.py" + REORDER_MEASUREMENTS_SCRIPT = transforms_root / "reorder_measurements.py" + EXTRACT_VALUES_SCRIPT = transforms_root / "extract_values.py" + NORMALIZATION_SCRIPT = transforms_root / "normalization.py" + OCCLUDE_OUTLIERS_SCRIPT = transforms_root / "occlude_outliers.py" + TENSORIZATION_SCRIPT = transforms_root / "tensorization.py" + TOKENIZATION_SCRIPT = transforms_root / "tokenization.py" +else: + # Root Source + AGGREGATE_CODE_METADATA_SCRIPT = "MEDS_transform-aggregate_code_metadata" + FIT_VOCABULARY_INDICES_SCRIPT = "MEDS_transform-fit_vocabulary_indices" + RESHARD_TO_SPLIT_SCRIPT = "MEDS_transform-reshard_to_split" + + # Filters + FILTER_MEASUREMENTS_SCRIPT = "MEDS_transform-filter_measurements" + FILTER_SUBJECTS_SCRIPT = "MEDS_transform-filter_subjects" + + # Transforms + ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT = "MEDS_transform-add_time_derived_measurements" + REORDER_MEASUREMENTS_SCRIPT = "MEDS_transform-reorder_measurements" + EXTRACT_VALUES_SCRIPT = "MEDS_transform-extract_values" + NORMALIZATION_SCRIPT = "MEDS_transform-normalization" + OCCLUDE_OUTLIERS_SCRIPT = "MEDS_transform-occlude_outliers" + TENSORIZATION_SCRIPT = "MEDS_transform-tensorization" + TOKENIZATION_SCRIPT = "MEDS_transform-tokenization" diff --git a/tests/MEDS_Transforms/test_add_time_derived_measurements.py b/tests/MEDS_Transforms/test_add_time_derived_measurements.py index ed7bbba..ff2131d 100644 --- a/tests/MEDS_Transforms/test_add_time_derived_measurements.py +++ b/tests/MEDS_Transforms/test_add_time_derived_measurements.py @@ -3,16 +3,11 @@ Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed scripts. """ -import rootutils - -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) from meds import subject_id_field -from tests.MEDS_Transforms.transform_tester_base import ( - ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, - single_stage_transform_tester, -) +from tests.MEDS_Transforms import ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester from tests.utils import parse_meds_csvs AGE_CALCULATION_STR = """ diff --git a/tests/MEDS_Transforms/test_aggregate_code_metadata.py b/tests/MEDS_Transforms/test_aggregate_code_metadata.py index 48ff79b..acf0099 100644 --- a/tests/MEDS_Transforms/test_aggregate_code_metadata.py +++ b/tests/MEDS_Transforms/test_aggregate_code_metadata.py @@ -3,14 +3,11 @@ Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed scripts. """ -import rootutils - -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) import polars as pl +from tests.MEDS_Transforms import AGGREGATE_CODE_METADATA_SCRIPT from tests.MEDS_Transforms.transform_tester_base import ( - AGGREGATE_CODE_METADATA_SCRIPT, MEDS_CODE_METADATA_SCHEMA, single_stage_transform_tester, ) diff --git a/tests/MEDS_Transforms/test_extract_values.py b/tests/MEDS_Transforms/test_extract_values.py index d273c99..4114b3b 100644 --- a/tests/MEDS_Transforms/test_extract_values.py +++ b/tests/MEDS_Transforms/test_extract_values.py @@ -4,15 +4,8 @@ scripts. """ -import rootutils - -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) - -from tests.MEDS_Transforms.transform_tester_base import ( - EXTRACT_VALUES_SCRIPT, - parse_shards_yaml, - single_stage_transform_tester, -) +from tests.MEDS_Transforms import EXTRACT_VALUES_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import parse_shards_yaml, single_stage_transform_tester INPUT_SHARDS = parse_shards_yaml( """ diff --git a/tests/MEDS_Transforms/test_filter_measurements.py b/tests/MEDS_Transforms/test_filter_measurements.py index a3e53d9..9991a26 100644 --- a/tests/MEDS_Transforms/test_filter_measurements.py +++ b/tests/MEDS_Transforms/test_filter_measurements.py @@ -3,15 +3,9 @@ Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed scripts. """ -import rootutils -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) - - -from tests.MEDS_Transforms.transform_tester_base import ( - FILTER_MEASUREMENTS_SCRIPT, - single_stage_transform_tester, -) +from tests.MEDS_Transforms import FILTER_MEASUREMENTS_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester from tests.utils import parse_meds_csvs # This is the code metadata diff --git a/tests/MEDS_Transforms/test_filter_subjects.py b/tests/MEDS_Transforms/test_filter_subjects.py index 4d4f2ca..83f4068 100644 --- a/tests/MEDS_Transforms/test_filter_subjects.py +++ b/tests/MEDS_Transforms/test_filter_subjects.py @@ -4,12 +4,10 @@ scripts. """ -import rootutils from meds import subject_id_field -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) - -from tests.MEDS_Transforms.transform_tester_base import FILTER_SUBJECTS_SCRIPT, single_stage_transform_tester +from tests.MEDS_Transforms import FILTER_SUBJECTS_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester from tests.utils import parse_meds_csvs WANT_TRAIN_0 = f""" diff --git a/tests/MEDS_Transforms/test_fit_vocabulary_indices.py b/tests/MEDS_Transforms/test_fit_vocabulary_indices.py index ea6c1c5..78f637a 100644 --- a/tests/MEDS_Transforms/test_fit_vocabulary_indices.py +++ b/tests/MEDS_Transforms/test_fit_vocabulary_indices.py @@ -4,16 +4,9 @@ scripts. """ -import rootutils -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) - - -from tests.MEDS_Transforms.transform_tester_base import ( - FIT_VOCABULARY_INDICES_SCRIPT, - parse_code_metadata_csv, - single_stage_transform_tester, -) +from tests.MEDS_Transforms import FIT_VOCABULARY_INDICES_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import parse_code_metadata_csv, single_stage_transform_tester WANT_CSV = """ code,code/n_occurrences,code/n_subjects,values/n_occurrences,values/sum,values/sum_sqd,description,parent_codes,code/vocab_index diff --git a/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py b/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py index 0deade2..6667313 100644 --- a/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py +++ b/tests/MEDS_Transforms/test_multi_stage_preprocess_pipeline.py @@ -17,9 +17,6 @@ The stage configuration arguments will be as given in the yaml block below: """ -import rootutils - -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) from datetime import datetime @@ -27,7 +24,7 @@ from meds import subject_id_field from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict -from tests.MEDS_Transforms.transform_tester_base import ( +from tests.MEDS_Transforms import ( ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, AGGREGATE_CODE_METADATA_SCRIPT, FILTER_SUBJECTS_SCRIPT, @@ -36,9 +33,8 @@ OCCLUDE_OUTLIERS_SCRIPT, TENSORIZATION_SCRIPT, TOKENIZATION_SCRIPT, - multi_stage_transform_tester, - parse_shards_yaml, ) +from tests.MEDS_Transforms.transform_tester_base import multi_stage_transform_tester, parse_shards_yaml MEDS_CODE_METADATA = pl.DataFrame( { diff --git a/tests/MEDS_Transforms/test_normalization.py b/tests/MEDS_Transforms/test_normalization.py index b6f386f..4cc21ae 100644 --- a/tests/MEDS_Transforms/test_normalization.py +++ b/tests/MEDS_Transforms/test_normalization.py @@ -5,11 +5,9 @@ """ import polars as pl -import rootutils -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) - -from tests.MEDS_Transforms.transform_tester_base import NORMALIZATION_SCRIPT, single_stage_transform_tester +from tests.MEDS_Transforms import NORMALIZATION_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester from tests.utils import MEDS_PL_SCHEMA, parse_meds_csvs # This is the code metadata file we'll use in this transform test. It is different than the default as we need diff --git a/tests/MEDS_Transforms/test_occlude_outliers.py b/tests/MEDS_Transforms/test_occlude_outliers.py index ad3d321..8e30db6 100644 --- a/tests/MEDS_Transforms/test_occlude_outliers.py +++ b/tests/MEDS_Transforms/test_occlude_outliers.py @@ -4,13 +4,11 @@ scripts. """ -import rootutils - -rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) import polars as pl -from tests.MEDS_Transforms.transform_tester_base import OCCLUDE_OUTLIERS_SCRIPT, single_stage_transform_tester +from tests.MEDS_Transforms import OCCLUDE_OUTLIERS_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester from tests.utils import MEDS_PL_SCHEMA, parse_meds_csvs # This is the code metadata diff --git a/tests/MEDS_Transforms/test_reorder_measurements.py b/tests/MEDS_Transforms/test_reorder_measurements.py index c4a2a54..782c394 100644 --- a/tests/MEDS_Transforms/test_reorder_measurements.py +++ b/tests/MEDS_Transforms/test_reorder_measurements.py @@ -4,14 +4,9 @@ scripts. """ -import rootutils -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) - -from tests.MEDS_Transforms.transform_tester_base import ( - REORDER_MEASUREMENTS_SCRIPT, - single_stage_transform_tester, -) +from tests.MEDS_Transforms import REORDER_MEASUREMENTS_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester from tests.utils import parse_meds_csvs ORDERED_CODE_PATTERNS = [ diff --git a/tests/MEDS_Transforms/test_reshard_to_split.py b/tests/MEDS_Transforms/test_reshard_to_split.py index 19008bc..d0094a9 100644 --- a/tests/MEDS_Transforms/test_reshard_to_split.py +++ b/tests/MEDS_Transforms/test_reshard_to_split.py @@ -3,14 +3,12 @@ Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed scripts. """ -import rootutils - -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) from meds import subject_id_field -from tests.MEDS_Transforms.transform_tester_base import RESHARD_TO_SPLIT_SCRIPT, single_stage_transform_tester +from tests.MEDS_Transforms import RESHARD_TO_SPLIT_SCRIPT +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester from tests.utils import parse_meds_csvs IN_SHARDS_MAP = { diff --git a/tests/MEDS_Transforms/test_tensorization.py b/tests/MEDS_Transforms/test_tensorization.py index b648e6e..f56d064 100644 --- a/tests/MEDS_Transforms/test_tensorization.py +++ b/tests/MEDS_Transforms/test_tensorization.py @@ -6,14 +6,12 @@ scripts. """ -import rootutils - -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict +from tests.MEDS_Transforms import TENSORIZATION_SCRIPT from tests.MEDS_Transforms.test_tokenization import WANT_EVENT_SEQS as TOKENIZED_SHARDS -from tests.MEDS_Transforms.transform_tester_base import TENSORIZATION_SCRIPT, single_stage_transform_tester +from tests.MEDS_Transforms.transform_tester_base import single_stage_transform_tester WANT_NRTS = { f'{k.replace("event_seqs/", "")}.nrt': JointNestedRaggedTensorDict( diff --git a/tests/MEDS_Transforms/test_tokenization.py b/tests/MEDS_Transforms/test_tokenization.py index 811b1c9..d12cdb0 100644 --- a/tests/MEDS_Transforms/test_tokenization.py +++ b/tests/MEDS_Transforms/test_tokenization.py @@ -10,9 +10,11 @@ import polars as pl +from tests.MEDS_Transforms import TOKENIZATION_SCRIPT + from .test_normalization import NORMALIZED_MEDS_SCHEMA from .test_normalization import WANT_SHARDS as NORMALIZED_SHARDS -from .transform_tester_base import TOKENIZATION_SCRIPT, single_stage_transform_tester +from .transform_tester_base import single_stage_transform_tester SECONDS_PER_DAY = 60 * 60 * 24 diff --git a/tests/MEDS_Transforms/transform_tester_base.py b/tests/MEDS_Transforms/transform_tester_base.py index 1e692e4..90b2ae7 100644 --- a/tests/MEDS_Transforms/transform_tester_base.py +++ b/tests/MEDS_Transforms/transform_tester_base.py @@ -11,60 +11,15 @@ except ImportError: from yaml import Loader -import os from collections import defaultdict from io import StringIO from pathlib import Path import polars as pl -import rootutils from meds import subject_id_field from tests.utils import FILE_T, MEDS_PL_SCHEMA, multi_stage_tester, parse_meds_csvs, single_stage_tester -root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) - -code_root = root / "src" / "MEDS_transforms" -transforms_root = code_root / "transforms" -filters_root = code_root / "filters" - -if os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1": - # Root Source - AGGREGATE_CODE_METADATA_SCRIPT = code_root / "aggregate_code_metadata.py" - FIT_VOCABULARY_INDICES_SCRIPT = code_root / "fit_vocabulary_indices.py" - RESHARD_TO_SPLIT_SCRIPT = code_root / "reshard_to_split.py" - - # Filters - FILTER_MEASUREMENTS_SCRIPT = filters_root / "filter_measurements.py" - FILTER_SUBJECTS_SCRIPT = filters_root / "filter_subjects.py" - - # Transforms - ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT = transforms_root / "add_time_derived_measurements.py" - REORDER_MEASUREMENTS_SCRIPT = transforms_root / "reorder_measurements.py" - EXTRACT_VALUES_SCRIPT = transforms_root / "extract_values.py" - NORMALIZATION_SCRIPT = transforms_root / "normalization.py" - OCCLUDE_OUTLIERS_SCRIPT = transforms_root / "occlude_outliers.py" - TENSORIZATION_SCRIPT = transforms_root / "tensorization.py" - TOKENIZATION_SCRIPT = transforms_root / "tokenization.py" -else: - # Root Source - AGGREGATE_CODE_METADATA_SCRIPT = "MEDS_transform-aggregate_code_metadata" - FIT_VOCABULARY_INDICES_SCRIPT = "MEDS_transform-fit_vocabulary_indices" - RESHARD_TO_SPLIT_SCRIPT = "MEDS_transform-reshard_to_split" - - # Filters - FILTER_MEASUREMENTS_SCRIPT = "MEDS_transform-filter_measurements" - FILTER_SUBJECTS_SCRIPT = "MEDS_transform-filter_subjects" - - # Transforms - ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT = "MEDS_transform-add_time_derived_measurements" - REORDER_MEASUREMENTS_SCRIPT = "MEDS_transform-reorder_measurements" - EXTRACT_VALUES_SCRIPT = "MEDS_transform-extract_values" - NORMALIZATION_SCRIPT = "MEDS_transform-normalization" - OCCLUDE_OUTLIERS_SCRIPT = "MEDS_transform-occlude_outliers" - TENSORIZATION_SCRIPT = "MEDS_transform-tensorization" - TOKENIZATION_SCRIPT = "MEDS_transform-tokenization" - # Test MEDS data (inputs) SHARDS = { From 7675292b7e2971fe595a6520c34972d12d6ddf8e Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 16:56:36 -0400 Subject: [PATCH 12/23] Added a shard_events test. --- tests/MEDS_Extract/test_shard_events.py | 113 ++++++++++++++++++++++++ tests/utils.py | 15 +++- 2 files changed, 125 insertions(+), 3 deletions(-) create mode 100644 tests/MEDS_Extract/test_shard_events.py diff --git a/tests/MEDS_Extract/test_shard_events.py b/tests/MEDS_Extract/test_shard_events.py new file mode 100644 index 0000000..c6f0d0f --- /dev/null +++ b/tests/MEDS_Extract/test_shard_events.py @@ -0,0 +1,113 @@ +"""Tests the shard events stage in isolation. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + +from io import StringIO + +import polars as pl + +from tests.MEDS_Extract import SHARD_EVENTS_SCRIPT +from tests.utils import single_stage_tester + +SUBJECTS_CSV = """ +MRN,dob,eye_color,height +1195293,06/20/1978,BLUE,164.6868838269085 +239684,12/28/1980,BROWN,175.271115221764 +1500733,07/20/1986,BROWN,158.60131573580904 +814703,03/28/1976,HAZEL,156.48559093209357 +754281,12/19/1988,BROWN,166.22261567137025 +68729,03/09/1978,HAZEL,160.3953106166676 +""" + +ADMIT_VITALS_CSV = """ +subject_id,admit_date,disch_date,department,vitals_date,HR,temp +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:57:18",112.6,95.5 +754281,"01/03/2010, 06:27:59","01/03/2010, 08:22:13",PULMONARY,"01/03/2010, 06:27:59",142.0,99.8 +814703,"02/05/2010, 05:55:39","02/05/2010, 07:02:30",ORTHOPEDIC,"02/05/2010, 05:55:39",170.2,100.1 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:25:35",113.4,95.8 +68729,"05/26/2010, 02:30:56","05/26/2010, 04:51:52",PULMONARY,"05/26/2010, 02:30:56",86.0,97.8 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:12:31",112.5,99.8 +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 16:20:49",90.1,100.1 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 17:48:48",105.1,96.2 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 17:41:51",102.6,96.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:25:32",114.1,100.0 +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 14:54:38",91.4,100.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:41:33",107.5,100.4 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:24:44",107.7,100.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:45:19",119.8,99.9 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:23:52",109.0,100.0 +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 15:39:49",84.4,100.3 +""" + +EVENT_CFGS_YAML = """ +subjects: + subject_id_col: MRN + eye_color: + code: + - EYE_COLOR + - col(eye_color) + time: null + _metadata: + demo_metadata: + description: description + height: + code: HEIGHT + time: null + numeric_value: height + dob: + code: DOB + time: col(dob) + time_format: "%m/%d/%Y" +admit_vitals: + admissions: + code: + - ADMISSION + - col(department) + time: col(admit_date) + time_format: "%m/%d/%Y, %H:%M:%S" + discharge: + code: DISCHARGE + time: col(disch_date) + time_format: "%m/%d/%Y, %H:%M:%S" + HR: + code: HR + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: HR + _metadata: + input_metadata: + description: {"title": {"lab_code": "HR"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "HR"}} + temp: + code: TEMP + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: temp + _metadata: + input_metadata: + description: {"title": {"lab_code": "temp"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "temp"}} +""" + + +def test_extraction(): + single_stage_tester( + script=SHARD_EVENTS_SCRIPT, + stage_name="shard_events", + stage_kwargs={"row_chunksize": 10}, + want_outputs={ + "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), + "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV))[:10], + "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV))[10:], + }, + config_name="extract", + input_files={ + "subjects.csv": SUBJECTS_CSV, + "admit_vitals.csv": ADMIT_VITALS_CSV, + "admit_vitals.parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV)), + "event_cfgs.yaml": EVENT_CFGS_YAML, + }, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", # This makes the escape pass to hydra + ) diff --git a/tests/utils.py b/tests/utils.py index c258513..4a00eb4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -278,7 +278,7 @@ def check_df_output( ) -FILE_T = pl.DataFrame | dict[str, Any] +FILE_T = pl.DataFrame | dict[str, Any] | str @contextmanager @@ -296,10 +296,14 @@ def input_dataset(input_files: dict[str, FILE_T] | None = None): data.write_parquet(fp.with_suffix(".parquet"), use_pyarrow=True) case pl.DataFrame() if fp.suffix == ".parquet": data.write_parquet(fp, use_pyarrow=True) + case pl.DataFrame() if fp.suffix == ".csv": + data.write_csv(fp) case dict() if fp.suffix == "": fp.with_suffix(".json").write_text(json.dumps(data)) case dict() if fp.suffix.endswith(".json"): fp.write_text(json.dumps(data)) + case str(): + fp.write_text(data.strip()) case _: raise ValueError(f"Unknown data type {type(data)} for file {fp.relative_to(input_dir)}") @@ -354,13 +358,19 @@ def single_stage_tester( should_error: bool = False, config_name: str = "preprocess", input_files: dict[str, FILE_T] | None = None, + **pipeline_kwargs, ): with input_dataset(input_files) as (input_dir, cohort_dir): + for k, v in pipeline_kwargs.items(): + if type(v) is str and "{input_dir}" in v: + pipeline_kwargs[k] = v.format(input_dir=str(input_dir.resolve())) + pipeline_config_kwargs = { "input_dir": str(input_dir.resolve()), "cohort_dir": str(cohort_dir.resolve()), "stages": [stage_name], "hydra.verbose": True, + **pipeline_kwargs, } if stage_kwargs: @@ -372,9 +382,8 @@ def single_stage_tester( "test_name": f"Single stage transform: {stage_name}", "should_error": should_error, "config_name": config_name, + "do_use_config_yaml": do_use_config_yaml, } - if do_use_config_yaml: - run_command_kwargs["do_use_config_yaml"] = True if do_pass_stage_name: run_command_kwargs["stage"] = stage_name From 32939b212cf2d8e76161d68aef3a3b74789525b0 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 17:08:18 -0400 Subject: [PATCH 13/23] Added a split_and_shard test. --- tests/MEDS_Extract/test_shard_events.py | 14 +- .../test_split_and_shard_subjects.py | 134 ++++++++++++++++++ tests/utils.py | 8 ++ 3 files changed, 149 insertions(+), 7 deletions(-) create mode 100644 tests/MEDS_Extract/test_split_and_shard_subjects.py diff --git a/tests/MEDS_Extract/test_shard_events.py b/tests/MEDS_Extract/test_shard_events.py index c6f0d0f..88cd60d 100644 --- a/tests/MEDS_Extract/test_shard_events.py +++ b/tests/MEDS_Extract/test_shard_events.py @@ -92,16 +92,11 @@ """ -def test_extraction(): +def test_shard_events(): single_stage_tester( script=SHARD_EVENTS_SCRIPT, stage_name="shard_events", stage_kwargs={"row_chunksize": 10}, - want_outputs={ - "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), - "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV))[:10], - "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV))[10:], - }, config_name="extract", input_files={ "subjects.csv": SUBJECTS_CSV, @@ -109,5 +104,10 @@ def test_extraction(): "admit_vitals.parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV)), "event_cfgs.yaml": EVENT_CFGS_YAML, }, - event_conversion_config_fp="{input_dir}/event_cfgs.yaml", # This makes the escape pass to hydra + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + want_outputs={ + "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), + "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV))[:10], + "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV))[10:], + }, ) diff --git a/tests/MEDS_Extract/test_split_and_shard_subjects.py b/tests/MEDS_Extract/test_split_and_shard_subjects.py new file mode 100644 index 0000000..0a216d2 --- /dev/null +++ b/tests/MEDS_Extract/test_split_and_shard_subjects.py @@ -0,0 +1,134 @@ +"""Tests the full end-to-end extraction process. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + +from io import StringIO + +import polars as pl + +from tests.MEDS_Extract import SPLIT_AND_SHARD_SCRIPT +from tests.utils import single_stage_tester + +SUBJECTS_CSV = """ +MRN,dob,eye_color,height +1195293,06/20/1978,BLUE,164.6868838269085 +239684,12/28/1980,BROWN,175.271115221764 +1500733,07/20/1986,BROWN,158.60131573580904 +814703,03/28/1976,HAZEL,156.48559093209357 +754281,12/19/1988,BROWN,166.22261567137025 +68729,03/09/1978,HAZEL,160.3953106166676 +""" + +ADMIT_VITALS_0_10_CSV = """ +subject_id,admit_date,disch_date,department,vitals_date,HR,temp +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:57:18",112.6,95.5 +754281,"01/03/2010, 06:27:59","01/03/2010, 08:22:13",PULMONARY,"01/03/2010, 06:27:59",142.0,99.8 +814703,"02/05/2010, 05:55:39","02/05/2010, 07:02:30",ORTHOPEDIC,"02/05/2010, 05:55:39",170.2,100.1 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:25:35",113.4,95.8 +68729,"05/26/2010, 02:30:56","05/26/2010, 04:51:52",PULMONARY,"05/26/2010, 02:30:56",86.0,97.8 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:12:31",112.5,99.8 +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 16:20:49",90.1,100.1 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 17:48:48",105.1,96.2 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 17:41:51",102.6,96.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:25:32",114.1,100.0 +""" + +ADMIT_VITALS_10_16_CSV = """ +subject_id,admit_date,disch_date,department,vitals_date,HR,temp +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 14:54:38",91.4,100.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:41:33",107.5,100.4 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:24:44",107.7,100.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:45:19",119.8,99.9 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:23:52",109.0,100.0 +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 15:39:49",84.4,100.3 +""" + +INPUTS = { + "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), + "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_0_10_CSV)), + "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_10_16_CSV)), +} + +EVENT_CFGS_YAML = """ +subjects: + subject_id_col: MRN + eye_color: + code: + - EYE_COLOR + - col(eye_color) + time: null + _metadata: + demo_metadata: + description: description + height: + code: HEIGHT + time: null + numeric_value: height + dob: + code: DOB + time: col(dob) + time_format: "%m/%d/%Y" +admit_vitals: + admissions: + code: + - ADMISSION + - col(department) + time: col(admit_date) + time_format: "%m/%d/%Y, %H:%M:%S" + discharge: + code: DISCHARGE + time: col(disch_date) + time_format: "%m/%d/%Y, %H:%M:%S" + HR: + code: HR + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: HR + _metadata: + input_metadata: + description: {"title": {"lab_code": "HR"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "HR"}} + temp: + code: TEMP + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: temp + _metadata: + input_metadata: + description: {"title": {"lab_code": "temp"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "temp"}} +""" + +# Test data (expected outputs) -- ALL OF THIS MAY CHANGE IF THE SEED OR DATA CHANGES +EXPECTED_SPLITS = { + "train/0": [239684, 1195293], + "train/1": [68729, 814703], + "tuning/0": [754281], + "held_out/0": [1500733], +} + +SUBJECT_SPLITS_DF = pl.DataFrame( + { + "subject_id": [239684, 1195293, 68729, 814703, 754281, 1500733], + "split": ["train", "train", "train", "train", "tuning", "held_out"], + } +) + + +def test_split_and_shard(): + single_stage_tester( + script=SPLIT_AND_SHARD_SCRIPT, + stage_name="split_and_shard_subjects", + stage_kwargs={ + "split_fracs.train": 4 / 6, + "split_fracs.tuning": 1 / 6, + "split_fracs.held_out": 1 / 6, + "n_subjects_per_shard": 2, + }, + config_name="extract", + input_files={**INPUTS, "event_cfgs.yaml": EVENT_CFGS_YAML}, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", # This makes the escape pass to hydra + want_outputs={"metadata/.shards.json": EXPECTED_SPLITS}, + ) diff --git a/tests/utils.py b/tests/utils.py index 4a00eb4..ea64c1c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -334,6 +334,14 @@ def check_outputs( check_df_output(output_fp, want) case ".nrt": check_NRT_output(output_fp, want) + case ".json": + with open(output_fp) as f: + got = json.load(f) + assert got == want, ( + f"Expected JSON at {output_fp} to be equal to the target.\n" + f"Wanted:\n{want}\n" + f"Got:\n{got}" + ) case _: raise ValueError(f"Unknown file suffix: {file_suffix}") From 1961478707b28517b7b6189a6f50e1dafb8d6d0d Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 18:35:06 -0400 Subject: [PATCH 14/23] Added a split_and_shard test. --- .../MEDS_Extract/test_split_and_shard_subjects.py | 13 ++++++------- tests/MEDS_Transforms/transform_tester_base.py | 15 ++++++--------- tests/utils.py | 5 +++++ 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/tests/MEDS_Extract/test_split_and_shard_subjects.py b/tests/MEDS_Extract/test_split_and_shard_subjects.py index 0a216d2..b57cdd1 100644 --- a/tests/MEDS_Extract/test_split_and_shard_subjects.py +++ b/tests/MEDS_Extract/test_split_and_shard_subjects.py @@ -45,12 +45,6 @@ 1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 15:39:49",84.4,100.3 """ -INPUTS = { - "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), - "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_0_10_CSV)), - "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_10_16_CSV)), -} - EVENT_CFGS_YAML = """ subjects: subject_id_col: MRN @@ -128,7 +122,12 @@ def test_split_and_shard(): "n_subjects_per_shard": 2, }, config_name="extract", - input_files={**INPUTS, "event_cfgs.yaml": EVENT_CFGS_YAML}, + input_files={ + "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), + "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_0_10_CSV)), + "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_10_16_CSV)), + "event_cfgs.yaml": EVENT_CFGS_YAML, + }, event_conversion_config_fp="{input_dir}/event_cfgs.yaml", # This makes the escape pass to hydra want_outputs={"metadata/.shards.json": EXPECTED_SPLITS}, ) diff --git a/tests/MEDS_Transforms/transform_tester_base.py b/tests/MEDS_Transforms/transform_tester_base.py index 90b2ae7..2599f11 100644 --- a/tests/MEDS_Transforms/transform_tester_base.py +++ b/tests/MEDS_Transforms/transform_tester_base.py @@ -4,12 +4,11 @@ scripts. """ -from yaml import load as load_yaml try: - from yaml import CLoader as Loader + pass except ImportError: - from yaml import Loader + pass from collections import defaultdict from io import StringIO @@ -18,7 +17,10 @@ import polars as pl from meds import subject_id_field -from tests.utils import FILE_T, MEDS_PL_SCHEMA, multi_stage_tester, parse_meds_csvs, single_stage_tester +from tests.utils import FILE_T, multi_stage_tester, parse_meds_csvs, parse_shards_yaml, single_stage_tester + +# So it can be imported from here +parse_shards_yaml = parse_shards_yaml # Test MEDS data (inputs) @@ -156,11 +158,6 @@ } -def parse_shards_yaml(yaml_str: str, **schema_updates) -> pl.DataFrame: - schema = {**MEDS_PL_SCHEMA, **schema_updates} - return parse_meds_csvs(load_yaml(yaml_str, Loader=Loader), schema=schema) - - def parse_code_metadata_csv(csv_str: str) -> pl.DataFrame: cols = csv_str.strip().split("\n")[0].split(",") schema = {col: dt for col, dt in MEDS_CODE_METADATA_SCHEMA.items() if col in cols} diff --git a/tests/utils.py b/tests/utils.py index ea64c1c..bf48137 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -55,6 +55,11 @@ def reader(csv_str: str) -> pl.DataFrame: return {k: reader(v) for k, v in csvs.items()} +def parse_shards_yaml(yaml_str: str, **schema_updates) -> pl.DataFrame: + schema = {**MEDS_PL_SCHEMA, **schema_updates} + return parse_meds_csvs(load_yaml(yaml_str, Loader=Loader), schema=schema) + + def dict_to_hydra_kwargs(d: dict[str, str]) -> str: """Converts a dictionary to a hydra kwargs string for testing purposes. From 57280f04cc8af184ccdf5eb0a2226b6b457ee032 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 19:08:33 -0400 Subject: [PATCH 15/23] Made tests specific to row order and column order where appropriate. --- tests/MEDS_Extract/test_shard_events.py | 1 + .../test_aggregate_code_metadata.py | 1 + tests/MEDS_Transforms/test_tokenization.py | 1 + .../MEDS_Transforms/transform_tester_base.py | 5 ++ tests/utils.py | 85 +++++++++---------- 5 files changed, 46 insertions(+), 47 deletions(-) diff --git a/tests/MEDS_Extract/test_shard_events.py b/tests/MEDS_Extract/test_shard_events.py index 88cd60d..f19746e 100644 --- a/tests/MEDS_Extract/test_shard_events.py +++ b/tests/MEDS_Extract/test_shard_events.py @@ -110,4 +110,5 @@ def test_shard_events(): "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV))[:10], "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_CSV))[10:], }, + df_check_kwargs={"check_column_order": False}, ) diff --git a/tests/MEDS_Transforms/test_aggregate_code_metadata.py b/tests/MEDS_Transforms/test_aggregate_code_metadata.py index acf0099..a2abce5 100644 --- a/tests/MEDS_Transforms/test_aggregate_code_metadata.py +++ b/tests/MEDS_Transforms/test_aggregate_code_metadata.py @@ -184,4 +184,5 @@ def test_aggregate_code_metadata(): input_code_metadata=MEDS_CODE_METADATA_FILE, do_use_config_yaml=True, assert_no_other_outputs=False, + df_check_kwargs={"check_column_order": False}, ) diff --git a/tests/MEDS_Transforms/test_tokenization.py b/tests/MEDS_Transforms/test_tokenization.py index d12cdb0..470b425 100644 --- a/tests/MEDS_Transforms/test_tokenization.py +++ b/tests/MEDS_Transforms/test_tokenization.py @@ -226,6 +226,7 @@ def test_tokenization(): transform_stage_kwargs=None, input_shards=NORMALIZED_SHARDS, want_data={**WANT_SCHEMAS, **WANT_EVENT_SEQS}, + df_check_kwargs={"check_column_order": False}, ) single_stage_transform_tester( diff --git a/tests/MEDS_Transforms/transform_tester_base.py b/tests/MEDS_Transforms/transform_tester_base.py index 2599f11..6a1d4ab 100644 --- a/tests/MEDS_Transforms/transform_tester_base.py +++ b/tests/MEDS_Transforms/transform_tester_base.py @@ -221,8 +221,12 @@ def single_stage_transform_tester( want_metadata: pl.DataFrame | None = None, assert_no_other_outputs: bool = True, should_error: bool = False, + df_check_kwargs: dict | None = None, **input_data_kwargs, ): + if df_check_kwargs is None: + df_check_kwargs = {} + base_kwargs = { "script": transform_script, "stage_name": stage_name, @@ -233,6 +237,7 @@ def single_stage_transform_tester( "should_error": should_error, "config_name": "preprocess", "input_files": remap_inputs_for_transform(**input_data_kwargs), + "df_check_kwargs": df_check_kwargs, } want_outputs = {} diff --git a/tests/utils.py b/tests/utils.py index bf48137..f029e89 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -203,19 +203,14 @@ def assert_df_equal(want: pl.DataFrame, got: pl.DataFrame, msg: str = None, **kw assert_frame_equal(want, got, **kwargs) except AssertionError as e: pl.Config.set_tbl_rows(-1) - print(f"DFs are not equal: {msg}\nwant:") - print(want) - print("got:") - print(got) - raise AssertionError(f"{msg}\n{e}") from e + raise AssertionError(f"{msg}:\nWant:\n{want}\nGot:\n{got}\n{e}") from e def check_NRT_output( output_fp: Path, want_nrt: JointNestedRaggedTensorDict, + msg: str, ): - assert output_fp.is_file(), f"Expected {output_fp} to exist." - got_nrt = JointNestedRaggedTensorDict.load(output_fp) # assert got_nrt.schema == want_nrt.schema, ( @@ -228,20 +223,16 @@ def check_NRT_output( got_tensors = got_nrt.tensors assert got_tensors.keys() == want_tensors.keys(), ( - f"Expected the keys of the NRT at {output_fp} to be equal to the target.\n" - f"Wanted:\n{list(want_tensors.keys())}\n" - f"Got:\n{list(got_tensors.keys())}" + f"{msg}:\n" f"Wanted:\n{list(want_tensors.keys())}\n" f"Got:\n{list(got_tensors.keys())}" ) for k in want_tensors.keys(): want_v = want_tensors[k] got_v = got_tensors[k] - assert type(want_v) is type(got_v), ( - f"Expected tensor {k} of the NRT at {output_fp} to be of the same type as the target.\n" - f"Wanted:\n{type(want_v)}\n" - f"Got:\n{type(got_v)}" - ) + assert type(want_v) is type( + got_v + ), f"{msg}: Wanted {k} to be of type {type(want_v)}, got {type(got_v)}." if isinstance(want_v, list): assert len(want_v) == len(got_v), ( @@ -263,26 +254,6 @@ def check_NRT_output( ) -def check_df_output( - output_fp: Path, - want_df: pl.DataFrame, - check_column_order: bool = False, - check_row_order: bool = True, - **kwargs, -): - assert output_fp.is_file(), f"Expected {output_fp} to exist." - - got_df = pl.read_parquet(output_fp, glob=False) - assert_df_equal( - want_df, - got_df, - (f"Expected the dataframe at {output_fp} to be equal to the target.\n"), - check_column_order=check_column_order, - check_row_order=check_row_order, - **kwargs, - ) - - FILE_T = pl.DataFrame | dict[str, Any] | str @@ -319,6 +290,7 @@ def check_outputs( cohort_dir: Path, want_outputs: dict[str, pl.DataFrame], assert_no_other_outputs: bool = True, + **df_check_kwargs, ): all_file_suffixes = set() @@ -331,19 +303,27 @@ def check_outputs( output_fp = cohort_dir / output_name + files_found = [str(fp.relative_to(cohort_dir)) for fp in cohort_dir.glob("**/*{file_suffix}")] + if not output_fp.is_file(): - raise AssertionError(f"Expected {output_fp} to exist.") + raise AssertionError( + f"Wanted {output_fp.relative_to(cohort_dir)} to exist. " + f"{len(files_found)} {file_suffix} files found: {', '.join(files_found)}" + ) + + msg = f"Expected {output_fp.relative_to(cohort_dir)} to be equal to the target" match file_suffix: case ".parquet": - check_df_output(output_fp, want) + got_df = pl.read_parquet(output_fp, glob=False) + assert_df_equal(want, got_df, msg=msg, **df_check_kwargs) case ".nrt": - check_NRT_output(output_fp, want) + check_NRT_output(output_fp, want, msg=msg) case ".json": with open(output_fp) as f: got = json.load(f) assert got == want, ( - f"Expected JSON at {output_fp} to be equal to the target.\n" + f"Expected JSON at {output_fp.relative_to(cohort_dir)} to be equal to the target.\n" f"Wanted:\n{want}\n" f"Got:\n{got}" ) @@ -371,8 +351,12 @@ def single_stage_tester( should_error: bool = False, config_name: str = "preprocess", input_files: dict[str, FILE_T] | None = None, + df_check_kwargs: dict | None = None, **pipeline_kwargs, ): + if df_check_kwargs is None: + df_check_kwargs = {} + with input_dataset(input_files) as (input_dir, cohort_dir): for k, v in pipeline_kwargs.items(): if type(v) is str and "{input_dir}" in v: @@ -409,13 +393,16 @@ def single_stage_tester( try: check_outputs( - cohort_dir, want_outputs=want_outputs, assert_no_other_outputs=assert_no_other_outputs + cohort_dir, + want_outputs=want_outputs, + assert_no_other_outputs=assert_no_other_outputs, + **df_check_kwargs, ) except Exception as e: raise AssertionError( - f"Single stage transform {stage_name} failed.\n" + f"Single stage transform {stage_name} failed -- {e}:\n" f"Script stdout:\n{stdout}\n" - f"Script stderr:\n{stderr}" + f"Script stderr:\n{stderr}\n" ) from e @@ -473,8 +460,12 @@ def multi_stage_tester( do_pass_stage_name=do_pass_stage_name[stage], ) - check_outputs( - cohort_dir, - want_outputs=want_outputs, - assert_no_other_outputs=assert_no_other_outputs, - ) + try: + check_outputs( + cohort_dir, + want_outputs=want_outputs, + assert_no_other_outputs=assert_no_other_outputs, + check_column_order=False, + ) + except Exception as e: + raise AssertionError(f"{n_stages}-stage pipeline ({stage_names}) failed--{e}") from e From 3d0ad006700905992a39bdec895482287ec9a49b Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 19:10:19 -0400 Subject: [PATCH 16/23] Added an event conversion test. --- .../test_convert_to_sharded_events.py | 229 ++++++++++++++++++ 1 file changed, 229 insertions(+) create mode 100644 tests/MEDS_Extract/test_convert_to_sharded_events.py diff --git a/tests/MEDS_Extract/test_convert_to_sharded_events.py b/tests/MEDS_Extract/test_convert_to_sharded_events.py new file mode 100644 index 0000000..9eae09f --- /dev/null +++ b/tests/MEDS_Extract/test_convert_to_sharded_events.py @@ -0,0 +1,229 @@ +"""Tests the convert to sharded events process. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + +from io import StringIO + +import polars as pl + +from tests.MEDS_Extract import CONVERT_TO_SHARDED_EVENTS_SCRIPT +from tests.utils import parse_shards_yaml, single_stage_tester + +SUBJECTS_CSV = """ +MRN,dob,eye_color,height +1195293,06/20/1978,BLUE,164.6868838269085 +239684,12/28/1980,BROWN,175.271115221764 +1500733,07/20/1986,BROWN,158.60131573580904 +814703,03/28/1976,HAZEL,156.48559093209357 +754281,12/19/1988,BROWN,166.22261567137025 +68729,03/09/1978,HAZEL,160.3953106166676 +""" + +ADMIT_VITALS_0_10_CSV = """ +subject_id,admit_date,disch_date,department,vitals_date,HR,temp +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:57:18",112.6,95.5 +754281,"01/03/2010, 06:27:59","01/03/2010, 08:22:13",PULMONARY,"01/03/2010, 06:27:59",142.0,99.8 +814703,"02/05/2010, 05:55:39","02/05/2010, 07:02:30",ORTHOPEDIC,"02/05/2010, 05:55:39",170.2,100.1 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:25:35",113.4,95.8 +68729,"05/26/2010, 02:30:56","05/26/2010, 04:51:52",PULMONARY,"05/26/2010, 02:30:56",86.0,97.8 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:12:31",112.5,99.8 +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 16:20:49",90.1,100.1 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 17:48:48",105.1,96.2 +239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 17:41:51",102.6,96.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:25:32",114.1,100.0 +""" + +ADMIT_VITALS_10_16_CSV = """ +subject_id,admit_date,disch_date,department,vitals_date,HR,temp +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 14:54:38",91.4,100.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:41:33",107.5,100.4 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 20:24:44",107.7,100.0 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:45:19",119.8,99.9 +1195293,"06/20/2010, 19:23:52","06/20/2010, 20:50:04",CARDIAC,"06/20/2010, 19:23:52",109.0,100.0 +1500733,"06/03/2010, 14:54:38","06/03/2010, 16:44:26",ORTHOPEDIC,"06/03/2010, 15:39:49",84.4,100.3 +""" + +EVENT_CFGS_YAML = """ +subjects: + subject_id_col: MRN + eye_color: + code: + - EYE_COLOR + - col(eye_color) + time: null + _metadata: + demo_metadata: + description: description + height: + code: HEIGHT + time: null + numeric_value: height + dob: + code: DOB + time: col(dob) + time_format: "%m/%d/%Y" +admit_vitals: + admissions: + code: + - ADMISSION + - col(department) + time: col(admit_date) + time_format: "%m/%d/%Y, %H:%M:%S" + discharge: + code: DISCHARGE + time: col(disch_date) + time_format: "%m/%d/%Y, %H:%M:%S" + HR: + code: HR + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: HR + _metadata: + input_metadata: + description: {"title": {"lab_code": "HR"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "HR"}} + temp: + code: TEMP + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: temp + _metadata: + input_metadata: + description: {"title": {"lab_code": "temp"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "temp"}} +""" + +SHARDS_JSON = { + "train/0": [239684, 1195293], + "train/1": [68729, 814703], + "tuning/0": [754281], + "held_out/0": [1500733], +} + +WANT_OUTPUTS = parse_shards_yaml( + """ + data/train/0/subjects/[0-6).parquet: |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + + data/train/1/subjects/[0-6).parquet: |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, + + data/tuning/0/subjects/[0-6).parquet: |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, + + data/held_out/0/subjects/[0-6).parquet: |-2 + subject_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + + data/train/0/admit_vitals/[0-10).parquet: |-2 + subject_id,time,code,numeric_value + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + data/train/0/admit_vitals/[10-16).parquet: |-2 + subject_id,time,code,numeric_value + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + data/train/1/admit_vitals/[0-10).parquet: |-2 + subject_id,time,code,numeric_value + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, + 68729,"05/26/2010, 02:30:56",HR,86.0 + 68729,"05/26/2010, 02:30:56",TEMP,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, + 814703,"02/05/2010, 05:55:39",HR,170.2 + 814703,"02/05/2010, 05:55:39",TEMP,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE, + + data/train/1/admit_vitals/[10-16).parquet: |-2 + subject_id,time,code,numeric_value + + data/tuning/0/admit_vitals/[0-10).parquet: |-2 + subject_id,time,code,numeric_value + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, + 754281,"01/03/2010, 06:27:59",HR,142.0 + 754281,"01/03/2010, 06:27:59",TEMP,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE, + + data/tuning/0/admit_vitals/[10-16).parquet: |-2 + subject_id,time,code,numeric_value + + data/held_out/0/admit_vitals/[0-10).parquet: |-2 + subject_id,time,code,numeric_value + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + + data/held_out/0/admit_vitals/[10-16).parquet: |-2 + subject_id,time,code,numeric_value + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + """ +) + + +def test_convert_to_sharded_events(): + single_stage_tester( + script=CONVERT_TO_SHARDED_EVENTS_SCRIPT, + stage_name="convert_to_sharded_events", + stage_kwargs=None, + config_name="extract", + input_files={ + "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), + "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_0_10_CSV)), + "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_10_16_CSV)), + "event_cfgs.yaml": EVENT_CFGS_YAML, + "metadata/.shards.json": SHARDS_JSON, + }, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", # This makes the escape pass to hydra + shards_map_fp="{input_dir}/metadata/.shards.json", + want_outputs=WANT_OUTPUTS, + df_check_kwargs={"check_row_order": False, "check_column_order": False, "check_dtypes": False}, + ) From ab31a070a7215d647878b2c36ea1962d3e8c5cec Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 27 Aug 2024 19:49:36 -0400 Subject: [PATCH 17/23] Removing outdated comments. --- tests/MEDS_Extract/test_convert_to_sharded_events.py | 2 +- tests/MEDS_Extract/test_split_and_shard_subjects.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/MEDS_Extract/test_convert_to_sharded_events.py b/tests/MEDS_Extract/test_convert_to_sharded_events.py index 9eae09f..074e897 100644 --- a/tests/MEDS_Extract/test_convert_to_sharded_events.py +++ b/tests/MEDS_Extract/test_convert_to_sharded_events.py @@ -222,7 +222,7 @@ def test_convert_to_sharded_events(): "event_cfgs.yaml": EVENT_CFGS_YAML, "metadata/.shards.json": SHARDS_JSON, }, - event_conversion_config_fp="{input_dir}/event_cfgs.yaml", # This makes the escape pass to hydra + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", shards_map_fp="{input_dir}/metadata/.shards.json", want_outputs=WANT_OUTPUTS, df_check_kwargs={"check_row_order": False, "check_column_order": False, "check_dtypes": False}, diff --git a/tests/MEDS_Extract/test_split_and_shard_subjects.py b/tests/MEDS_Extract/test_split_and_shard_subjects.py index b57cdd1..db74896 100644 --- a/tests/MEDS_Extract/test_split_and_shard_subjects.py +++ b/tests/MEDS_Extract/test_split_and_shard_subjects.py @@ -128,6 +128,6 @@ def test_split_and_shard(): "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_10_16_CSV)), "event_cfgs.yaml": EVENT_CFGS_YAML, }, - event_conversion_config_fp="{input_dir}/event_cfgs.yaml", # This makes the escape pass to hydra + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", want_outputs={"metadata/.shards.json": EXPECTED_SPLITS}, ) From a812425467957a706607c63894f002e1c4c099b3 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 11:01:48 -0400 Subject: [PATCH 18/23] Adding merge test. --- .../MEDS_Extract/test_merge_to_MEDS_cohort.py | 268 ++++++++++++++++++ 1 file changed, 268 insertions(+) create mode 100644 tests/MEDS_Extract/test_merge_to_MEDS_cohort.py diff --git a/tests/MEDS_Extract/test_merge_to_MEDS_cohort.py b/tests/MEDS_Extract/test_merge_to_MEDS_cohort.py new file mode 100644 index 0000000..b9a21ff --- /dev/null +++ b/tests/MEDS_Extract/test_merge_to_MEDS_cohort.py @@ -0,0 +1,268 @@ +"""Tests the merge to MEDS events process. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + +from tests.MEDS_Extract import MERGE_TO_MEDS_COHORT_SCRIPT +from tests.utils import parse_shards_yaml, single_stage_tester + +EVENT_CFGS_YAML = """ +subjects: + subject_id_col: MRN + eye_color: + code: + - EYE_COLOR + - col(eye_color) + time: null + _metadata: + demo_metadata: + description: description + height: + code: HEIGHT + time: null + numeric_value: height + dob: + code: DOB + time: col(dob) + time_format: "%m/%d/%Y" +admit_vitals: + admissions: + code: + - ADMISSION + - col(department) + time: col(admit_date) + time_format: "%m/%d/%Y, %H:%M:%S" + discharge: + code: DISCHARGE + time: col(disch_date) + time_format: "%m/%d/%Y, %H:%M:%S" + HR: + code: HR + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: HR + _metadata: + input_metadata: + description: {"title": {"lab_code": "HR"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "HR"}} + temp: + code: TEMP + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: temp + _metadata: + input_metadata: + description: {"title": {"lab_code": "temp"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "temp"}} +""" + +SHARDS_JSON = { + "train/0": [239684, 1195293], + "train/1": [68729, 814703], + "tuning/0": [754281], + "held_out/0": [1500733], +} + +INPUT_SHARDS = parse_shards_yaml( + """ + data/train/0/subjects/[0-6): |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + + data/train/1/subjects/[0-6): |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, + + data/tuning/0/subjects/[0-6): |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, + + data/held_out/0/subjects/[0-6): |-2 + subject_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + + data/train/0/admit_vitals/[0-10): |-2 + subject_id,time,code,numeric_value + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + data/train/0/admit_vitals/[10-16): |-2 + subject_id,time,code,numeric_value + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + data/train/1/admit_vitals/[0-10): |-2 + subject_id,time,code,numeric_value + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, + 68729,"05/26/2010, 02:30:56",HR,86.0 + 68729,"05/26/2010, 02:30:56",TEMP,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, + 814703,"02/05/2010, 05:55:39",HR,170.2 + 814703,"02/05/2010, 05:55:39",TEMP,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE, + + data/train/1/admit_vitals/[10-16): |-2 + subject_id,time,code,numeric_value + + data/tuning/0/admit_vitals/[0-10): |-2 + subject_id,time,code,numeric_value + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, + 754281,"01/03/2010, 06:27:59",HR,142.0 + 754281,"01/03/2010, 06:27:59",TEMP,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE, + + data/tuning/0/admit_vitals/[10-16): |-2 + subject_id,time,code,numeric_value + + data/held_out/0/admit_vitals/[0-10): |-2 + subject_id,time,code,numeric_value + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + + data/held_out/0/admit_vitals/[10-16): |-2 + subject_id,time,code,numeric_value + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + """ +) + +WANT_OUTPUTS = parse_shards_yaml( + """ + data/train/0: |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + + data/train/1: |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, + 68729,"05/26/2010, 02:30:56",HR,86.0 + 68729,"05/26/2010, 02:30:56",TEMP,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, + 814703,"02/05/2010, 05:55:39",HR,170.2 + 814703,"02/05/2010, 05:55:39",TEMP,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE, + + data/tuning/0: |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, + 754281,"01/03/2010, 06:27:59",HR,142.0 + 754281,"01/03/2010, 06:27:59",TEMP,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE, + + data/held_out/0: |-2 + subject_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + """ +) + + +def test_convert_to_sharded_events(): + single_stage_tester( + script=MERGE_TO_MEDS_COHORT_SCRIPT, + stage_name="merge_to_MEDS_cohort", + stage_kwargs=None, + config_name="extract", + input_files={ + **INPUT_SHARDS, + "event_cfgs.yaml": EVENT_CFGS_YAML, + "metadata/.shards.json": SHARDS_JSON, + }, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + shards_map_fp="{input_dir}/metadata/.shards.json", + want_outputs=WANT_OUTPUTS, + df_check_kwargs={"check_column_order": False}, + ) From 5393d83bbfd382aa4ac363beb62a2905befa8588 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 11:10:28 -0400 Subject: [PATCH 19/23] Added data finalization test. --- tests/MEDS_Extract/test_finalize_MEDS_data.py | 111 ++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 tests/MEDS_Extract/test_finalize_MEDS_data.py diff --git a/tests/MEDS_Extract/test_finalize_MEDS_data.py b/tests/MEDS_Extract/test_finalize_MEDS_data.py new file mode 100644 index 0000000..d9a3e0a --- /dev/null +++ b/tests/MEDS_Extract/test_finalize_MEDS_data.py @@ -0,0 +1,111 @@ +"""Tests the finalize MEDS data process. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + +import polars as pl + +from tests.MEDS_Extract import FINALIZE_DATA_SCRIPT +from tests.utils import parse_shards_yaml, single_stage_tester + +INPUT_SHARDS = parse_shards_yaml( + """ + data/train/0: |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + + data/train/1: |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, + 68729,"05/26/2010, 02:30:56",HR,86.0 + 68729,"05/26/2010, 02:30:56",TEMP,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, + 814703,"02/05/2010, 05:55:39",HR,170.2 + 814703,"02/05/2010, 05:55:39",TEMP,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE, + + data/tuning/0: |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, + 754281,"01/03/2010, 06:27:59",HR,142.0 + 754281,"01/03/2010, 06:27:59",TEMP,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE, + + data/held_out/0: |-2 + subject_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + """ +) + +WANT_OUTPUTS = { + k: v.with_columns( + pl.col("subject_id").cast(pl.Int64), + pl.col("time").cast(pl.Datetime("us")), + pl.col("code").cast(pl.String), + pl.col("numeric_value").cast(pl.Float32), + ) + for k, v in INPUT_SHARDS.items() +} + + +def test_convert_to_sharded_events(): + single_stage_tester( + script=FINALIZE_DATA_SCRIPT, + stage_name="finalize_MEDS_data", + stage_kwargs=None, + config_name="extract", + input_files=INPUT_SHARDS, + want_outputs=WANT_OUTPUTS, + df_check_kwargs={"check_column_order": True, "check_dtypes": True, "check_row_order": True}, + ) From 886817883136ece9e7e466efecc9cc4ccba1be0a Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 11:25:19 -0400 Subject: [PATCH 20/23] Added the ability to handle list columns to df checker. --- tests/MEDS_Extract/test_extract.py | 5 ----- tests/utils.py | 14 ++++++++++++++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/MEDS_Extract/test_extract.py b/tests/MEDS_Extract/test_extract.py index be6b246..954ca21 100644 --- a/tests/MEDS_Extract/test_extract.py +++ b/tests/MEDS_Extract/test_extract.py @@ -525,14 +525,9 @@ def test_extraction(): got_df = pl.read_parquet(output_file, glob=False) want_df = pl.read_csv(source=StringIO(MEDS_OUTPUT_CODE_METADATA_FILE)).with_columns( - pl.col("code"), pl.col("parent_codes").cast(pl.List(pl.Utf8)), ) - # We collapse the list type as it throws an error in the assert_df_equal otherwise - got_df = got_df.with_columns(pl.col("parent_codes").list.join("||")) - want_df = want_df.with_columns(pl.col("parent_codes").list.join("||")) - assert_df_equal( want=want_df, got=got_df, diff --git a/tests/utils.py b/tests/utils.py index f029e89..450a86b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -200,6 +200,20 @@ def run_command( def assert_df_equal(want: pl.DataFrame, got: pl.DataFrame, msg: str = None, **kwargs): try: + update_exprs = {} + for k, v in want.schema.items(): + assert k in got.schema, f"missing column {k}." + if kwargs.get("check_dtypes", False): + assert v == got.schema[k], f"column {k} has different types." + if v == pl.List(pl.String) and got.schema[k] == pl.List(pl.String): + update_exprs[k] = pl.col(k).list.join("||") + if update_exprs: + want_cols = want.columns + got_cols = got.columns + + want = want.with_columns(**update_exprs).select(want_cols) + got = got.with_columns(**update_exprs).select(got_cols) + assert_frame_equal(want, got, **kwargs) except AssertionError as e: pl.Config.set_tbl_rows(-1) From 3b77247cd7e54347b73a367dafcb1d1b276f414e Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 11:26:12 -0400 Subject: [PATCH 21/23] Added a metadata extractor test. --- .../test_extract_code_metadata.py | 208 ++++++++++++++++++ 1 file changed, 208 insertions(+) create mode 100644 tests/MEDS_Extract/test_extract_code_metadata.py diff --git a/tests/MEDS_Extract/test_extract_code_metadata.py b/tests/MEDS_Extract/test_extract_code_metadata.py new file mode 100644 index 0000000..af307ce --- /dev/null +++ b/tests/MEDS_Extract/test_extract_code_metadata.py @@ -0,0 +1,208 @@ +"""Tests the extract code metadata process. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + + +import polars as pl + +from tests.MEDS_Extract import EXTRACT_CODE_METADATA_SCRIPT +from tests.utils import parse_shards_yaml, single_stage_tester + +INPUT_SHARDS = parse_shards_yaml( + """ + data/train/0: |-2 + subject_id,time,code,numeric_value + 239684,,EYE_COLOR//BROWN, + 239684,,HEIGHT,175.271115221764 + 239684,"12/28/1980, 00:00:00",DOB, + 239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, + 239684,"05/11/2010, 17:41:51",HR,102.6 + 239684,"05/11/2010, 17:41:51",TEMP,96.0 + 239684,"05/11/2010, 17:48:48",HR,105.1 + 239684,"05/11/2010, 17:48:48",TEMP,96.2 + 239684,"05/11/2010, 18:25:35",HR,113.4 + 239684,"05/11/2010, 18:25:35",TEMP,95.8 + 239684,"05/11/2010, 18:57:18",HR,112.6 + 239684,"05/11/2010, 18:57:18",TEMP,95.5 + 239684,"05/11/2010, 19:27:19",DISCHARGE, + 1195293,,EYE_COLOR//BLUE, + 1195293,,HEIGHT,164.6868838269085 + 1195293,"06/20/1978, 00:00:00",DOB, + 1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, + 1195293,"06/20/2010, 19:23:52",HR,109.0 + 1195293,"06/20/2010, 19:23:52",TEMP,100.0 + 1195293,"06/20/2010, 19:25:32",HR,114.1 + 1195293,"06/20/2010, 19:25:32",TEMP,100.0 + 1195293,"06/20/2010, 19:45:19",HR,119.8 + 1195293,"06/20/2010, 19:45:19",TEMP,99.9 + 1195293,"06/20/2010, 20:12:31",HR,112.5 + 1195293,"06/20/2010, 20:12:31",TEMP,99.8 + 1195293,"06/20/2010, 20:24:44",HR,107.7 + 1195293,"06/20/2010, 20:24:44",TEMP,100.0 + 1195293,"06/20/2010, 20:41:33",HR,107.5 + 1195293,"06/20/2010, 20:41:33",TEMP,100.4 + 1195293,"06/20/2010, 20:50:04",DISCHARGE, + + + data/train/1: |-2 + subject_id,time,code,numeric_value + 68729,,EYE_COLOR//HAZEL, + 68729,,HEIGHT,160.3953106166676 + 68729,"03/09/1978, 00:00:00",DOB, + 68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, + 68729,"05/26/2010, 02:30:56",HR,86.0 + 68729,"05/26/2010, 02:30:56",TEMP,97.8 + 68729,"05/26/2010, 04:51:52",DISCHARGE, + 814703,,EYE_COLOR//HAZEL, + 814703,,HEIGHT,156.48559093209357 + 814703,"03/28/1976, 00:00:00",DOB, + 814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, + 814703,"02/05/2010, 05:55:39",HR,170.2 + 814703,"02/05/2010, 05:55:39",TEMP,100.1 + 814703,"02/05/2010, 07:02:30",DISCHARGE, + + data/tuning/0: |-2 + subject_id,time,code,numeric_value + 754281,,EYE_COLOR//BROWN, + 754281,,HEIGHT,166.22261567137025 + 754281,"12/19/1988, 00:00:00",DOB, + 754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, + 754281,"01/03/2010, 06:27:59",HR,142.0 + 754281,"01/03/2010, 06:27:59",TEMP,99.8 + 754281,"01/03/2010, 08:22:13",DISCHARGE, + + data/held_out/0: |-2 + subject_id,time,code,numeric_value + 1500733,,EYE_COLOR//BROWN, + 1500733,,HEIGHT,158.60131573580904 + 1500733,"07/20/1986, 00:00:00",DOB, + 1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, + 1500733,"06/03/2010, 14:54:38",HR,91.4 + 1500733,"06/03/2010, 14:54:38",TEMP,100.0 + 1500733,"06/03/2010, 15:39:49",HR,84.4 + 1500733,"06/03/2010, 15:39:49",TEMP,100.3 + 1500733,"06/03/2010, 16:20:49",HR,90.1 + 1500733,"06/03/2010, 16:20:49",TEMP,100.1 + 1500733,"06/03/2010, 16:44:26",DISCHARGE, + """ +) + + +INPUT_METADATA_FILE = """ +lab_code,title,loinc +HR,Heart Rate,8867-4 +temp,Body Temperature,8310-5 +""" + +DEMO_METADATA_FILE = """ +eye_color,description +BROWN,"Brown Eyes. The most common eye color." +BLUE,"Blue Eyes. Less common than brown." +HAZEL,"Hazel eyes. These are uncommon" +GREEN,"Green eyes. These are rare." +""" + +EVENT_CFGS_YAML = """ +subjects: + subject_id_col: MRN + eye_color: + code: + - EYE_COLOR + - col(eye_color) + time: null + _metadata: + demo_metadata: + description: description + height: + code: HEIGHT + time: null + numeric_value: height + dob: + code: DOB + time: col(dob) + time_format: "%m/%d/%Y" +admit_vitals: + admissions: + code: + - ADMISSION + - col(department) + time: col(admit_date) + time_format: "%m/%d/%Y, %H:%M:%S" + discharge: + code: DISCHARGE + time: col(disch_date) + time_format: "%m/%d/%Y, %H:%M:%S" + HR: + code: HR + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: HR + _metadata: + input_metadata: + description: {"title": {"lab_code": "HR"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "HR"}} + temp: + code: TEMP + time: col(vitals_date) + time_format: "%m/%d/%Y, %H:%M:%S" + numeric_value: temp + _metadata: + input_metadata: + description: {"title": {"lab_code": "temp"}} + parent_codes: {"LOINC/{loinc}": {"lab_code": "temp"}} +""" + +SHARDS_JSON = { + "train/0": [239684, 1195293], + "train/1": [68729, 814703], + "tuning/0": [754281], + "held_out/0": [1500733], +} + +MEDS_OUTPUT_CODE_METADATA_FILE = """ +code,description,parent_codes +EYE_COLOR//BLUE,"Blue Eyes. Less common than brown.", +EYE_COLOR//BROWN,"Brown Eyes. The most common eye color.", +EYE_COLOR//HAZEL,"Hazel eyes. These are uncommon", +HR,"Heart Rate",LOINC/8867-4 +TEMP,"Body Temperature",LOINC/8310-5 +""" + +WANT_OUTPUTS = { + "metadata/codes": pl.DataFrame( + { + "code": ["EYE_COLOR//BLUE", "EYE_COLOR//BROWN", "EYE_COLOR//HAZEL", "HR", "TEMP"], + "description": [ + "Blue Eyes. Less common than brown.", + "Brown Eyes. The most common eye color.", + "Hazel eyes. These are uncommon", + "Heart Rate", + "Body Temperature", + ], + "parent_codes": [None, None, None, ["LOINC/8867-4"], ["LOINC/8310-5"]], + } + ), +} + + +def test_convert_to_sharded_events(): + single_stage_tester( + script=EXTRACT_CODE_METADATA_SCRIPT, + stage_name="extract_code_metadata", + stage_kwargs=None, + config_name="extract", + input_files={ + **INPUT_SHARDS, + "demo_metadata.csv": DEMO_METADATA_FILE, + "input_metadata.csv": INPUT_METADATA_FILE, + "event_cfgs.yaml": EVENT_CFGS_YAML, + "metadata/.shards.json": SHARDS_JSON, + }, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + shards_map_fp="{input_dir}/metadata/.shards.json", + want_outputs=WANT_OUTPUTS, + df_check_kwargs={"check_row_order": False, "check_column_order": False, "check_dtypes": True}, + assert_no_other_outputs=False, + ) From 04fc97af94189f88d8fbf974b01fadb0f2a5bea7 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 11:36:08 -0400 Subject: [PATCH 22/23] Removed unused constant in test. --- tests/MEDS_Extract/test_extract_code_metadata.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/MEDS_Extract/test_extract_code_metadata.py b/tests/MEDS_Extract/test_extract_code_metadata.py index af307ce..74a3d02 100644 --- a/tests/MEDS_Extract/test_extract_code_metadata.py +++ b/tests/MEDS_Extract/test_extract_code_metadata.py @@ -161,15 +161,6 @@ "held_out/0": [1500733], } -MEDS_OUTPUT_CODE_METADATA_FILE = """ -code,description,parent_codes -EYE_COLOR//BLUE,"Blue Eyes. Less common than brown.", -EYE_COLOR//BROWN,"Brown Eyes. The most common eye color.", -EYE_COLOR//HAZEL,"Hazel eyes. These are uncommon", -HR,"Heart Rate",LOINC/8867-4 -TEMP,"Body Temperature",LOINC/8310-5 -""" - WANT_OUTPUTS = { "metadata/codes": pl.DataFrame( { From 2d848091ea49f8ede703c52a4dce069646db8a5f Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 28 Aug 2024 11:46:21 -0400 Subject: [PATCH 23/23] Added finalize MEDS metadata test. --- .../test_finalize_MEDS_metadata.py | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 tests/MEDS_Extract/test_finalize_MEDS_metadata.py diff --git a/tests/MEDS_Extract/test_finalize_MEDS_metadata.py b/tests/MEDS_Extract/test_finalize_MEDS_metadata.py new file mode 100644 index 0000000..274997f --- /dev/null +++ b/tests/MEDS_Extract/test_finalize_MEDS_metadata.py @@ -0,0 +1,91 @@ +"""Tests the finalize MEDS metadata process. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + + +import polars as pl +from meds import __version__ as MEDS_VERSION + +from MEDS_transforms.utils import get_package_version as get_meds_transform_version +from tests.MEDS_Extract import FINALIZE_METADATA_SCRIPT +from tests.utils import single_stage_tester + +SHARDS_JSON = { + "train/0": [239684, 1195293], + "train/1": [68729, 814703], + "tuning/0": [754281], + "held_out/0": [1500733], +} + +WANT_OUTPUTS = { + "metadata/codes": pl.DataFrame( + { + "code": ["EYE_COLOR//BLUE", "EYE_COLOR//BROWN", "EYE_COLOR//HAZEL", "HR", "TEMP"], + "description": [ + "Blue Eyes. Less common than brown.", + "Brown Eyes. The most common eye color.", + "Hazel eyes. These are uncommon", + "Heart Rate", + "Body Temperature", + ], + "parent_codes": [None, None, None, ["LOINC/8867-4"], ["LOINC/8310-5"]], + } + ), +} + +METADATA_DF = pl.DataFrame( + { + "code": ["EYE_COLOR//BLUE", "EYE_COLOR//BROWN", "EYE_COLOR//HAZEL", "HR", "TEMP"], + "description": [ + "Blue Eyes. Less common than brown.", + "Brown Eyes. The most common eye color.", + "Hazel eyes. These are uncommon", + "Heart Rate", + "Body Temperature", + ], + "parent_codes": [None, None, None, ["LOINC/8867-4"], ["LOINC/8310-5"]], + } +) + +WANT_OUTPUTS = { + "metadata/codes": ( + METADATA_DF.with_columns( + pl.col("code").cast(pl.String), + pl.col("description").cast(pl.String), + pl.col("parent_codes").cast(pl.List(pl.String)), + ).select(["code", "description", "parent_codes"]) + ), + "metadata/subject_splits": pl.DataFrame( + { + "subject_id": [239684, 1195293, 68729, 814703, 754281, 1500733], + "split": ["train", "train", "train", "train", "tuning", "held_out"], + } + ), + "metadata/dataset.json": { + "dataset_name": "TEST", + "dataset_version": "1.0", + "etl_name": "MEDS_transforms", + "etl_version": get_meds_transform_version(), + "meds_version": MEDS_VERSION, + }, +} + + +def test_convert_to_sharded_events(): + single_stage_tester( + script=FINALIZE_METADATA_SCRIPT, + stage_name="finalize_MEDS_metadata", + stage_kwargs=None, + config_name="extract", + input_files={ + "metadata/codes": METADATA_DF, + "metadata/.shards.json": SHARDS_JSON, + }, + **{"etl_metadata.dataset_name": "TEST", "etl_metadata.dataset_version": "1.0"}, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + shards_map_fp="{input_dir}/metadata/.shards.json", + want_outputs=WANT_OUTPUTS, + df_check_kwargs={"check_row_order": False, "check_column_order": True, "check_dtypes": True}, + )