Skip to content

Commit

Permalink
Merge branch 'dev' into alt_51_raw_values
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Aug 3, 2024
2 parents 4b3cfec + e97c587 commit 76ce942
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 101 deletions.
3 changes: 0 additions & 3 deletions MIMIC-IV_Example/configs/event_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ icu/inputevents:
description: ["omop_concept_name", "label"] # List of strings are columns to be collated
itemid: "itemid (omop_source_code)"
parent_codes: "{omop_vocabulary_id}/{omop_concept_code}"
rateuom: null # A null column means this column is needed in pulling from the metadata.
input_end:
code:
- INFUSION_END
Expand All @@ -296,8 +295,6 @@ icu/inputevents:
description: ["omop_concept_name", "label"] # List of strings are columns to be collated
itemid: "itemid (omop_source_code)"
parent_codes: "{omop_vocabulary_id}/{omop_concept_code}"
statusdescription: null
amountuom: null
patient_weight:
code:
- PATIENT_WEIGHT_AT_INFUSION
Expand Down
20 changes: 18 additions & 2 deletions MIMIC-IV_Example/joint_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,17 @@ MEDS_extract-shard_events \
cohort_dir="$MIMICIV_MEDS_DIR" \
stage="shard_events" \
stage_configs.shard_events.infer_schema_length=999999999 \
etl_metadata.dataset_name="MIMIC-IV" \
etl_metadata.dataset_version="2.2" \
event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@"

echo "Splitting patients in serial"
MEDS_extract-split_and_shard_patients \
input_dir="$MIMICIV_PREMEDS_DIR" \
cohort_dir="$MIMICIV_MEDS_DIR" \
stage="split_and_shard_patients" \
etl_metadata.dataset_name="MIMIC-IV" \
etl_metadata.dataset_version="2.2" \
event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@"

echo "Converting to sharded events with $N_PARALLEL_WORKERS workers in parallel"
Expand All @@ -68,6 +72,8 @@ MEDS_extract-convert_to_sharded_events \
input_dir="$MIMICIV_PREMEDS_DIR" \
cohort_dir="$MIMICIV_MEDS_DIR" \
stage="convert_to_sharded_events" \
etl_metadata.dataset_name="MIMIC-IV" \
etl_metadata.dataset_version="2.2" \
event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@"

echo "Merging to a MEDS cohort with $N_PARALLEL_WORKERS workers in parallel"
Expand All @@ -77,7 +83,9 @@ MEDS_extract-merge_to_MEDS_cohort \
hydra/launcher=joblib \
input_dir="$MIMICIV_PREMEDS_DIR" \
cohort_dir="$MIMICIV_MEDS_DIR" \
stage="merge_to_MEDS_cohort"
stage="merge_to_MEDS_cohort" \
etl_metadata.dataset_name="MIMIC-IV" \
etl_metadata.dataset_version="2.2" \
event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@"

echo "Aggregating initial code stats with $N_PARALLEL_WORKERS workers in parallel"
Expand All @@ -88,7 +96,9 @@ MEDS_transform-aggregate_code_metadata \
hydra/launcher=joblib \
input_dir="$MIMICIV_PREMEDS_DIR" \
cohort_dir="$MIMICIV_MEDS_DIR" \
stage="aggregate_code_metadata"
stage="aggregate_code_metadata" \
etl_metadata.dataset_name="MIMIC-IV" \
etl_metadata.dataset_version="2.2" \
event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@"

# TODO -- make this the pre-meds dir and have the pre-meds script symlink
Expand All @@ -97,6 +107,8 @@ MEDS_extract-extract_code_metadata \
input_dir="$MIMICIV_RAW_DIR" \
cohort_dir="$MIMICIV_MEDS_DIR" \
stage="extract_code_metadata" \
etl_metadata.dataset_name="MIMIC-IV" \
etl_metadata.dataset_version="2.2" \
event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@"

echo "Finalizing MEDS data with $N_PARALLEL_WORKERS workers in parallel"
Expand All @@ -107,11 +119,15 @@ MEDS_extract-finalize_MEDS_data \
input_dir="$MIMICIV_RAW_DIR" \
cohort_dir="$MIMICIV_MEDS_DIR" \
stage="finalize_MEDS_data" \
etl_metadata.dataset_name="MIMIC-IV" \
etl_metadata.dataset_version="2.2" \
event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@"

