Skip to content

Commit

Permalink
Merge pull request #143 from mmcdermott/139_fix_reshard_to_split_in_p…
Browse files Browse the repository at this point in the history
…arallel_mode

Fixes brittle reduce stage checking in Reshard stage.
  • Loading branch information
mmcdermott authored Aug 11, 2024
2 parents 2d1c4cf + 3fc909e commit cda4f06
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 94 deletions.
3 changes: 2 additions & 1 deletion src/MEDS_transforms/aggregate_code_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from MEDS_transforms import PREPROCESS_CONFIG_YAML
from MEDS_transforms.mapreduce.mapper import map_over
from MEDS_transforms.mapreduce.utils import is_complete_parquet_file
from MEDS_transforms.utils import write_lazyframe


Expand Down Expand Up @@ -667,7 +668,7 @@ def run_map_reduce(cfg: DictConfig):

logger.info("Starting reduction process")

while not all(fp.is_file() for fp in all_out_fps):
while not all(is_complete_parquet_file(fp) for fp in all_out_fps):
logger.info("Waiting to begin reduction for all files to be written...")
time.sleep(cfg.polling_time)

Expand Down
1 change: 0 additions & 1 deletion src/MEDS_transforms/mapreduce/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,6 @@ def map_over(
read_fn,
write_fn,
compute_fn,
do_return=False,
do_overwrite=cfg.do_overwrite,
)
all_out_fps.append(out_fp)
Expand Down
96 changes: 62 additions & 34 deletions src/MEDS_transforms/mapreduce/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,46 @@
from datetime import datetime
from pathlib import Path

import pyarrow.parquet as pq
from loguru import logger
from omegaconf import DictConfig

LOCK_TIME_FMT = "%Y-%m-%dT%H:%M:%S.%f"


def is_complete_parquet_file(fp: Path) -> bool:
"""Check if a parquet file is complete.
Args:
fp: The file path to the parquet file.
Returns:
True if the parquet file is complete, False otherwise.
Examples:
>>> import tempfile, polars as pl
>>> df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
>>> with tempfile.NamedTemporaryFile() as tmp:
... df.write_parquet(tmp)
... is_complete_parquet_file(tmp)
True
>>> with tempfile.NamedTemporaryFile() as tmp:
... df.write_csv(tmp)
... is_complete_parquet_file(tmp)
False
>>> with tempfile.TemporaryDirectory() as tmp:
... tmp = Path(tmp)
... is_complete_parquet_file(tmp / "nonexistent.parquet")
False
"""

try:
_ = pq.ParquetFile(fp)
return True
except Exception:
return False


def get_earliest_lock(cache_directory: Path) -> datetime | None:
"""Returns the earliest start time of any lock file present in a cache directory, or None if none exist.
Expand Down Expand Up @@ -83,6 +117,13 @@ def register_lock(cache_directory: Path) -> tuple[datetime, Path]:
return lock_time, lock_fp


def default_file_checker(fp: Path) -> bool:
"""Check if a file exists and is complete."""
if fp.suffix == ".parquet":
return is_complete_parquet_file(fp)
return fp.is_file()


def rwlock_wrap[
DF_T
](
Expand All @@ -92,8 +133,8 @@ def rwlock_wrap[
write_fn: Callable[[DF_T, Path], None],
compute_fn: Callable[[DF_T], DF_T],
do_overwrite: bool = False,
do_return: bool = False,
) -> tuple[bool, DF_T | None]:
out_fp_checker: Callable[[Path], bool] = default_file_checker,
) -> bool:
"""Wrap a series of file-in file-out map transformations on a dataframe with caching and locking.

