Skip to content

Commit

Permalink
Separated merge shards shard iterator out for ease of import.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Aug 10, 2024
1 parent dd43d5a commit fd6e77f
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 43 deletions.
32 changes: 2 additions & 30 deletions src/MEDS_transforms/extract/merge_to_MEDS_cohort.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
#!/usr/bin/env python
import json
import random
from functools import partial
from pathlib import Path

Expand All @@ -11,6 +9,7 @@

from MEDS_transforms.extract import CONFIG_YAML
from MEDS_transforms.mapreduce.mapper import map_over
from MEDS_transforms.mapreduce.utils import shard_iterator_by_shard_map


def merge_subdirs_and_sort(
Expand Down Expand Up @@ -236,37 +235,10 @@ def main(cfg: DictConfig):
additional_sort_by=cfg.stage_cfg.get("additional_sort_by", None),
)

shard_map_fp = Path(cfg.shards_map_fp)
if not shard_map_fp.exists():
raise FileNotFoundError(f"Shard map file not found at {str(shard_map_fp.resolve())}")

shards = list(json.loads(shard_map_fp.read_text()).keys())

def shard_iterator(cfg: DictConfig) -> tuple[list[str], bool]:
input_dir = Path(cfg.stage_cfg.data_input_dir)
output_dir = Path(cfg.stage_cfg.output_dir)

if cfg.stage_cfg.get("train_only", None):
raise ValueError("train_only is not supported for this stage.")

if "worker" in cfg:
random.seed(cfg.worker)
random.shuffle(shards)

logger.info(f"Mapping computation over a maximum of {len(shards)} shards")

out = []
for sh in shards:
in_fp = input_dir / sh
out_fp = output_dir / f"{sh}.parquet"
out.append((in_fp, out_fp))

return out, False

map_over(
cfg,
read_fn=read_fn,
shard_iterator_fntr=shard_iterator,
shard_iterator_fntr=shard_iterator_by_shard_map,
)