echo "Finalizing MEDS metadata in serial."
MEDS_extract-finalize_MEDS_metadata \
input_dir="$MIMICIV_RAW_DIR" \
cohort_dir="$MIMICIV_MEDS_DIR" \
stage="finalize_MEDS_metadata" \
etl_metadata.dataset_name="MIMIC-IV" \
etl_metadata.dataset_version="2.2" \
event_conversion_config_fp=./MIMIC-IV_Example/configs/event_configs.yaml "$@"
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
extract_code_metadata:
is_metadata: true
description_seperator: "\n"
description_separator: "\n"
14 changes: 6 additions & 8 deletions src/MEDS_transforms/extract/convert_to_sharded_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy
┌────────────┬─────────────────┬─────────────────────┬───────────────────┬────────────┐
│ patient_id ┆ code ┆ time ┆ categorical_value ┆ text_value │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ u8 ┆ str ┆ datetime[μs] ┆ cat ┆ str │
│ u8 ┆ str ┆ datetime[μs] ┆ str ┆ str │
╞════════════╪═════════════════╪═════════════════════╪═══════════════════╪════════════╡
│ 1 ┆ DISCHARGE//Home ┆ 2021-01-01 11:23:45 ┆ AOx4 ┆ Home │
│ 1 ┆ DISCHARGE//SNF ┆ 2021-01-02 12:34:56 ┆ AO ┆ SNF │
Expand Down Expand Up @@ -476,13 +476,11 @@ def extract_event(df: pl.LazyFrame, event_cfg: dict[str, str | None]) -> pl.Lazy
case "text_value" if not df.schema[v] == pl.Utf8:
logger.warning(f"Converting text_value to string for {code_expr}")
col = col.cast(pl.Utf8, strict=False)
case "categorical_value" if not isinstance(df.schema[v], pl.Categorical):
logger.warning(f"Converting categorical_value to categorical for {code_expr}")
col = col.cast(pl.Utf8).cast(pl.Categorical)
case "categorical_value" if not is_str:
logger.warning(f"Converting categorical_value to string for {code_expr}")
col = col.cast(pl.Utf8)
case _ if is_str:
# TODO(mmd): Is this right? Is this always a good idea? It probably usually is, but maybe not
# always. Maybe a check on unique values first?
col = col.cast(pl.Categorical)
pass
case _ if not (is_numeric or is_str or is_cat):
raise ValueError(
f"Source column '{v}' for event column {k} is not numeric, string, or categorical! "
Expand Down Expand Up @@ -625,7 +623,7 @@ def convert_to_events(
┌────────────┬───────────┬─────────────────────┬────────────────┬───────────────────────┬────────────────────┬───────────┐
│ patient_id ┆ code ┆ time ┆ admission_type ┆ severity_on_admission ┆ discharge_location ┆ eye_color │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ u8 ┆ str ┆ datetime[μs] ┆ cat ┆ f64 ┆ cat ┆ cat │
│ u8 ┆ str ┆ datetime[μs] ┆ str ┆ f64 ┆ cat ┆ cat │
╞════════════╪═══════════╪═════════════════════╪════════════════╪═══════════════════════╪════════════════════╪═══════════╡
│ 1 ┆ ADMISSION ┆ 2021-01-01 00:00:00 ┆ A ┆ 1.0 ┆ null ┆ null │
│ 1 ┆ ADMISSION ┆ 2021-01-02 00:00:00 ┆ B ┆ 2.0 ┆ null ┆ null │
Expand Down
10 changes: 6 additions & 4 deletions src/MEDS_transforms/extract/extract_code_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,8 @@ def reducer_fn(*dfs):

reduced = reducer_fn(*[pl.scan_parquet(fp, glob=False) for fp in all_out_fps])
join_cols = ["code", *cfg.get("code_modifier_cols", [])]
metadata_cols = [c for c in reduced.columns if c not in join_cols]
reduced_cols = reduced.collect_schema().names()
metadata_cols = [c for c in reduced_cols if c not in join_cols]

n_unique_obs = reduced.select(pl.n_unique(*join_cols)).collect().item()
n_rows = reduced.select(pl.count()).collect().item()
Expand All @@ -421,11 +422,12 @@ def reducer_fn(*dfs):
if n_unique_obs != n_rows:
aggs = {c: pl.col(c) for c in metadata_cols if c not in MEDS_METADATA_MANDATORY_TYPES}
if "description" in metadata_cols:
aggs["description"] = pl.col("description").list.join(cfg.stage_cfg.description_separator)
separator = cfg.stage_cfg.description_separator
aggs["description"] = pl.col("description").str.concat(separator)
if "parent_codes" in metadata_cols:
aggs["parent_codes"] = pl.col("parent_codes").explode().implode()
aggs["parent_codes"] = pl.col("parent_codes").explode()

reduced = reduced.group_by(join_cols).agg(*(pl.col(c) for c in metadata_cols))
reduced = reduced.group_by(join_cols).agg(**aggs)

reduced = reduced.collect()

Expand Down
8 changes: 2 additions & 6 deletions src/MEDS_transforms/mapreduce/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

DF_T = TypeVar("DF_T")

MAP_FN_T = Callable[[DF_T], DF_T] | tuple[Callable[[DF_T], DF_T]]
MAP_FN_T = Callable[[DF_T], DF_T]
SHARD_GEN_T = Generator[tuple[Path, Path], None, None]
SHARD_ITR_FNTR_T = Callable[[DictConfig], SHARD_GEN_T]

Expand Down Expand Up @@ -46,9 +46,6 @@ def map_over(
if compute_fn is None:
compute_fn = identity_fn

if not isinstance(compute_fn, tuple):
compute_fn = (compute_fn,)

process_split = cfg.stage_cfg.get("process_split", None)
split_fp = Path(cfg.input_dir) / "metadata" / "patient_split.parquet"
shards_map_fp = Path(cfg.shards_map_fp) if "shards_map_fp" in cfg else None
Expand Down Expand Up @@ -79,9 +76,8 @@ def map_over(
out_fp,
read_fn,
write_fn,
*compute_fn,
compute_fn,
do_return=False,
cache_intermediate=False,
do_overwrite=cfg.do_overwrite,
)
all_out_fps.append(out_fp)
Expand Down
91 changes: 16 additions & 75 deletions src/MEDS_transforms/mapreduce/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import json
import random
import shutil
from collections.abc import Callable
from datetime import datetime
from pathlib import Path
Expand Down Expand Up @@ -91,9 +90,7 @@ def rwlock_wrap[
out_fp: Path,
read_fn: Callable[[Path], DF_T],
write_fn: Callable[[DF_T, Path], None],
*transform_fns: Callable[[DF_T], DF_T],
cache_intermediate: bool = True,
clear_cache_on_completion: bool = True,
compute_fn: Callable[[DF_T], DF_T],
do_overwrite: bool = False,
do_return: bool = False,
) -> tuple[bool, DF_T | None]:
Expand All @@ -111,21 +108,8 @@ def rwlock_wrap[
loading to further accelerate unnecessary reads when resuming from intermediate cached steps.
write_fn: Function that writes the dataframe to a file. This must take as input a dataframe of
(generic) type DF_T and a Path object, and will write the dataframe to that file.
transform_fns: A series of functions that transform the dataframe. Each function must take as input
a dataframe of (generic) type DF_T and return a dataframe of (generic) type DF_T. The functions
will be applied in the passed order.
cache_intermediate: If True, intermediate outputs of the transformations will be cached in a hidden
directory in the same parent directory as `out_fp` of the form
`{out_fp.parent}/.{out_fp.stem}_cache`. This can be useful for debugging and resuming from
intermediate steps when nontrivial transformations are composed. Cached files will be named
`step_{i}.output` where `i` is the index of the transformation function in `transform_fns`. **Note
that if you change the order of the transformations, the cache will be no longer valid but the
system will _not_ automatically delete the cache!**. This is `True` by default.
If `do_overwrite=True`, any prior individual cache files that are detected during the run will be
deleted before their corresponding step is run. If `do_overwrite=False` and a cache file exists,
that step of the transformation will be skipped and the cache file will be read directly.
clear_cache_on_completion: If True, the cache directory will be deleted after the final output is
written. This is `True` by default.
compute_fn: A function that transform the dataframe, which must take as input a dataframe of (generic)
type DF_T and return a dataframe of (generic) type DF_T.
do_overwrite: If True, the output file will be overwritten if it already exists. This is `False` by
default.
do_return: If True, the final dataframe will be returned. This is `False` by default.
Expand All @@ -147,51 +131,30 @@ def rwlock_wrap[
>>> in_df.write_csv(in_fp)
>>> read_fn = pl.read_csv
>>> write_fn = pl.DataFrame.write_csv
>>> transform_fns = [
... lambda df: df.with_columns(pl.col("c") * 2),
... lambda df: df.filter(pl.col("c") > 4)
... ]
>>> result_computed = rwlock_wrap(in_fp, out_fp, read_fn, write_fn, *transform_fns, do_return=False)
>>> compute_fn = lambda df: df.with_columns(pl.col("c") * 2).filter(pl.col("c") > 4)
>>> result_computed = rwlock_wrap(in_fp, out_fp, read_fn, write_fn, compute_fn, do_return=False)
>>> assert result_computed
>>> print(out_fp.read_text())
a,b,c
1,2,6
3,5,12
<BLANKLINE>
>>> out_fp.unlink()
>>> cache_directory = root / f".output_cache"
>>> assert not cache_directory.is_dir()
>>> transform_fns = [
... lambda df: df.with_columns(pl.col("c") * 2),
... lambda df: df.filter(pl.col("d") > 4)
... ]
>>> rwlock_wrap(in_fp, out_fp, read_fn, write_fn, *transform_fns)
>>> compute_fn = lambda df: df.with_columns(pl.col("c") * 2).filter(pl.col("d") > 4)
>>> rwlock_wrap(in_fp, out_fp, read_fn, write_fn, compute_fn)
Traceback (most recent call last):
...
polars.exceptions.ColumnNotFoundError: unable to find column "d"; valid columns: ["a", "b", "c"]
>>> assert cache_directory.is_dir()
>>> cache_fp = cache_directory / "step_0.output"
>>> pl.read_csv(cache_fp)
shape: (3, 3)
┌─────┬─────┬─────┐
abc
---------
i64i64i64
╞═════╪═════╪═════╡
126
34-2
3512
└─────┴─────┴─────┘
>>> shutil.rmtree(cache_directory)
>>> cache_directory = root / f".output_cache"
>>> lock_dir = cache_directory / "locks"
>>> assert not lock_dir.exists()
>>> assert not list(lock_dir.iterdir())
>>> def lock_dir_checker_fn(df: pl.DataFrame) -> pl.DataFrame:
... print(f"Lock dir exists? {lock_dir.exists()}")
... print(f"Lock dir empty? {not (list(lock_dir.iterdir()))}")
... return df
>>> result_computed, out_df = rwlock_wrap(
... in_fp, out_fp, read_fn, write_fn, lock_dir_checker_fn, do_return=True
... )
Lock dir exists? True
Lock dir empty? False
>>> assert result_computed
>>> out_df
shape: (3, 3)
Expand Down Expand Up @@ -240,40 +203,18 @@ def rwlock_wrap[
logger.info("Read dataset")

try:
for i, transform_fn in enumerate(transform_fns):
cache_fp = cache_directory / f"step_{i}.output"

st_time_step = datetime.now()
if cache_fp.is_file():
if do_overwrite:
logger.info(
f"Deleting existing cached output for step {i} " f"as do_overwrite={do_overwrite}"
)
cache_fp.unlink()
else:
logger.info(f"Reading cached output for step {i}")
df = read_fn(cache_fp)
else:
df = transform_fn(df)

if cache_intermediate and i < len(transform_fns) - 1:
logger.info(f"Writing intermediate output for step {i} to {cache_fp}")
write_fn(df, cache_fp)
logger.info(f"Completed step {i} in {datetime.now() - st_time_step}")

df = compute_fn(df)
logger.info(f"Writing final output to {out_fp}")
write_fn(df, out_fp)
logger.info(f"Succeeded in {datetime.now() - st_time}")
if clear_cache_on_completion:
logger.info(f"Clearing cache directory {cache_directory}")
shutil.rmtree(cache_directory)
else:
logger.info(f"Leaving cache directory {cache_directory}, but clearing lock at {lock_fp}")
lock_fp.unlink()
logger.info(f"Leaving cache directory {cache_directory}, but clearing lock at {lock_fp}")
lock_fp.unlink()

if do_return:
return True, df
else:
return True

except Exception as e:
logger.warning(f"Clearing lock due to Exception {e} at {lock_fp} after {datetime.now() - st_time}")
lock_fp.unlink()
Expand Down
2 changes: 0 additions & 2 deletions src/MEDS_transforms/transforms/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@ def main(cfg: DictConfig):
write_lazyframe,
extract_statics_and_schema,
do_return=False,
cache_intermediate=False,
do_overwrite=cfg.do_overwrite,
)

Expand All @@ -265,7 +264,6 @@ def main(cfg: DictConfig):
write_lazyframe,
extract_seq_of_patient_events,
do_return=False,
cache_intermediate=False,
do_overwrite=cfg.do_overwrite,
)

Expand Down

0 comments on commit 76ce942

Please sign in to comment.