Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved compliance by removing creation of shards.json file and adding patient_splits.parquet file. #137

Merged
merged 1 commit into from
Aug 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions src/MEDS_transforms/transforms/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
columns of concern here thus are `patient_id`, `time`, `code`, `numeric_value`.
"""

import json
import random
from pathlib import Path

import hydra
Expand All @@ -19,7 +17,7 @@
from omegaconf import DictConfig, OmegaConf

from MEDS_transforms import PREPROCESS_CONFIG_YAML
from MEDS_transforms.mapreduce.mapper import rwlock_wrap
from MEDS_transforms.mapreduce.utils import rwlock_wrap, shard_iterator
from MEDS_transforms.utils import hydra_loguru_init, write_lazyframe

SECONDS_PER_MINUTE = 60.0
Expand Down Expand Up @@ -230,18 +228,17 @@ def main(cfg: DictConfig):
f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}"
)

input_dir = Path(cfg.stage_cfg.data_input_dir)
output_dir = Path(cfg.stage_cfg.output_dir)
shards_single_output, include_only_train = shard_iterator(cfg)

shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text())
if include_only_train:
raise ValueError("Not supported for this stage.")

patient_splits = list(shards.keys())
random.shuffle(patient_splits)
for in_fp, out_fp in shards_single_output:
sharded_path = out_fp.relative_to(output_dir)

for sp in patient_splits:
in_fp = input_dir / f"{sp}.parquet"
schema_out_fp = output_dir / "schemas" / f"{sp}.parquet"
event_seq_out_fp = output_dir / "event_seqs" / f"{sp}.parquet"
schema_out_fp = output_dir / "schemas" / sharded_path
event_seq_out_fp = output_dir / "event_seqs" / sharded_path

logger.info(f"Tokenizing {str(in_fp.resolve())} into schemas at {str(schema_out_fp.resolve())}")

Expand Down
31 changes: 27 additions & 4 deletions tests/transform_tester_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import os
import tempfile
from collections import defaultdict
from io import StringIO
from pathlib import Path

Expand Down Expand Up @@ -56,13 +57,19 @@

# Test MEDS data (inputs)

SPLITS = {
SHARDS = {
"train/0": [239684, 1195293],
"train/1": [68729, 814703],
"tuning/0": [754281],
"held_out/0": [1500733],
}

SPLITS = {
"train": [239684, 1195293, 68729, 814703],
"tuning": [754281],
"held_out": [1500733],
}

MEDS_TRAIN_0 = """
patient_id,time,code,numeric_value
239684,,EYE_COLOR//BROWN,
Expand Down Expand Up @@ -292,6 +299,8 @@ def single_stage_transform_tester(
do_pass_stage_name: bool = False,
file_suffix: str = ".parquet",
do_use_config_yaml: bool = False,
input_shards_map: dict[str, list[int]] | None = None,
input_splits_map: dict[str, list[int]] | None = None,
):
with tempfile.TemporaryDirectory() as d:
MEDS_dir = Path(d) / "MEDS_cohort"
Expand All @@ -306,9 +315,23 @@ def single_stage_transform_tester(
MEDS_metadata_dir.mkdir(parents=True)
cohort_dir.mkdir(parents=True)

# Write the splits
splits_fp = MEDS_dir / "splits.json"
splits_fp.write_text(json.dumps(SPLITS))
# Write the shards map
if input_shards_map is None:
input_shards_map = SHARDS

shards_fp = MEDS_metadata_dir / ".shards.json"
shards_fp.write_text(json.dumps(input_shards_map))

# Write the splits parquet file
if input_splits_map is None:
input_splits_map = SPLITS
input_splits_as_df = defaultdict(list)
for split_name, patient_ids in input_splits_map.items():
input_splits_as_df["patient_id"].extend(patient_ids)
input_splits_as_df["split"].extend([split_name] * len(patient_ids))
input_splits_df = pl.DataFrame(input_splits_as_df)
input_splits_fp = MEDS_metadata_dir / "patient_splits.parquet"
input_splits_df.write_parquet(input_splits_fp, use_pyarrow=True)

if input_shards is None:
input_shards = MEDS_SHARDS
Expand Down
Loading