Expand Down
138 changes: 125 additions & 13 deletions src/MEDS_transforms/mapreduce/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,44 @@ def rwlock_wrap[
raise e


def shuffle_shards(shards: list[str], cfg: DictConfig) -> list[str]:
"""Shuffle the shards in a deterministic, pseudo-random way based on the worker ID in the configuration.

Args:
shards: The list of shards to shuffle.
cfg: The configuration dictionary for the overall pipeline. Should (possibly) contain the following
keys (some are optional, as marked below):
- `worker` (optional): The worker ID for the MR worker; this is also used to seed the
randomization process. If not provided, the randomization process will be unseeded.

Returns:
The shuffled list of shards.

Examples:
>>> cfg = DictConfig({"worker": 1})
>>> shards = ["train/0", "train/1", "tuning", "held_out"]
>>> shuffle_shards(shards, cfg)
['train/1', 'held_out', 'tuning', 'train/0']
>>> cfg = DictConfig({"worker": 2})
>>> shuffle_shards(shards, cfg)
['tuning', 'held_out', 'train/1', 'train/0']
"""

if "worker" in cfg:
add_str = str(cfg["worker"])
else:
add_str = str(datetime.now())

shard_keys = []
for shard in shards:
shard_hash = hashlib.sha256((add_str + shard).encode("utf-8")).hexdigest()
if shard_hash in shard_keys:
raise ValueError(f"Hash collision for shard {shard} with add_str {add_str}!")
shard_keys.append(int(shard_hash, 16))

return [shard for _, shard in sorted(zip(shard_keys, shards))]


def shard_iterator(
cfg: DictConfig,
out_suffix: str = ".parquet",
Expand Down Expand Up @@ -422,19 +460,7 @@ def shard_iterator(
"and relying on `patient_splits.parquet` for filtering."
)

if "worker" in cfg:
add_str = str(cfg.worker)
else:
add_str = str(datetime.now())

shard_keys = []
for shard in shards:
shard_hash = hashlib.sha256((add_str + shard).encode("utf-8")).hexdigest()
if shard_hash in shard_keys:
raise ValueError(f"Hash collision for shard {shard} with add_str {add_str}!")
shard_keys.append(int(shard_hash, 16))

shards = [shard for _, shard in sorted(zip(shard_keys, shards))]
shards = shuffle_shards(shards, cfg)

logger.info(f"Mapping computation over a maximum of {len(shards)} shards")

Expand All @@ -446,3 +472,89 @@ def shard_iterator(
out.append((in_fp, out_fp))

return out, includes_only_train


def shard_iterator_by_shard_map(cfg: DictConfig) -> tuple[list[str], bool]:
"""Returns an iterator over shard paths and output paths based on a shard map file, not files on disk.
Args:
cfg: The configuration dictionary for the overall pipeline. Should contain the following keys:
- `shards_map_fp` (mandatory): The file path to the shards map file.
- `stage_cfg.data_input_dir` (mandatory): The directory containing the input data.
- `stage_cfg.output_dir` (mandatory): The directory to write the output data.
- `worker` (optional): The worker ID for the MR worker; this is also used to seed the
Returns:
A list of pairs of input and output file paths for each shard, as well as a boolean indicating
whether the shards are only train shards.
Raises:
ValueError: If the `shards_map_fp` key is not present in the configuration.
FileNotFoundError: If the shard map file is not found at the path specified in the configuration.
ValueError: If the `train_only` key is present in the configuration.
Examples:
>>> from tempfile import NamedTemporaryFile, TemporaryDirectory
>>> import json
>>> shard_iterator_by_shard_map(DictConfig({}))
Traceback (most recent call last):
...
ValueError: shards_map_fp must be present in the configuration for a map-based shard iterator.
>>> with NamedTemporaryFile() as tmp:
... cfg = DictConfig({"shards_map_fp": tmp.name, "stage_cfg": {"train_only": True}})
... shard_iterator_by_shard_map(cfg)
Traceback (most recent call last):
...
ValueError: train_only is not supported for this stage.
>>> with TemporaryDirectory() as tmp:
... tmp = Path(tmp)
... shards_map_fp = tmp / "shards_map.json"
... cfg = DictConfig({"shards_map_fp": shards_map_fp, "stage_cfg": {"train_only": False}})
... shard_iterator_by_shard_map(cfg)
Traceback (most recent call last):
...
FileNotFoundError: Shard map file not found at ...shards_map.json
>>> shards = {"train/0": [1, 2, 3, 4], "train/1": [5, 6, 7], "tuning": [8], "held_out": [9]}
>>> with NamedTemporaryFile() as tmp:
... _ = Path(tmp.name).write_text(json.dumps(shards))
... cfg = DictConfig({
... "shards_map_fp": tmp.name,
... "worker": 1,
... "stage_cfg": {"data_input_dir": "data", "output_dir": "output"},
... })
... fps, includes_only_train = shard_iterator_by_shard_map(cfg)
>>> fps # doctest: +NORMALIZE_WHITESPACE
[(PosixPath('data/train/1'), PosixPath('output/train/1.parquet')),
(PosixPath('data/held_out'), PosixPath('output/held_out.parquet')),
(PosixPath('data/tuning'), PosixPath('output/tuning.parquet')),
(PosixPath('data/train/0'), PosixPath('output/train/0.parquet'))]
>>> includes_only_train
False
"""

if "shards_map_fp" not in cfg:
raise ValueError("shards_map_fp must be present in the configuration for a map-based shard iterator.")

if cfg.stage_cfg.get("train_only", None):
raise ValueError("train_only is not supported for this stage.")

shard_map_fp = Path(cfg.shards_map_fp)
if not shard_map_fp.exists():
raise FileNotFoundError(f"Shard map file not found at {str(shard_map_fp.resolve())}")

shards = list(json.loads(shard_map_fp.read_text()).keys())

input_dir = Path(cfg.stage_cfg.data_input_dir)
output_dir = Path(cfg.stage_cfg.output_dir)

shards = shuffle_shards(shards, cfg)

logger.info(f"Mapping computation over a maximum of {len(shards)} shards")

out = []
for sh in shards:
in_fp = input_dir / sh
out_fp = output_dir / f"{sh}.parquet"
out.append((in_fp, out_fp))

return out, False

0 comments on commit fd6e77f

Please sign in to comment.