From 1b2d0230a45d6e8077988ba284a1005636482e25 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 10 Aug 2024 17:08:10 -0400 Subject: [PATCH] Improved compliance by removing creation of shards.json file and adding splits parquet file. --- .../transforms/tokenization.py | 19 +++++------- tests/transform_tester_base.py | 31 ++++++++++++++++--- 2 files changed, 35 insertions(+), 15 deletions(-) diff --git a/src/MEDS_transforms/transforms/tokenization.py b/src/MEDS_transforms/transforms/tokenization.py index 0ef5704..5e5e0e6 100644 --- a/src/MEDS_transforms/transforms/tokenization.py +++ b/src/MEDS_transforms/transforms/tokenization.py @@ -9,8 +9,6 @@ columns of concern here thus are `patient_id`, `time`, `code`, `numeric_value`. """ -import json -import random from pathlib import Path import hydra @@ -19,7 +17,7 @@ from omegaconf import DictConfig, OmegaConf from MEDS_transforms import PREPROCESS_CONFIG_YAML -from MEDS_transforms.mapreduce.mapper import rwlock_wrap +from MEDS_transforms.mapreduce.utils import rwlock_wrap, shard_iterator from MEDS_transforms.utils import hydra_loguru_init, write_lazyframe SECONDS_PER_MINUTE = 60.0 @@ -230,18 +228,17 @@ def main(cfg: DictConfig): f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" ) - input_dir = Path(cfg.stage_cfg.data_input_dir) output_dir = Path(cfg.stage_cfg.output_dir) + shards_single_output, include_only_train = shard_iterator(cfg) - shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) + if include_only_train: + raise ValueError("Not supported for this stage.") - patient_splits = list(shards.keys()) - random.shuffle(patient_splits) + for in_fp, out_fp in shards_single_output: + sharded_path = out_fp.relative_to(output_dir) - for sp in patient_splits: - in_fp = input_dir / f"{sp}.parquet" - schema_out_fp = output_dir / "schemas" / f"{sp}.parquet" - event_seq_out_fp = output_dir / "event_seqs" / f"{sp}.parquet" + schema_out_fp = output_dir / "schemas" / sharded_path + event_seq_out_fp = output_dir / "event_seqs" / sharded_path logger.info(f"Tokenizing {str(in_fp.resolve())} into schemas at {str(schema_out_fp.resolve())}") diff --git a/tests/transform_tester_base.py b/tests/transform_tester_base.py index 80e4f4b..945a2f4 100644 --- a/tests/transform_tester_base.py +++ b/tests/transform_tester_base.py @@ -7,6 +7,7 @@ import json import os import tempfile +from collections import defaultdict from io import StringIO from pathlib import Path @@ -56,13 +57,19 @@ # Test MEDS data (inputs) -SPLITS = { +SHARDS = { "train/0": [239684, 1195293], "train/1": [68729, 814703], "tuning/0": [754281], "held_out/0": [1500733], } +SPLITS = { + "train": [239684, 1195293, 68729, 814703], + "tuning": [754281], + "held_out": [1500733], +} + MEDS_TRAIN_0 = """ patient_id,time,code,numeric_value 239684,,EYE_COLOR//BROWN, @@ -292,6 +299,8 @@ def single_stage_transform_tester( do_pass_stage_name: bool = False, file_suffix: str = ".parquet", do_use_config_yaml: bool = False, + input_shards_map: dict[str, list[int]] | None = None, + input_splits_map: dict[str, list[int]] | None = None, ): with tempfile.TemporaryDirectory() as d: MEDS_dir = Path(d) / "MEDS_cohort" @@ -306,9 +315,23 @@ def single_stage_transform_tester( MEDS_metadata_dir.mkdir(parents=True) cohort_dir.mkdir(parents=True) - # Write the splits - splits_fp = MEDS_dir / "splits.json" - splits_fp.write_text(json.dumps(SPLITS)) + # Write the shards map + if input_shards_map is None: + input_shards_map = SHARDS + + shards_fp = MEDS_metadata_dir / ".shards.json" + shards_fp.write_text(json.dumps(input_shards_map)) + + # Write the splits parquet file + if input_splits_map is None: + input_splits_map = SPLITS + input_splits_as_df = defaultdict(list) + for split_name, patient_ids in input_splits_map.items(): + input_splits_as_df["patient_id"].extend(patient_ids) + input_splits_as_df["split"].extend([split_name] * len(patient_ids)) + input_splits_df = pl.DataFrame(input_splits_as_df) + input_splits_fp = MEDS_metadata_dir / "patient_splits.parquet" + input_splits_df.write_parquet(input_splits_fp, use_pyarrow=True) if input_shards is None: input_shards = MEDS_SHARDS