Args:
Expand All @@ -112,11 +153,9 @@ def rwlock_wrap[
type DF_T and return a dataframe of (generic) type DF_T.
do_overwrite: If True, the output file will be overwritten if it already exists. This is `False` by
default.
do_return: If True, the final dataframe will be returned. This is `False` by default.

Returns:
The dataframe resulting from the transformations applied in sequence to the dataframe stored in
`in_fp`.
True if the computation was run, False otherwise.

Examples:
>>> import polars as pl
Expand All @@ -132,7 +171,7 @@ def rwlock_wrap[
>>> read_fn = pl.read_csv
>>> write_fn = pl.DataFrame.write_csv
>>> compute_fn = lambda df: df.with_columns(pl.col("c") * 2).filter(pl.col("c") > 4)
>>> result_computed = rwlock_wrap(in_fp, out_fp, read_fn, write_fn, compute_fn, do_return=False)
>>> result_computed = rwlock_wrap(in_fp, out_fp, read_fn, write_fn, compute_fn)
>>> assert result_computed
>>> print(out_fp.read_text())
a,b,c
Expand All @@ -145,49 +184,33 @@ def rwlock_wrap[
Traceback (most recent call last):
...
polars.exceptions.ColumnNotFoundError: unable to find column "d"; valid columns: ["a", "b", "c"]
>>> cache_directory = root / f".output_cache"
>>> cache_directory = root / f".output.csv_cache"
>>> lock_dir = cache_directory / "locks"
>>> assert not list(lock_dir.iterdir())
>>> def lock_dir_checker_fn(df: pl.DataFrame) -> pl.DataFrame:
... print(f"Lock dir empty? {not (list(lock_dir.iterdir()))}")
... return df
>>> result_computed, out_df = rwlock_wrap(
... in_fp, out_fp, read_fn, write_fn, lock_dir_checker_fn, do_return=True
... )
>>> result_computed = rwlock_wrap(in_fp, out_fp, read_fn, write_fn, lock_dir_checker_fn)
Lock dir empty? False
>>> assert result_computed
>>> out_df
shape: (3, 3)
┌─────┬─────┬─────┐
abc
---------
i64i64i64
╞═════╪═════╪═════╡
123
34-1
356
└─────┴─────┴─────┘
>>> directory.cleanup()
"""

if out_fp.is_file():
if out_fp_checker(out_fp):
if do_overwrite:
logger.info(f"Deleting existing {out_fp} as do_overwrite={do_overwrite}.")
out_fp.unlink()
else:
logger.info(f"{out_fp} exists; reading directly and returning.")
if do_return:
return True, read_fn(out_fp)
else:
return True, None
logger.info(f"{out_fp} exists; returning.")
return False

cache_directory = out_fp.parent / f".{out_fp.stem}_cache"
cache_directory = out_fp.parent / f".{out_fp.parts[-1]}_cache"
cache_directory.mkdir(exist_ok=True, parents=True)

earliest_lock_time = get_earliest_lock(cache_directory)
if earliest_lock_time is not None:
logger.info(f"{out_fp} is in progress as of {earliest_lock_time}. Returning.")
return False, None if do_return else False
return False

st_time, lock_fp = register_lock(cache_directory)

Expand All @@ -196,12 +219,20 @@ def rwlock_wrap[
if earliest_lock_time < st_time:
logger.info(f"Earlier lock found at {earliest_lock_time}. Deleting current lock and returning.")
lock_fp.unlink()
return False, None if do_return else False
return False

logger.info(f"Reading input dataframe from {in_fp}")
df = read_fn(in_fp)
logger.info("Read dataset")

earliest_lock_time = get_earliest_lock(cache_directory)
if earliest_lock_time < st_time:
logger.info(
f"Earlier lock found post read at {earliest_lock_time}. Deleting current lock and returning."
)
lock_fp.unlink()
return False

try:
df = compute_fn(df)
logger.info(f"Writing final output to {out_fp}")
Expand All @@ -210,10 +241,7 @@ def rwlock_wrap[
logger.info(f"Leaving cache directory {cache_directory}, but clearing lock at {lock_fp}")
lock_fp.unlink()

if do_return:
return True, df
else:
return True, None
return True

except Exception as e:
logger.warning(f"Clearing lock due to Exception {e} at {lock_fp} after {datetime.now() - st_time}")
Expand Down
Loading

0 comments on commit cda4f06

Please sign in to comment.