Skip to content

Commit

Permalink
Merge pull request #46 from mmcdermott/order-as-in-event-config
Browse files Browse the repository at this point in the history
Maintain the order of event from event_config while merging cohorts
  • Loading branch information
mmcdermott authored Jul 26, 2024
2 parents 6b75771 + 53fcdd8 commit 5177cdf
Showing 1 changed file with 26 additions and 11 deletions.
37 changes: 26 additions & 11 deletions src/MEDS_transforms/extract/merge_to_MEDS_cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import hydra
import polars as pl
from loguru import logger
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf

from MEDS_transforms.extract import CONFIG_YAML
from MEDS_transforms.mapreduce.mapper import map_over, shard_iterator
Expand All @@ -14,12 +14,17 @@


def merge_subdirs_and_sort(
sp_dir: Path, unique_by: list[str] | str | None, additional_sort_by: list[str] | None = None
sp_dir: Path,
event_subsets: list[str],
unique_by: list[str] | str | None,
additional_sort_by: list[str] | None = None,
) -> pl.LazyFrame:
"""This function reads all parquet files in subdirs of `sp_dir` and merges them into a single dataframe.
Args:
sp_dir: The directory containing the subdirs with parquet files to be merged.
event_subsets: The list of event table paths passed to maintain the order in event_configs.yaml
while merging the events.
unique_by: The list of columns that should be ensured to be unique after the dataframes are merged. If
`None`, this is ignored. If `*`, all columns are used. If a list of strings, only the columns in
the list are used. If a column is not found in the dataframe, it is omitted from the unique-by, a
Expand Down Expand Up @@ -61,10 +66,10 @@ def merge_subdirs_and_sort(
... })
>>> with TemporaryDirectory() as tmpdir:
... sp_dir = Path(tmpdir)
... merge_subdirs_and_sort(sp_dir, unique_by=None)
... merge_subdirs_and_sort(sp_dir, event_subsets=[], unique_by=None)
Traceback (most recent call last):
...
FileNotFoundError: No files found in ...
FileNotFoundError: No parquet files found in ...
>>> with TemporaryDirectory() as tmpdir:
... sp_dir = Path(tmpdir)
... (sp_dir / "subdir1").mkdir()
Expand All @@ -74,6 +79,7 @@ def merge_subdirs_and_sort(
... df3.write_parquet(sp_dir / "subdir2" / "df.parquet")
... merge_subdirs_and_sort(
... sp_dir,
... event_subsets=["subdir1", "subdir2"],
... unique_by=None,
... additional_sort_by=["code", "numerical_value", "missing_col_will_not_error"]
... ).collect()
Expand Down Expand Up @@ -101,6 +107,7 @@ def merge_subdirs_and_sort(
... df3.write_parquet(sp_dir / "subdir2" / "df.parquet")
... merge_subdirs_and_sort(
... sp_dir,
... event_subsets=["subdir1", "subdir2"],
... unique_by="*",
... additional_sort_by=["code", "numerical_value"]
... ).collect()
Expand Down Expand Up @@ -130,6 +137,7 @@ def merge_subdirs_and_sort(
... # the unique-by constraint.
... merge_subdirs_and_sort(
... sp_dir,
... event_subsets=["subdir1", "subdir2"],
... unique_by=["patient_id", "timestamp", "code"],
... additional_sort_by=["code", "numerical_value"]
... ).select("patient_id", "timestamp", "code").collect()
Expand All @@ -147,11 +155,15 @@ def merge_subdirs_and_sort(
│ 3 ┆ 8 ┆ E │
└────────────┴───────────┴──────┘
"""

files_to_read = list(sp_dir.glob("**/*.parquet"))

files_to_read = [fp for es in event_subsets for fp in (sp_dir / es).glob("*.parquet")]
if not files_to_read:
raise FileNotFoundError(f"No files found in {sp_dir}/**/*.parquet.")
raise FileNotFoundError(f"No parquet files found in {sp_dir}/**/*.parquet.")

if len(dirs_to_read := {fp.parent for fp in files_to_read}) != len(event_subsets):
raise RuntimeError(
"Number of found subsets ({}) does not match "
"number of subsets in event_config ({}): {}".format(len(dirs_to_read), len(event_subsets), sp_dir)
)

file_strs = "\n".join(f" - {str(fp.resolve())}" for fp in files_to_read)
logger.info(f"Reading {len(files_to_read)} files:\n{file_strs}")
Expand All @@ -165,15 +177,15 @@ def merge_subdirs_and_sort(
case None:
pass
case "*":
df = df.unique(maintain_order=False)
df = df.unique(maintain_order=True)
case list() if len(unique_by) > 0 and all(isinstance(u, str) for u in unique_by):
subset = []
for u in unique_by:
if u in df_columns:
subset.append(u)
else:
logger.warning(f"Column {u} not found in dataframe. Omitting from unique-by subset.")
df = df.unique(maintain_order=False, subset=subset)
df = df.unique(maintain_order=True, subset=subset)
case _:
raise ValueError(f"Invalid unique_by value: {unique_by}")

Expand All @@ -185,7 +197,7 @@ def merge_subdirs_and_sort(
else:
logger.warning(f"Column {s} not found in dataframe. Omitting from sort-by list.")

return df.sort(by=sort_by, multithreaded=False)
return df.sort(by=sort_by, maintain_order=True, multithreaded=False)


@hydra.main(version_base=None, config_path=str(CONFIG_YAML.parent), config_name=CONFIG_YAML.stem)
Expand Down Expand Up @@ -214,9 +226,12 @@ def main(cfg: DictConfig):
Returns:
Writes the merged dataframes to the shard-specific output filepath in the `cfg.stage_cfg.output_dir`.
"""
event_conversion_cfg = OmegaConf.load(cfg.event_conversion_config_fp)
event_conversion_cfg.pop("patient_id_col", None)

read_fn = partial(
merge_subdirs_and_sort,
event_subsets=list(event_conversion_cfg.keys()),
unique_by=cfg.stage_cfg.get("unique_by", None),
additional_sort_by=cfg.stage_cfg.get("additional_sort_by", None),
)
Expand Down

0 comments on commit 5177cdf

Please sign in to comment.