From f9d7a0e777769202bc889a6bdfddff2429f68802 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Mon, 29 Jul 2024 17:33:25 -0400 Subject: [PATCH 01/53] Added a start at a quantile implementation in the existing MR framework. --- .../aggregate_code_metadata.py | 41 ++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/src/MEDS_transforms/aggregate_code_metadata.py b/src/MEDS_transforms/aggregate_code_metadata.py index 85ac23c..aa405b9 100755 --- a/src/MEDS_transforms/aggregate_code_metadata.py +++ b/src/MEDS_transforms/aggregate_code_metadata.py @@ -62,6 +62,7 @@ class METADATA_FN(StrEnum): the code "values/min": Collects the minimum non-null, non-nan numerical_value value for the code & modifiers "values/max": Collects the maximum non-null, non-nan numerical_value value for the code & modifiers + "values/quantiles": TODO """ CODE_N_PATIENTS = "code/n_patients" @@ -73,6 +74,7 @@ class METADATA_FN(StrEnum): VALUES_SUM_SQD = "values/sum_sqd" VALUES_MIN = "values/min" VALUES_MAX = "values/max" + VALUES_QUANTILES = "values/quantiles" # TODO: Figure out parametrizing number of quantiles class MapReducePair(NamedTuple): @@ -99,6 +101,18 @@ class MapReducePair(NamedTuple): reducer: Callable[[pl.Expr | Sequence[pl.Expr] | cs._selector_proxy_], pl.Expr] +def quantile_reducer(cols: cs._selector_proxy_) -> pl.Expr: + """TODO.""" + + vals = pl.concat_list(cols).explode() + + quantile_keys = [0.25, 0.5, 0.75] + quantile_cols = [f"values/quantile/{key}" for key in quantile_keys] + quantiles = {col: vals.quantile(key).alias(col) for col, key in zip(quantile_cols, quantile_keys)} + + return pl.struct(**quantiles) + + VAL_PRESENT: pl.Expr = pl.col("numerical_value").is_not_null() & pl.col("numerical_value").is_not_nan() IS_INT: pl.Expr = pl.col("numerical_value").round() == pl.col("numerical_value") @@ -126,6 +140,10 @@ class MapReducePair(NamedTuple): METADATA_FN.VALUES_MAX: MapReducePair( pl.col("numerical_value").filter(VAL_PRESENT).max(), pl.max_horizontal ), + METADATA_FN.VALUES_QUANTILES: MapReducePair( + pl.col("numerical_value").filter(VAL_PRESENT), + quantile_reducer, + ), } @@ -163,7 +181,7 @@ def validate_args_and_get_code_cols( ... ValueError: Metadata aggregation function INVALID not found in METADATA_FN enumeration. Values are: code/n_patients, code/n_occurrences, values/n_patients, values/n_occurrences, values/n_ints, - values/sum, values/sum_sqd, values/min, values/max + values/sum, values/sum_sqd, values/min, values/max, values/quantiles >>> valid_cfg = DictConfig({"aggregations": ["code/n_patients", "values/n_ints"]}) >>> validate_args_and_get_code_cols(valid_cfg, 33) Traceback (most recent call last): @@ -365,6 +383,21 @@ def mapper_fntr( │ C ┆ 1 ┆ 81.25 ┆ 5.0 ┆ 7.5 │ │ D ┆ null ┆ 0.0 ┆ null ┆ null │ └──────┴───────────┴────────────────┴────────────┴────────────┘ + >>> stage_cfg = DictConfig({"aggregations": ["values/quantiles"]}) + >>> mapper = mapper_fntr(stage_cfg, code_modifier_columns) + >>> mapper(df.lazy()).collect().select("code", "modifier1", pl.col("values/quantiles")) + shape: (5, 3) + ┌──────┬───────────┬──────────────────┐ + │ code ┆ modifier1 ┆ values/quantiles │ + │ --- ┆ --- ┆ --- │ + │ cat ┆ i64 ┆ list[f64] │ + ╞══════╪═══════════╪══════════════════╡ + │ A ┆ 1 ┆ [1.1, 1.1] │ + │ A ┆ 2 ┆ [6.0] │ + │ B ┆ 2 ┆ [2.0, 4.0] │ + │ C ┆ 1 ┆ [5.0, 7.5] │ + │ D ┆ null ┆ [] │ + └──────┴───────────┴──────────────────┘ """ code_key_columns = validate_args_and_get_code_cols(stage_cfg, code_modifier_columns) @@ -426,6 +459,7 @@ def reducer_fntr( ... "values/sum_sqd": [21.3, 2.42, 36.0, 84.0, 81.25], ... "values/min": [-1, 0, -1, 2, 2], ... "values/max": [8.0, 1.1, 6.0, 8.0, 7.5], + ... "values/quantiles": [[1.1, 1.1], [6.0], [6.0], [5.0, 7.5], []], ... }) >>> df_2 = pl.DataFrame({ ... "code": pl.Series(["A", "A", "B", "C"], dtype=pl.Categorical), @@ -439,6 +473,7 @@ def reducer_fntr( ... "values/sum_sqd": [0., 103.2, 84.0, 81.25], ... "values/min": [None, -1., 0.2, -2.], ... "values/max": [None, 6.2, 1.0, 1.5], + ... "values/quantiles": [[1.3, -1.1, 2.0], [6.0, 1.2], [3.0, 2.5], [11.1, 12.]], ... }) >>> df_3 = pl.DataFrame({ ... "code": pl.Series(["D"], dtype=pl.Categorical), @@ -452,6 +487,7 @@ def reducer_fntr( ... "values/sum_sqd": [4], ... "values/min": [0], ... "values/max": [2], + ... "values/quantiles": [[]], ... }) >>> code_modifier_columns = ["modifier1"] >>> stage_cfg = DictConfig({"aggregations": ["code/n_patients", "values/n_ints"]}) @@ -532,6 +568,9 @@ def reducer_fntr( │ C ┆ 2 ┆ 81.25 ┆ 2.0 ┆ 7.5 │ │ D ┆ 1 ┆ 4.0 ┆ 0.0 ┆ 2.0 │ └──────┴───────────┴────────────────┴────────────┴────────────┘ + >>> stage_cfg = DictConfig({"aggregations": ["values/quantiles"]}) + >>> reducer = reducer_fntr(stage_cfg, code_modifier_columns) + >>> reducer(df_1, df_2, df_3) >>> reducer(df_1.drop("values/min"), df_2, df_3) Traceback (most recent call last): ... From 9c3b497edf96b2bce17d133b6eaca71246c11597 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 30 Jul 2024 23:44:40 -0400 Subject: [PATCH 02/53] Got the reducer to return reasonable numbers at least. --- .../aggregate_code_metadata.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/MEDS_transforms/aggregate_code_metadata.py b/src/MEDS_transforms/aggregate_code_metadata.py index aa405b9..1e30688 100755 --- a/src/MEDS_transforms/aggregate_code_metadata.py +++ b/src/MEDS_transforms/aggregate_code_metadata.py @@ -104,7 +104,7 @@ class MapReducePair(NamedTuple): def quantile_reducer(cols: cs._selector_proxy_) -> pl.Expr: """TODO.""" - vals = pl.concat_list(cols).explode() + vals = pl.concat_list(cols.fill_null([])).explode() quantile_keys = [0.25, 0.5, 0.75] quantile_cols = [f"values/quantile/{key}" for key in quantile_keys] @@ -570,7 +570,21 @@ def reducer_fntr( └──────┴───────────┴────────────────┴────────────┴────────────┘ >>> stage_cfg = DictConfig({"aggregations": ["values/quantiles"]}) >>> reducer = reducer_fntr(stage_cfg, code_modifier_columns) - >>> reducer(df_1, df_2, df_3) + >>> reducer(df_1, df_2, df_3).unnest("values/quantiles") + shape: (7, 5) + ┌──────┬───────────┬──────────────────────┬─────────────────────┬──────────────────────┐ + │ code ┆ modifier1 ┆ values/quantile/0.25 ┆ values/quantile/0.5 ┆ values/quantile/0.75 │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ cat ┆ i64 ┆ f64 ┆ f64 ┆ f64 │ + ╞══════╪═══════════╪══════════════════════╪═════════════════════╪══════════════════════╡ + │ null ┆ null ┆ 1.1 ┆ 1.1 ┆ 1.1 │ + │ A ┆ 1 ┆ 1.3 ┆ 2.0 ┆ 2.0 │ + │ A ┆ 2 ┆ 6.0 ┆ 6.0 ┆ 6.0 │ + │ B ┆ 1 ┆ 3.0 ┆ 5.0 ┆ 5.0 │ + │ C ┆ null ┆ 11.1 ┆ 12.0 ┆ 12.0 │ + │ C ┆ 2 ┆ null ┆ null ┆ null │ + │ D ┆ 1 ┆ null ┆ null ┆ null │ + └──────┴───────────┴──────────────────────┴─────────────────────┴──────────────────────┘ >>> reducer(df_1.drop("values/min"), df_2, df_3) Traceback (most recent call last): ... @@ -581,7 +595,8 @@ def reducer_fntr( aggregations = stage_cfg.aggregations agg_operations = { - agg: CODE_METADATA_AGGREGATIONS[agg].reducer(cs.matches(f"{agg}/shard_\\d+")) for agg in aggregations + agg: CODE_METADATA_AGGREGATIONS[agg].reducer(cs.matches(f"{agg}/shard_\\d+")).over(*code_key_columns) + for agg in aggregations } def reducer(*dfs: Sequence[pl.LazyFrame]) -> pl.LazyFrame: From 4b3cfecc79d0b99fab311773732661fb7fb62a26 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 1 Aug 2024 22:40:30 -0400 Subject: [PATCH 03/53] Added doctests and parametrization. --- .../aggregate_code_metadata.py | 98 ++++++++++++++----- 1 file changed, 76 insertions(+), 22 deletions(-) diff --git a/src/MEDS_transforms/aggregate_code_metadata.py b/src/MEDS_transforms/aggregate_code_metadata.py index 6d5e3c4..cd7c1d9 100755 --- a/src/MEDS_transforms/aggregate_code_metadata.py +++ b/src/MEDS_transforms/aggregate_code_metadata.py @@ -102,16 +102,52 @@ class MapReducePair(NamedTuple): reducer: Callable[[pl.Expr | Sequence[pl.Expr] | cs._selector_proxy_], pl.Expr] -def quantile_reducer(cols: cs._selector_proxy_) -> pl.Expr: - """TODO.""" +def quantile_reducer(cols: cs._selector_proxy_, quantiles: list[float]) -> pl.Expr: + """Calculates the specified quantiles for the combined set of all numerical values in `cols`. + + Args: + cols: A polars selector that selects the column(s) containing the numerical values for which the + quantiles should be calculated. + quantiles: A list of floats specifying the quantiles that should be calculated. + + Returns: + A polars expression that calculates the specified quantiles for the combined set of all numerical + values in `cols`. + + Examples: + >>> df = pl.DataFrame({ + ... "key": [1, 2], + ... "vals/shard1": [[1, 2, float('nan')], [None, 3]], + ... "vals/shard2": [[3.0, 4], [30]], + ... }, strict=False) + >>> expr = quantile_reducer(cs.starts_with("vals/"), [0.01, 0.5, 0.75]) + >>> df.select(expr) + shape: (1, 1) + ┌──────────────────┐ + │ values/quantiles │ + │ --- │ + │ struct[3] │ + ╞══════════════════╡ + │ {1.0,3.0,30.0} │ + └──────────────────┘ + >>> df.select("key", expr.over("key")) + shape: (2, 2) + ┌─────┬──────────────────┐ + │ key ┆ values/quantiles │ + │ --- ┆ --- │ + │ i64 ┆ struct[3] │ + ╞═════╪══════════════════╡ + │ 1 ┆ {1.0,3.0,4.0} │ + │ 2 ┆ {3.0,30.0,30.0} │ + └─────┴──────────────────┘ + """ vals = pl.concat_list(cols.fill_null([])).explode() - quantile_keys = [0.25, 0.5, 0.75] - quantile_cols = [f"values/quantile/{key}" for key in quantile_keys] - quantiles = {col: vals.quantile(key).alias(col) for col, key in zip(quantile_cols, quantile_keys)} + quantile_cols = [f"values/quantile/{q}" for q in quantiles] + quantiles_struct = {col: vals.quantile(q).alias(col) for col, q in zip(quantile_cols, quantiles)} - return pl.struct(**quantiles) + return pl.struct(**quantiles_struct).alias(METADATA_FN.VALUES_QUANTILES) VAL = pl.col("numeric_value") @@ -170,7 +206,7 @@ def validate_args_and_get_code_cols( ValueError: Metadata aggregation function INVALID not found in METADATA_FN enumeration. Values are: code/n_patients, code/n_occurrences, values/n_patients, values/n_occurrences, values/n_ints, values/sum, values/sum_sqd, values/min, values/max, values/quantiles - >>> valid_cfg = DictConfig({"aggregations": ["code/n_patients", "values/n_ints"]}) + >>> valid_cfg = DictConfig({"aggregations": ["code/n_patients", {"name": "values/n_ints"}]}) >>> validate_args_and_get_code_cols(valid_cfg, 33) Traceback (most recent call last): ... @@ -192,6 +228,8 @@ def validate_args_and_get_code_cols( aggregations = stage_cfg.aggregations for agg in aggregations: + if isinstance(agg, (dict, DictConfig)): + agg = agg.get("name", None) if agg not in METADATA_FN: raise ValueError( f"Metadata aggregation function {agg} not found in METADATA_FN enumeration. Values are: " @@ -378,7 +416,7 @@ def mapper_fntr( ┌──────┬───────────┬──────────────────┐ │ code ┆ modifier1 ┆ values/quantiles │ │ --- ┆ --- ┆ --- │ - │ cat ┆ i64 ┆ list[f64] │ + │ str ┆ i64 ┆ list[f64] │ ╞══════╪═══════════╪══════════════════╡ │ A ┆ 1 ┆ [1.1, 1.1] │ │ A ┆ 2 ┆ [6.0] │ @@ -391,7 +429,10 @@ def mapper_fntr( code_key_columns = validate_args_and_get_code_cols(stage_cfg, code_modifier_columns) aggregations = stage_cfg.aggregations - agg_operations = {agg: CODE_METADATA_AGGREGATIONS[agg].mapper for agg in aggregations} + agg_operations = {} + for agg in aggregations: + agg_name = agg if isinstance(agg, str) else agg["name"] + agg_operations[agg_name] = CODE_METADATA_AGGREGATIONS[agg_name].mapper def by_code_mapper(df: pl.LazyFrame) -> pl.LazyFrame: return df.group_by(code_key_columns).agg(**agg_operations).sort(code_key_columns) @@ -556,14 +597,20 @@ def reducer_fntr( │ C ┆ 2 ┆ 81.25 ┆ 2.0 ┆ 7.5 │ │ D ┆ 1 ┆ 4.0 ┆ 0.0 ┆ 2.0 │ └──────┴───────────┴────────────────┴────────────┴────────────┘ - >>> stage_cfg = DictConfig({"aggregations": ["values/quantiles"]}) + >>> reducer(df_1.drop("values/min"), df_2, df_3) + Traceback (most recent call last): + ... + KeyError: 'Column values/min not found in DataFrame 0 for reduction.' + >>> stage_cfg = DictConfig({ + ... "aggregations": [{"name": "values/quantiles", "quantiles": [0.25, 0.5, 0.75]}], + ... }) >>> reducer = reducer_fntr(stage_cfg, code_modifier_columns) >>> reducer(df_1, df_2, df_3).unnest("values/quantiles") shape: (7, 5) ┌──────┬───────────┬──────────────────────┬─────────────────────┬──────────────────────┐ │ code ┆ modifier1 ┆ values/quantile/0.25 ┆ values/quantile/0.5 ┆ values/quantile/0.75 │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ - │ cat ┆ i64 ┆ f64 ┆ f64 ┆ f64 │ + │ str ┆ i64 ┆ f64 ┆ f64 ┆ f64 │ ╞══════╪═══════════╪══════════════════════╪═════════════════════╪══════════════════════╡ │ null ┆ null ┆ 1.1 ┆ 1.1 ┆ 1.1 │ │ A ┆ 1 ┆ 1.3 ┆ 2.0 ┆ 2.0 │ @@ -573,30 +620,37 @@ def reducer_fntr( │ C ┆ 2 ┆ null ┆ null ┆ null │ │ D ┆ 1 ┆ null ┆ null ┆ null │ └──────┴───────────┴──────────────────────┴─────────────────────┴──────────────────────┘ - >>> reducer(df_1.drop("values/min"), df_2, df_3) - Traceback (most recent call last): - ... - KeyError: 'Column values/min not found in DataFrame 0 for reduction.' """ code_key_columns = validate_args_and_get_code_cols(stage_cfg, code_modifier_columns) aggregations = stage_cfg.aggregations - agg_operations = { - agg: CODE_METADATA_AGGREGATIONS[agg].reducer(cs.matches(f"{agg}/shard_\\d+")).over(*code_key_columns) - for agg in aggregations - } + agg_operations = {} + for agg in aggregations: + if isinstance(agg, (dict, DictConfig)): + agg_name = agg["name"] + agg_kwargs = {k: v for k, v in agg.items() if k != "name"} + else: + agg_name = agg + agg_kwargs = {} + agg_operations[agg_name] = ( + CODE_METADATA_AGGREGATIONS[agg_name] + .reducer(cs.matches(f"{agg_name}/shard_\\d+"), **agg_kwargs) + .over(*code_key_columns) + ) def reducer(*dfs: Sequence[pl.LazyFrame]) -> pl.LazyFrame: renamed_dfs = [] for i, df in enumerate(dfs): + agg_selectors = [] for agg in aggregations: + if isinstance(agg, (dict, DictConfig)): + agg = agg["name"] if agg not in df.columns: raise KeyError(f"Column {agg} not found in DataFrame {i} for reduction.") + agg_selectors.append(pl.col(agg).alias(f"{agg}/shard_{i}")) - renamed_dfs.append( - df.select(*code_key_columns, *[pl.col(agg).alias(f"{agg}/shard_{i}") for agg in aggregations]) - ) + renamed_dfs.append(df.select(*code_key_columns, *agg_selectors)) df = renamed_dfs[0] for rdf in renamed_dfs[1:]: From e72150678d7eea42a8b57cde6d52a8701d6a5d06 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Fri, 2 Aug 2024 11:54:32 -0400 Subject: [PATCH 04/53] VERY preliminary code. Just committing so as to not lose changes. Nothing works yet. --- src/MEDS_transforms/mapreduce/mapper.py | 64 +++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 5 deletions(-) diff --git a/src/MEDS_transforms/mapreduce/mapper.py b/src/MEDS_transforms/mapreduce/mapper.py index 67605d2..033ec23 100644 --- a/src/MEDS_transforms/mapreduce/mapper.py +++ b/src/MEDS_transforms/mapreduce/mapper.py @@ -2,7 +2,7 @@ from collections.abc import Callable, Generator from datetime import datetime -from functools import partial +from functools import partial, wraps from pathlib import Path from typing import Any, TypeVar @@ -10,6 +10,7 @@ from loguru import logger from omegaconf import DictConfig +from ..extract.parser import is_matcher, matcher_to_expr from ..utils import stage_init, write_lazyframe from .utils import rwlock_wrap, shard_iterator @@ -24,14 +25,67 @@ def identity_fn(df: Any) -> Any: return df -def read_and_filter_fntr(patients: list[int], read_fn: Callable[[Path], DF_T]) -> Callable[[Path], DF_T]: +def read_and_filter_fntr(filter_expr: pl.Expr, read_fn: Callable[[Path], DF_T]) -> Callable[[Path], DF_T]: def read_and_filter(in_fp: Path) -> DF_T: - df = read_fn(in_fp) - return df.filter(pl.col("patient_id").isin(patients)) + return read_fn(in_fp).filter(filter_expr) return read_and_filter +MATCH_REVISE_KEY = "_match_revise" +MATCHER_KEY = "_matcher" + + +def is_match_revise(stage_cfg: DictConfig) -> bool: + return stage_cfg.get(MATCH_REVISE_KEY, False) + + +def validate_match_revise(stage_cfg: DictConfig): + match_revise_options = stage_cfg[MATCH_REVISE_KEY] + if not isinstance(match_revise_options, (list, ListConfig)): + raise ValueError(f"Match revise options must be a list, got {type(match_revise_options)}") + + for match_revise_cfg in match_revise_options: + if not isinstance(match_revise_cfg, (dict, DictConfig)): + raise ValueError(f"Match revise config must be a dict, got {type(match_revise_cfg)}") + + if MATCHER_KEY not in match_revise_cfg: + raise ValueError(f"Match revise config must contain a {MATCHER_KEY} key") + + if not is_matcher(match_revise_cfg[MATCHER_KEY]): + raise ValueError(f"Match revise config must contain a valid matcher in {MATCHER_KEY}") + + +def match_revise_fntr(matcher_expr: pl.Expr, compute_fn: Callable[[DF_T], DF_T]) -> Callable[[DF_T], DF_T]: + @wraps(compute_fn) + def match_revise_fn(df: DF_T) -> DF_T: + cols = df.collect_schema().names + idx_col = "_row_idx" + while idx_col in cols: + idx_col = f"_{idx_col}" + + df = df.with_row_index(idx_col) + + matches = df.filter(matcher_expr) + revised = compute_fn(matches) + return compute_fn(df.filter(matcher_expr)) + + return match_revise_fn + + +def get_match_revise_compute_fn(stage_cfg: DictConfig, compute_fn: MAP_FN_T) -> MAP_FN_T: + if not is_match_revise(stage_cfg): + return compute_fn + + validate_match_revise(stage_cfg) + + match_revise_options = stage_cfg[MATCH_REVISE_KEY] + out_compute_fn = [] + for match_revise_cfg in match_revise_options: + matcher = matcher_to_expr(match_revise_cfg[MATCHER_KEY]) + compute_fn = (partial(pl.filter, matcher),) + compute_fn + + def map_over( cfg: DictConfig, compute_fn: MAP_FN_T | None = None, @@ -60,7 +114,7 @@ def map_over( .collect() .to_list() ) - read_fn = read_and_filter_fntr(split_patients, read_fn) + read_fn = read_and_filter_fntr(pl.col("patient_id").isin(split_patients), read_fn) elif process_split and shards_map_fp and shards_map_fp.exists(): logger.warning( f"Split {process_split} requested, but no patient split file found at {str(split_fp)}. " From bb88c1c681b22d16f428d91362a49aa7611a6338 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 3 Aug 2024 22:32:13 -0400 Subject: [PATCH 05/53] Renamed to 'code_modifiers' --- src/MEDS_transforms/aggregate_code_metadata.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/MEDS_transforms/aggregate_code_metadata.py b/src/MEDS_transforms/aggregate_code_metadata.py index 3d45fe2..91d62a1 100755 --- a/src/MEDS_transforms/aggregate_code_metadata.py +++ b/src/MEDS_transforms/aggregate_code_metadata.py @@ -406,7 +406,7 @@ def mapper_fntr( │ D ┆ null ┆ 0.0 ┆ null ┆ null │ └──────┴───────────┴────────────────┴────────────┴────────────┘ >>> stage_cfg = DictConfig({"aggregations": ["values/quantiles"]}) - >>> mapper = mapper_fntr(stage_cfg, code_modifier_columns) + >>> mapper = mapper_fntr(stage_cfg, code_modifiers) >>> mapper(df.lazy()).collect().select("code", "modifier1", pl.col("values/quantiles")) shape: (5, 3) ┌──────┬───────────┬──────────────────┐ @@ -600,7 +600,7 @@ def reducer_fntr( >>> stage_cfg = DictConfig({ ... "aggregations": [{"name": "values/quantiles", "quantiles": [0.25, 0.5, 0.75]}], ... }) - >>> reducer = reducer_fntr(stage_cfg, code_modifier_columns) + >>> reducer = reducer_fntr(stage_cfg, code_modifiers) >>> reducer(df_1, df_2, df_3).unnest("values/quantiles") shape: (7, 5) ┌──────┬───────────┬──────────────────────┬─────────────────────┬──────────────────────┐ From 8e1a78d7f649dc3febd83e440b1faa9137a668b0 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Mon, 5 Aug 2024 14:49:29 -0400 Subject: [PATCH 06/53] Very preliminary implementation of match and revise --- src/MEDS_transforms/mapreduce/mapper.py | 135 ++++++++++++++++-------- 1 file changed, 93 insertions(+), 42 deletions(-) diff --git a/src/MEDS_transforms/mapreduce/mapper.py b/src/MEDS_transforms/mapreduce/mapper.py index 77f47e5..2b5f649 100644 --- a/src/MEDS_transforms/mapreduce/mapper.py +++ b/src/MEDS_transforms/mapreduce/mapper.py @@ -1,5 +1,6 @@ """Basic utilities for parallelizable map operations on sharded MEDS datasets with caching and locking.""" +import copy import inspect from collections.abc import Callable, Generator from datetime import datetime @@ -10,7 +11,7 @@ import polars as pl from loguru import logger -from omegaconf import DictConfig +from omegaconf import DictConfig, ListConfig from ..extract.parser import is_matcher, matcher_to_expr from ..utils import stage_init, write_lazyframe @@ -170,7 +171,31 @@ def validate_match_revise(stage_cfg: DictConfig): raise ValueError(f"Match revise config must contain a valid matcher in {MATCHER_KEY}") -def match_revise_fntr(matcher_expr: pl.Expr, compute_fn: Callable[[DF_T], DF_T]) -> Callable[[DF_T], DF_T]: +def match_revise_fntr(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_COMPUTE_FN_T) -> COMPUTE_FN_T: + """TODO. + + Args: + + Returns: + + Examples: + >>> raise NotImplementedError("TODO: Add examples") + """ + stage_cfg = copy.deepcopy(stage_cfg) + + if not is_match_revise(stage_cfg): + return bind_compute_fn(cfg, stage_cfg, compute_fn) + + validate_match_revise(stage_cfg) + + matchers_and_fns = [] + for match_revise_cfg in stage_cfg.pop(MATCH_REVISE_KEY): + matcher = matcher_to_expr(match_revise_cfg.pop(MATCHER_KEY)) + local_stage_cfg = DictConfig({**stage_cfg, **match_revise_cfg}) + local_compute_fn = bind_compute_fn(cfg, local_stage_cfg, compute_fn) + + matchers_and_fns.append((matcher, local_compute_fn)) + @wraps(compute_fn) def match_revise_fn(df: DF_T) -> DF_T: cols = df.collect_schema().names @@ -180,24 +205,77 @@ def match_revise_fn(df: DF_T) -> DF_T: df = df.with_row_index(idx_col) - matches = df.filter(matcher_expr) - revised = compute_fn(matches) - return compute_fn(df.filter(matcher_expr)) + unmatched_df = df + + revision_parts = [] + for matcher_expr, local_compute_fn in matchers_and_fns: + matched_df = unmatched_df.filter(matcher_expr).with_columns( + pl.col(idx_col).fill_null("forward").name.keep() + ) + unmatched_df = unmatched_df.filter(~matcher_expr) + + revision_parts = local_compute_fn(matched_df) + + revision_parts.append(unmatched_df) + return ( + pl.concat(revision_parts, how="vertical") + .sort(["patient_id", "time", idx_col], maintain_order=True) + .drop(idx_col) + ) return match_revise_fn -def get_match_revise_compute_fn(stage_cfg: DictConfig, compute_fn: MAP_FN_T) -> MAP_FN_T: - if not is_match_revise(stage_cfg): - return compute_fn +def bind_compute_fn(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_COMPUTE_FN_T) -> COMPUTE_FN_T: + """Bind the compute function to the appropriate parameters based on the type of the compute function. - validate_match_revise(stage_cfg) + Args: + cfg: The DictConfig configuration object. + stage_cfg: The DictConfig stage configuration object. This is separated from the ``cfg`` argument + because in some cases, such as under the match and revise paradigm, the stage config may be + modified dynamically under different matcher conditions to yield different compute functions. + compute_fn: The compute function to bind. - match_revise_options = stage_cfg[MATCH_REVISE_KEY] - out_compute_fn = [] - for match_revise_cfg in match_revise_options: - matcher = matcher_to_expr(match_revise_cfg[MATCHER_KEY]) - compute_fn = (partial(pl.filter, matcher),) + compute_fn + Returns: + The compute function bound to the appropriate parameters. + + Raises: + ValueError: If the compute function is not a valid compute function. + + Examples: + >>> raise NotImplementedError("TODO: Add examples") + """ + + def fntr_params(compute_fn: ANY_COMPUTE_FN_T) -> ComputeFnArgs: + compute_fn_params = inspect.signature(compute_fn).parameters + kwargs = ComputeFnArgs() + + if "cfg" in compute_fn_params: + kwargs["cfg"] = cfg + if "stage_cfg" in compute_fn_params: + kwargs["stage_cfg"] = stage_cfg + if "code_modifiers" in compute_fn_params: + code_modifiers = cfg.get("code_modifiers", None) + kwargs["code_modifiers"] = code_modifiers + if "code_metadata" in compute_fn_params: + kwargs["code_metadata"] = pl.read_parquet( + Path(stage_cfg.metadata_input_dir) / "codes.parquet", use_pyarrow=True + ) + return kwargs + + if compute_fn is None: + return identity_fn + match compute_fn_type(compute_fn): + case ComputeFnType.DIRECT: + pass + case ComputeFnType.UNBOUND: + compute_fn = partial(compute_fn, **fntr_params(compute_fn)) + case ComputeFnType.FUNCTOR: + compute_fn = compute_fn(**fntr_params(compute_fn)) + case _: + raise ValueError("Invalid compute function") + + return compute_fn def map_over( @@ -233,34 +311,7 @@ def map_over( f"Split {process_split} requested, but no patient split file found at {str(split_fp)}." ) - def fntr_params(compute_fn: ANY_COMPUTE_FN_T) -> ComputeFnArgs: - compute_fn_params = inspect.signature(compute_fn).parameters - kwargs = ComputeFnArgs() - - if "cfg" in compute_fn_params: - kwargs["cfg"] = cfg - if "stage_cfg" in compute_fn_params: - kwargs["stage_cfg"] = cfg.stage_cfg - if "code_modifiers" in compute_fn_params: - code_modifiers = cfg.get("code_modifiers", None) - kwargs["code_modifiers"] = code_modifiers - if "code_metadata" in compute_fn_params: - kwargs["code_metadata"] = pl.read_parquet( - Path(cfg.stage_cfg.metadata_input_dir) / "codes.parquet", use_pyarrow=True - ) - return kwargs - - if compute_fn is None: - compute_fn = identity_fn - match compute_fn_type(compute_fn): - case ComputeFnType.DIRECT: - pass - case ComputeFnType.UNBOUND: - compute_fn = partial(compute_fn, **fntr_params(compute_fn)) - case ComputeFnType.FUNCTOR: - compute_fn = compute_fn(**fntr_params(compute_fn)) - case _: - raise ValueError("Invalid compute function") + compute_fn = match_revise_fntr(cfg, cfg.stage_cfg, compute_fn) all_out_fps = [] for in_fp, out_fp in shard_iterator_fntr(cfg): From e260088154953b9683ba6f2f1bd1c4c2da3e2322 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Mon, 5 Aug 2024 17:24:00 -0400 Subject: [PATCH 07/53] Started adding documentation and tests; not yet complete --- src/MEDS_transforms/mapreduce/mapper.py | 233 +++++++++++++++++++++--- 1 file changed, 205 insertions(+), 28 deletions(-) diff --git a/src/MEDS_transforms/mapreduce/mapper.py b/src/MEDS_transforms/mapreduce/mapper.py index 2b5f649..6ce490e 100644 --- a/src/MEDS_transforms/mapreduce/mapper.py +++ b/src/MEDS_transforms/mapreduce/mapper.py @@ -137,10 +137,84 @@ def compute_fn_type(compute_fn: ANY_COMPUTE_FN_T) -> ComputeFnType | None: def identity_fn(df: Any) -> Any: + """A "null" compute function that returns the input DataFrame as is. + + Args: + df: The input DataFrame. + + Returns: + The input DataFrame. + + Examples: + >>> df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + >>> (identity_fn(df) == df).select(pl.all_horizontal(pl.all().all())).item() + True + """ + return df def read_and_filter_fntr(filter_expr: pl.Expr, read_fn: Callable[[Path], DF_T]) -> Callable[[Path], DF_T]: + """Create a function that reads a DataFrame from a file and filters it based on a given expression. + + This is specified as a functor in this way to allow it to modify arbitrary other read functions for use in + different mapreduce pipelines. + + Args: + filter_expr: The filter expression to apply to the DataFrame. + read_fn: The read function to use to read the DataFrame. + + Returns: + A function that reads a DataFrame from a file and filters it based on the given expression. + + Examples: + >>> dfs = { + ... "df1": pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}), + ... "df2": pl.DataFrame({"a": [4, 5, 6], "b": [7, 8, 9]}) + ... } + >>> read_fn = lambda key: dfs[key] + >>> fn = read_and_filter_fntr((pl.col("a") % 2) == 0, read_fn) + >>> fn("df1") + shape: (1, 2) + ┌─────┬─────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪═════╡ + │ 2 ┆ 5 │ + └─────┴─────┘ + >>> fn("df2") + shape: (2, 2) + ┌─────┬─────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪═════╡ + │ 4 ┆ 7 │ + │ 6 ┆ 9 │ + └─────┴─────┘ + >>> fn = read_and_filter_fntr((pl.col("b") % 2) == 0, read_fn) + >>> fn("df1") + shape: (2, 2) + ┌─────┬─────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪═════╡ + │ 1 ┆ 4 │ + │ 3 ┆ 6 │ + └─────┴─────┘ + >>> fn("df2") + shape: (1, 2) + ┌─────┬─────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪═════╡ + │ 5 ┆ 8 │ + └─────┴─────┘ + """ + def read_and_filter(in_fp: Path) -> DF_T: return read_fn(in_fp).filter(filter_expr) @@ -152,10 +226,24 @@ def read_and_filter(in_fp: Path) -> DF_T: def is_match_revise(stage_cfg: DictConfig) -> bool: + """Check if the stage configuration is in a match and revise format. + + Examples: + >>> raise NotImplementedError + """ return stage_cfg.get(MATCH_REVISE_KEY, False) def validate_match_revise(stage_cfg: DictConfig): + """Validate that the stage configuration is in a match and revise format. + + Examples: + >>> raise NotImplementedError + """ + + if MATCH_REVISE_KEY not in stage_cfg: + raise ValueError(f"Stage configuration must contain a {MATCH_REVISE_KEY} key") + match_revise_options = stage_cfg[MATCH_REVISE_KEY] if not isinstance(match_revise_options, (list, ListConfig)): raise ValueError(f"Match revise options must be a list, got {type(match_revise_options)}") @@ -172,56 +260,142 @@ def validate_match_revise(stage_cfg: DictConfig): def match_revise_fntr(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_COMPUTE_FN_T) -> COMPUTE_FN_T: - """TODO. + """A functor that creates a match & revise compute function based on the given configuration. + + Stage configurations for match & revise must be in a match and revise format. Consider the below example, + showing the ``stage_cfg`` object in ``yaml`` format: + + ```yaml + global_arg_1: "foo" + _match_revise: + - _matcher: {code: "CODE//BAR"} + local_arg_1: "bar" + - _matcher: {code: "CODE//BAZ"} + local_arg_1: "baz" + ``` + + This configuration will create a match & revise compute function that will filter the input DataFrame for + rows that match the ``CODE//BAR`` code and apply the compute function with the ``local_arg_1=bar`` + parameter, and then filter the input DataFrame for rows that match the ``CODE//BAZ`` code and apply the + compute function with the ``local_arg_1=baz`` parameter. Both of these local compute functions will be + applied to the input DataFrame in sequence, and the resulting DataFrames will be concatenated alongside + any of the dataframe that matches no matcher (which will be left unmodified) and merged in a sorted way + that respects the ``patient_id``, ``time`` ordering first, then the order of the match & revise blocks + themselves, then the order of the rows in each match & revise block output. Each local compute function + will also use the ``global_arg_1=foo`` parameter. Args: + cfg: The DictConfig configuration object. + stage_cfg: The DictConfig stage configuration object. This stage configuration must be in a match and + revise format, meaning it must have a key ``"_match_revise"`` that contains a list of local match + & revise configurations. Each local match & revise configuration must contain a key ``"_matcher"`` + which links to the matcher configuration to use to filter the input DataFrame for the local + compute execution, and all other keys are local configuration parameters to be used in the local + compute execution. + compute_fn: The compute function to bind to the match & revise configuration local arguments. Returns: + A function that applies the match & revise compute function to the input DataFrame. + + Raises: + ValueError: If the stage configuration is not in a match and revise format. Examples: - >>> raise NotImplementedError("TODO: Add examples") + >>> df = pl.DataFrame({ + ... "patient_id": [1, 1, 1, 2, 2, 2], + ... "time": [1, 2, 2, 1, 1, 2], + ... "initial_idx": [0, 1, 2, 3, 4, 5], + ... "code": ["FINAL", "CODE//TEMP_2", "CODE//TEMP_1", "FINAL", "CODE//TEMP_2", "CODE//TEMP_1"] + ... }) + >>> def compute_fn(df: pl.DataFrame, stage_cfg: DictConfig) -> pl.DataFrame: + ... return df.with_columns( + ... pl.col("code").str.slice(0, len("CODE//")) + + ... stage_cfg.local_code_mid + "//" + stage_cfg.global_code_end + ... ) + >>> stage_cfg = DictConfig({ + ... "global_code_end": "foo", + ... "_match_revise": [ + ... {"_matcher": {"code": "CODE//TEMP_1"}, "local_code_mid": "bar"}, + ... {"_matcher": {"code": "CODE//TEMP_2"}, "local_code_mid": "baz"} + ... ] + ... }) + >>> cfg = DictConfig({"stage_cfg": stage_cfg}) + >>> match_revise_fn = match_revise_fntr(cfg, stage_cfg, compute_fn) + >>> match_revise_fn(df.lazy()).collect() + shape: (6, 4) + ┌────────────┬──────┬─────────────┬────────────────┐ + │ patient_id ┆ time ┆ initial_idx ┆ code │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ i64 ┆ str │ + ╞════════════╪══════╪═════════════╪════════════════╡ + │ 1 ┆ 1 ┆ 0 ┆ FINAL │ + │ 1 ┆ 2 ┆ 2 ┆ CODE//bar//foo │ + │ 1 ┆ 2 ┆ 1 ┆ CODE//baz//foo │ + │ 2 ┆ 1 ┆ 4 ┆ CODE//baz//foo │ + │ 2 ┆ 1 ┆ 3 ┆ FINAL │ + │ 2 ┆ 2 ┆ 5 ┆ CODE//bar//foo │ + └────────────┴──────┴─────────────┴────────────────┘ + >>> stage_cfg = DictConfig({ + ... "global_code_end": "foo", + ... "_match_revise": [ + ... {"_matcher": {"code": "CODE//TEMP_2"}, "local_code_mid": "bizz"}, + ... {"_matcher": {"code": "CODE//TEMP_1"}, "local_code_mid": "foo", "global_code_end": "bar"}, + ... ] + ... }) + >>> cfg = DictConfig({"stage_cfg": stage_cfg}) + >>> match_revise_fn = match_revise_fntr(cfg, stage_cfg, compute_fn) + >>> match_revise_fn(df.lazy()).collect() + >>> stage_cfg = DictConfig({ + ... "global_code_end": "foo", "_match_revise": [{"_matcher": {"missing": "CODE//TEMP_2"}}] + ... }) + >>> cfg = DictConfig({"stage_cfg": stage_cfg}) + >>> match_revise_fn = match_revise_fntr(cfg, stage_cfg, compute_fn) + >>> match_revise_fn(df.lazy()).collect() + Traceback (most recent call last): + ... + ValueError: Missing needed columns {'code'} for local matcher 0: + >>> stage_cfg = DictConfig({"global_code_end": "foo"}) + >>> cfg = DictConfig({"stage_cfg": stage_cfg}) + >>> match_revise_fn = match_revise_fntr(cfg, stage_cfg, compute_fn) + >>> match_revise_fn(df.lazy()).collect() + Traceback (most recent call last): + ... + ValueError: Invalid match and revise configuration... """ stage_cfg = copy.deepcopy(stage_cfg) - if not is_match_revise(stage_cfg): - return bind_compute_fn(cfg, stage_cfg, compute_fn) - - validate_match_revise(stage_cfg) + try: + validate_match_revise(stage_cfg) + except ValueError as e: + raise ValueError("Invalid match and revise configuration") from e matchers_and_fns = [] for match_revise_cfg in stage_cfg.pop(MATCH_REVISE_KEY): - matcher = matcher_to_expr(match_revise_cfg.pop(MATCHER_KEY)) + matcher, cols = matcher_to_expr(match_revise_cfg.pop(MATCHER_KEY)) local_stage_cfg = DictConfig({**stage_cfg, **match_revise_cfg}) local_compute_fn = bind_compute_fn(cfg, local_stage_cfg, compute_fn) - matchers_and_fns.append((matcher, local_compute_fn)) + matchers_and_fns.append((matcher, cols, local_compute_fn)) @wraps(compute_fn) def match_revise_fn(df: DF_T) -> DF_T: - cols = df.collect_schema().names - idx_col = "_row_idx" - while idx_col in cols: - idx_col = f"_{idx_col}" - - df = df.with_row_index(idx_col) - unmatched_df = df + cols = set(df.collect_schema().names()) revision_parts = [] - for matcher_expr, local_compute_fn in matchers_and_fns: - matched_df = unmatched_df.filter(matcher_expr).with_columns( - pl.col(idx_col).fill_null("forward").name.keep() - ) + for i, (matcher_expr, need_cols, local_compute_fn) in enumerate(matchers_and_fns): + if not need_cols.issubset(cols): + raise ValueError( + f"Missing needed columns {need_cols - cols} for local matcher {i}: " + f"{matcher_expr}\nColumns available: {cols}" + ) + matched_df = unmatched_df.filter(matcher_expr) unmatched_df = unmatched_df.filter(~matcher_expr) - revision_parts = local_compute_fn(matched_df) + revision_parts.append(local_compute_fn(matched_df)) revision_parts.append(unmatched_df) - return ( - pl.concat(revision_parts, how="vertical") - .sort(["patient_id", "time", idx_col], maintain_order=True) - .drop(idx_col) - ) + return pl.concat(revision_parts, how="vertical").sort(["patient_id", "time"], maintain_order=True) return match_revise_fn @@ -232,8 +406,8 @@ def bind_compute_fn(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_COMP Args: cfg: The DictConfig configuration object. stage_cfg: The DictConfig stage configuration object. This is separated from the ``cfg`` argument - because in some cases, such as under the match and revise paradigm, the stage config may be - modified dynamically under different matcher conditions to yield different compute functions. + because in some cases, such as under the match & revise paradigm, the stage config may be modified + dynamically under different matcher conditions to yield different compute functions. compute_fn: The compute function to bind. Returns: @@ -311,7 +485,10 @@ def map_over( f"Split {process_split} requested, but no patient split file found at {str(split_fp)}." ) - compute_fn = match_revise_fntr(cfg, cfg.stage_cfg, compute_fn) + if is_match_revise(cfg.stage_cfg): + compute_fn = match_revise_fntr(cfg, cfg.stage_cfg, compute_fn) + else: + compute_fn = bind_compute_fn(cfg, cfg.stage_cfg, compute_fn) all_out_fps = [] for in_fp, out_fp in shard_iterator_fntr(cfg): From 8750b933fa40ace7ea32b6c149a5b491b8225817 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Mon, 5 Aug 2024 17:35:03 -0400 Subject: [PATCH 08/53] Tests added for everything except bind compute function. --- src/MEDS_transforms/mapreduce/mapper.py | 66 ++++++++++++++++++++----- 1 file changed, 55 insertions(+), 11 deletions(-) diff --git a/src/MEDS_transforms/mapreduce/mapper.py b/src/MEDS_transforms/mapreduce/mapper.py index 6ce490e..53dd4d5 100644 --- a/src/MEDS_transforms/mapreduce/mapper.py +++ b/src/MEDS_transforms/mapreduce/mapper.py @@ -229,16 +229,45 @@ def is_match_revise(stage_cfg: DictConfig) -> bool: """Check if the stage configuration is in a match and revise format. Examples: - >>> raise NotImplementedError + >>> is_match_revise(DictConfig({"_match_revise": []})) + False + >>> is_match_revise(DictConfig({"_match_revise": [{"_matcher": {"code": "CODE//TEMP"}}]})) + True + >>> is_match_revise(DictConfig({"foo": "bar"})) + False """ - return stage_cfg.get(MATCH_REVISE_KEY, False) + return bool(stage_cfg.get(MATCH_REVISE_KEY, False)) def validate_match_revise(stage_cfg: DictConfig): """Validate that the stage configuration is in a match and revise format. Examples: - >>> raise NotImplementedError + >>> validate_match_revise(DictConfig({"foo": []})) + Traceback (most recent call last): + ... + ValueError: Stage configuration must contain a _match_revise key + >>> validate_match_revise(DictConfig({"_match_revise": "foo"})) + Traceback (most recent call last): + ... + ValueError: Match revise options must be a list, got + >>> validate_match_revise(DictConfig({"_match_revise": [1]})) + Traceback (most recent call last): + ... + ValueError: Match revise config 0 must be a dict, got + >>> validate_match_revise(DictConfig({"_match_revise": [{"_matcher": {"foo": "bar"}}, 1]})) + Traceback (most recent call last): + ... + ValueError: Match revise config 1 must be a dict, got + >>> validate_match_revise(DictConfig({"_match_revise": [{"foo": "bar"}]})) + Traceback (most recent call last): + ... + ValueError: Match revise config 0 must contain a _matcher key + >>> validate_match_revise(DictConfig({"_match_revise": [{"_matcher": {32: "bar"}}]})) + Traceback (most recent call last): + ... + ValueError: Match revise config 0 must contain a valid matcher in _matcher + >>> validate_match_revise(DictConfig({"_match_revise": [{"_matcher": {"code": "CODE//TEMP"}}]})) """ if MATCH_REVISE_KEY not in stage_cfg: @@ -248,15 +277,15 @@ def validate_match_revise(stage_cfg: DictConfig): if not isinstance(match_revise_options, (list, ListConfig)): raise ValueError(f"Match revise options must be a list, got {type(match_revise_options)}") - for match_revise_cfg in match_revise_options: + for i, match_revise_cfg in enumerate(match_revise_options): if not isinstance(match_revise_cfg, (dict, DictConfig)): - raise ValueError(f"Match revise config must be a dict, got {type(match_revise_cfg)}") + raise ValueError(f"Match revise config {i} must be a dict, got {type(match_revise_cfg)}") if MATCHER_KEY not in match_revise_cfg: - raise ValueError(f"Match revise config must contain a {MATCHER_KEY} key") + raise ValueError(f"Match revise config {i} must contain a {MATCHER_KEY} key") if not is_matcher(match_revise_cfg[MATCHER_KEY]): - raise ValueError(f"Match revise config must contain a valid matcher in {MATCHER_KEY}") + raise ValueError(f"Match revise config {i} must contain a valid matcher in {MATCHER_KEY}") def match_revise_fntr(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_COMPUTE_FN_T) -> COMPUTE_FN_T: @@ -345,19 +374,33 @@ def match_revise_fntr(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_CO >>> cfg = DictConfig({"stage_cfg": stage_cfg}) >>> match_revise_fn = match_revise_fntr(cfg, stage_cfg, compute_fn) >>> match_revise_fn(df.lazy()).collect() + shape: (6, 4) + ┌────────────┬──────┬─────────────┬─────────────────┐ + │ patient_id ┆ time ┆ initial_idx ┆ code │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ i64 ┆ str │ + ╞════════════╪══════╪═════════════╪═════════════════╡ + │ 1 ┆ 1 ┆ 0 ┆ FINAL │ + │ 1 ┆ 2 ┆ 1 ┆ CODE//bizz//foo │ + │ 1 ┆ 2 ┆ 2 ┆ CODE//foo//bar │ + │ 2 ┆ 1 ┆ 4 ┆ CODE//bizz//foo │ + │ 2 ┆ 1 ┆ 3 ┆ FINAL │ + │ 2 ┆ 2 ┆ 5 ┆ CODE//foo//bar │ + └────────────┴──────┴─────────────┴─────────────────┘ >>> stage_cfg = DictConfig({ ... "global_code_end": "foo", "_match_revise": [{"_matcher": {"missing": "CODE//TEMP_2"}}] ... }) >>> cfg = DictConfig({"stage_cfg": stage_cfg}) >>> match_revise_fn = match_revise_fntr(cfg, stage_cfg, compute_fn) - >>> match_revise_fn(df.lazy()).collect() + >>> match_revise_fn(df.lazy()).collect() # doctest: +NORMALIZE_WHITESPACE Traceback (most recent call last): ... - ValueError: Missing needed columns {'code'} for local matcher 0: + ValueError: Missing needed columns {'missing'} for local matcher 0: + [(col("missing")) == (String(CODE//TEMP_2))].all_horizontal() + Columns available: 'code', 'initial_idx', 'patient_id', 'time' >>> stage_cfg = DictConfig({"global_code_end": "foo"}) >>> cfg = DictConfig({"stage_cfg": stage_cfg}) >>> match_revise_fn = match_revise_fntr(cfg, stage_cfg, compute_fn) - >>> match_revise_fn(df.lazy()).collect() Traceback (most recent call last): ... ValueError: Invalid match and revise configuration... @@ -385,9 +428,10 @@ def match_revise_fn(df: DF_T) -> DF_T: revision_parts = [] for i, (matcher_expr, need_cols, local_compute_fn) in enumerate(matchers_and_fns): if not need_cols.issubset(cols): + cols_str = "', '".join(x for x in sorted(cols)) raise ValueError( f"Missing needed columns {need_cols - cols} for local matcher {i}: " - f"{matcher_expr}\nColumns available: {cols}" + f"{matcher_expr}\nColumns available: '{cols_str}'" ) matched_df = unmatched_df.filter(matcher_expr) unmatched_df = unmatched_df.filter(~matcher_expr) From 8e418ab544780d7065731cfe51fe40d7650697f8 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Mon, 5 Aug 2024 22:52:28 -0400 Subject: [PATCH 09/53] Added all doctests. --- src/MEDS_transforms/mapreduce/mapper.py | 74 ++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) diff --git a/src/MEDS_transforms/mapreduce/mapper.py b/src/MEDS_transforms/mapreduce/mapper.py index 53dd4d5..81158c6 100644 --- a/src/MEDS_transforms/mapreduce/mapper.py +++ b/src/MEDS_transforms/mapreduce/mapper.py @@ -461,7 +461,79 @@ def bind_compute_fn(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_COMP ValueError: If the compute function is not a valid compute function. Examples: - >>> raise NotImplementedError("TODO: Add examples") + >>> compute_fn = bind_compute_fn(DictConfig({}), DictConfig({}), None) + >>> compute_fn("foobar") + 'foobar' + >>> def compute_fntr(df: pl.DataFrame, cfg: DictConfig) -> pl.DataFrame: + ... return df.with_columns(pl.lit(cfg.val).alias("val")) + >>> compute_fn = bind_compute_fn(DictConfig({"val": "foo"}), None, compute_fntr) + >>> compute_fn(pl.DataFrame({"a": [1, 2, 3]})) + shape: (3, 2) + ┌─────┬─────┐ + │ a ┆ val │ + │ --- ┆ --- │ + │ i64 ┆ str │ + ╞═════╪═════╡ + │ 1 ┆ foo │ + │ 2 ┆ foo │ + │ 3 ┆ foo │ + └─────┴─────┘ + >>> def compute_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: + ... return lambda df: df.with_columns(pl.lit(cfg.val).alias("val")) + >>> compute_fn = bind_compute_fn(DictConfig({"val": "foo"}), None, compute_fntr) + >>> compute_fn(pl.DataFrame({"a": [1, 2, 3]})) + shape: (3, 2) + ┌─────┬─────┐ + │ a ┆ val │ + │ --- ┆ --- │ + │ i64 ┆ str │ + ╞═════╪═════╡ + │ 1 ┆ foo │ + │ 2 ┆ foo │ + │ 3 ┆ foo │ + └─────┴─────┘ + >>> def compute_fntr(stage_cfg, cfg) -> Callable[[pl.DataFrame], pl.DataFrame]: + ... return lambda df: df.with_columns( + ... pl.lit(stage_cfg.val).alias("stage_val"), pl.lit(cfg.val).alias("cfg_val") + ... ) + >>> compute_fn = bind_compute_fn(DictConfig({"val": "quo"}), DictConfig({"val": "bar"}), compute_fntr) + >>> compute_fn(pl.DataFrame({"a": [1, 2, 3]})) + shape: (3, 3) + ┌─────┬───────────┬─────────┐ + │ a ┆ stage_val ┆ cfg_val │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ str ┆ str │ + ╞═════╪═══════════╪═════════╡ + │ 1 ┆ bar ┆ quo │ + │ 2 ┆ bar ┆ quo │ + │ 3 ┆ bar ┆ quo │ + └─────┴───────────┴─────────┘ + >>> def compute_fntr(df, code_metadata): + ... return df.join(code_metadata, on="a") + >>> code_metadata_df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + >>> from tempfile import TemporaryDirectory + >>> with TemporaryDirectory() as tmpdir: + ... code_metadata_fp = Path(tmpdir) / "codes.parquet" + ... code_metadata_df.write_parquet(code_metadata_fp) + ... stage_cfg = DictConfig({"metadata_input_dir": tmpdir}) + ... compute_fn = bind_compute_fn(DictConfig({}), stage_cfg, compute_fntr) + ... compute_fn(pl.DataFrame({"a": [1, 2, 3]})) + shape: (3, 2) + ┌─────┬─────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪═════╡ + │ 1 ┆ 4 │ + │ 2 ┆ 5 │ + │ 3 ┆ 6 │ + └─────┴─────┘ + >>> def compute_fntr(df: pl.DataFrame, cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: + ... return lambda df: df + >>> bind_compute_fn(DictConfig({}), DictConfig({}), compute_fntr) + Traceback (most recent call last): + ... + ValueError: Invalid compute function """ def fntr_params(compute_fn: ANY_COMPUTE_FN_T) -> ComputeFnArgs: From f19ecc7d65790521e46cdfe9a6f28f1dee46b8fb Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Mon, 5 Aug 2024 23:24:15 -0400 Subject: [PATCH 10/53] Added an integration test via the filter_measurement transform; also an example of loading MEDS_transforms configs as defaults dynamically. --- src/MEDS_transforms/mapreduce/mapper.py | 6 +- tests/test_filter_measurements.py | 111 ++++++++++++++++++++++++ tests/transform_tester_base.py | 16 ++-- tests/utils.py | 51 +++++++++-- 4 files changed, 169 insertions(+), 15 deletions(-) diff --git a/src/MEDS_transforms/mapreduce/mapper.py b/src/MEDS_transforms/mapreduce/mapper.py index 81158c6..29039fa 100644 --- a/src/MEDS_transforms/mapreduce/mapper.py +++ b/src/MEDS_transforms/mapreduce/mapper.py @@ -1,6 +1,5 @@ """Basic utilities for parallelizable map operations on sharded MEDS datasets with caching and locking.""" -import copy import inspect from collections.abc import Callable, Generator from datetime import datetime @@ -9,6 +8,7 @@ from pathlib import Path from typing import Any, NotRequired, TypedDict, TypeVar +import hydra import polars as pl from loguru import logger from omegaconf import DictConfig, ListConfig @@ -405,13 +405,13 @@ def match_revise_fntr(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_CO ... ValueError: Invalid match and revise configuration... """ - stage_cfg = copy.deepcopy(stage_cfg) - try: validate_match_revise(stage_cfg) except ValueError as e: raise ValueError("Invalid match and revise configuration") from e + stage_cfg = hydra.utils.instantiate(stage_cfg) + matchers_and_fns = [] for match_revise_cfg in stage_cfg.pop(MATCH_REVISE_KEY): matcher, cols = matcher_to_expr(match_revise_cfg.pop(MATCHER_KEY)) diff --git a/tests/test_filter_measurements.py b/tests/test_filter_measurements.py index 8e0ff7e..8fc901e 100644 --- a/tests/test_filter_measurements.py +++ b/tests/test_filter_measurements.py @@ -118,3 +118,114 @@ def test_filter_measurements(): transform_stage_kwargs={"min_patients_per_code": 2}, want_outputs=WANT_SHARDS, ) + + +# This is the code metadata +# MEDS_CODE_METADATA_CSV = """ +# code,code/n_occurrences,code/n_patients,values/n_occurrences,values/sum,values/sum_sqd,description,parent_code +# ,44,4,28,3198.8389005974336,382968.28937288234,, +# ADMISSION//CARDIAC,2,2,0,,,, +# ADMISSION//ORTHOPEDIC,1,1,0,,,, +# ADMISSION//PULMONARY,1,1,0,,,, +# DISCHARGE,4,4,0,,,, +# DOB,4,4,0,,,, +# EYE_COLOR//BLUE,1,1,0,,,"Blue Eyes. Less common than brown.", +# EYE_COLOR//BROWN,1,1,0,,,"Brown Eyes. The most common eye color.", +# EYE_COLOR//HAZEL,2,2,0,,,"Hazel eyes. These are uncommon", +# HEIGHT,4,4,4,656.8389005974336,108056.12937288235,, +# HR,12,4,12,1360.5000000000002,158538.77,"Heart Rate",LOINC/8867-4 +# TEMP,12,4,12,1181.4999999999998,116373.38999999998,"Body Temperature",LOINC/8310-5 +# """ +# +# In the test that applies to the match and revise framework, we'll filter codes in the following manner: +# - Codes that start with ADMISSION// will be filtered to occur at least 2 times, which are: +# ADMISSION//CARDIAC +# - Codes in [HR] will be filtered to occur at least 15 times, which are: +# (no codes) +# - Codes that start with EYE_COLOR// will be filtered to occur at least 4 times, which are: +# (no codes) +# - Other codes won't be filtered, so we will retain HEIGHT, DISCHARGE, DOB, TEMP + +MR_WANT_TRAIN_0 = """ +patient_id,time,code,numeric_value +239684,,HEIGHT,175.271115221764 +239684,"12/28/1980, 00:00:00",DOB, +239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, +239684,"05/11/2010, 17:41:51",TEMP,96.0 +239684,"05/11/2010, 17:48:48",TEMP,96.2 +239684,"05/11/2010, 18:25:35",TEMP,95.8 +239684,"05/11/2010, 18:57:18",TEMP,95.5 +239684,"05/11/2010, 19:27:19",DISCHARGE, +1195293,,HEIGHT,164.6868838269085 +1195293,"06/20/1978, 00:00:00",DOB, +1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, +1195293,"06/20/2010, 19:23:52",TEMP,100.0 +1195293,"06/20/2010, 19:25:32",TEMP,100.0 +1195293,"06/20/2010, 19:45:19",TEMP,99.9 +1195293,"06/20/2010, 20:12:31",TEMP,99.8 +1195293,"06/20/2010, 20:24:44",TEMP,100.0 +1195293,"06/20/2010, 20:41:33",TEMP,100.4 +1195293,"06/20/2010, 20:50:04",DISCHARGE, +""" + +MR_WANT_TRAIN_1 = """ +patient_id,time,code,numeric_value +68729,,HEIGHT,160.3953106166676 +68729,"03/09/1978, 00:00:00",DOB, +68729,"05/26/2010, 02:30:56",TEMP,97.8 +68729,"05/26/2010, 04:51:52",DISCHARGE, +814703,,HEIGHT,156.48559093209357 +814703,"03/28/1976, 00:00:00",DOB, +814703,"02/05/2010, 05:55:39",TEMP,100.1 +814703,"02/05/2010, 07:02:30",DISCHARGE, +""" + +MR_WANT_TUNING_0 = """ +patient_id,time,code,numeric_value +754281,,HEIGHT,166.22261567137025 +754281,"12/19/1988, 00:00:00",DOB, +754281,"01/03/2010, 06:27:59",TEMP,99.8 +754281,"01/03/2010, 08:22:13",DISCHARGE, +""" + +MR_WANT_HELD_OUT_0 = """ +patient_id,time,code,numeric_value +1500733,,HEIGHT,158.60131573580904 +1500733,"07/20/1986, 00:00:00",DOB, +1500733,"06/03/2010, 14:54:38",TEMP,100.0 +1500733,"06/03/2010, 15:39:49",TEMP,100.3 +1500733,"06/03/2010, 16:20:49",TEMP,100.1 +1500733,"06/03/2010, 16:44:26",DISCHARGE, +""" + +MR_WANT_SHARDS = parse_meds_csvs( + { + "train/0": MR_WANT_TRAIN_0, + "train/1": MR_WANT_TRAIN_1, + "tuning/0": MR_WANT_TUNING_0, + "held_out/0": MR_WANT_HELD_OUT_0, + } +) + +MATCH_REVISE_KEY = "_match_revise" +MATCHER_KEY = "_matcher" + + +def test_match_revise_filter_measurements(): + single_stage_transform_tester( + transform_script=FILTER_MEASUREMENTS_SCRIPT, + stage_name="filter_measurements", + transform_stage_kwargs={ + "_match_revise": [ + {"_matcher": {"code": "ADMISSION//CARDIAC"}, "min_patients_per_code": 2}, + {"_matcher": {"code": "ADMISSION//ORTHOPEDIC"}, "min_patients_per_code": 2}, + {"_matcher": {"code": "ADMISSION//PULMONARY"}, "min_patients_per_code": 2}, + {"_matcher": {"code": "HR"}, "min_patients_per_code": 15}, + {"_matcher": {"code": "EYE_COLOR//BLUE"}, "min_patients_per_code": 4}, + {"_matcher": {"code": "EYE_COLOR//BROWN"}, "min_patients_per_code": 4}, + {"_matcher": {"code": "EYE_COLOR//HAZEL"}, "min_patients_per_code": 4}, + ], + }, + want_outputs=MR_WANT_SHARDS, + do_use_config_yaml=True, + ) diff --git a/tests/transform_tester_base.py b/tests/transform_tester_base.py index ce2b8c3..80e4f4b 100644 --- a/tests/transform_tester_base.py +++ b/tests/transform_tester_base.py @@ -291,6 +291,7 @@ def single_stage_transform_tester( input_shards: dict[str, pl.DataFrame] | None = None, do_pass_stage_name: bool = False, file_suffix: str = ".parquet", + do_use_config_yaml: bool = False, ): with tempfile.TemporaryDirectory() as d: MEDS_dir = Path(d) / "MEDS_cohort" @@ -337,12 +338,17 @@ def single_stage_transform_tester( if transform_stage_kwargs: pipeline_config_kwargs["stage_configs"] = {stage_name: transform_stage_kwargs} + run_command_kwargs = { + "script": transform_script, + "hydra_kwargs": pipeline_config_kwargs, + "test_name": f"Single stage transform: {stage_name}", + } + if do_use_config_yaml: + run_command_kwargs["do_use_config_yaml"] = True + run_command_kwargs["config_name"] = "preprocess" + # Run the transform - stderr, stdout = run_command( - transform_script, - pipeline_config_kwargs, - f"Single stage transform: {stage_name}", - ) + stderr, stdout = run_command(**run_command_kwargs) # Check the output if isinstance(want_outputs, pl.DataFrame): diff --git a/tests/utils.py b/tests/utils.py index 2f383f0..d6ac438 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,8 +1,10 @@ import subprocess +import tempfile from io import StringIO from pathlib import Path import polars as pl +from omegaconf import OmegaConf from polars.testing import assert_frame_equal DEFAULT_CSV_TS_FORMAT = "%m/%d/%Y, %H:%M:%S" @@ -102,31 +104,66 @@ def run_command( test_name: str, config_name: str | None = None, should_error: bool = False, + do_use_config_yaml: bool = False, ): script = ["python", str(script.resolve())] if isinstance(script, Path) else [script] command_parts = script - if config_name is not None: - command_parts.append(f"--config-name={config_name}") - command_parts.append(" ".join(dict_to_hydra_kwargs(hydra_kwargs))) + + err_cmd_lines = [] + + if do_use_config_yaml: + if config_name is None: + raise ValueError("config_name must be provided if do_use_config_yaml is True.") + + conf = OmegaConf.create( + { + "defaults": [config_name], + **hydra_kwargs, + } + ) + + conf_dir = tempfile.TemporaryDirectory() + conf_path = Path(conf_dir.name) / "config.yaml" + OmegaConf.save(conf, conf_path) + + command_parts.extend( + [ + f"--config-path={str(conf_path.parent.resolve())}", + "--config-name=config", + "'hydra.searchpath=[pkg://MEDS_transforms.configs]'", + ] + ) + err_cmd_lines.append(f"Using config yaml:\n{OmegaConf.to_yaml(conf)}") + else: + if config_name is not None: + command_parts.append(f"--config-name={config_name}") + command_parts.append(" ".join(dict_to_hydra_kwargs(hydra_kwargs))) full_cmd = " ".join(command_parts) + err_cmd_lines.append(f"Running command: {full_cmd}") command_out = subprocess.run(full_cmd, shell=True, capture_output=True) command_errored = command_out.returncode != 0 stderr = command_out.stderr.decode() + err_cmd_lines.append(f"stderr:\n{stderr}") stdout = command_out.stdout.decode() + err_cmd_lines.append(f"stdout:\n{stdout}") if should_error and not command_errored: + if do_use_config_yaml: + conf_dir.cleanup() raise AssertionError( - f"{test_name} failed as command did not error when expected!\n" - f"command:{full_cmd}\nstdout:\n{stdout}\nstderr:\n{stderr}" + f"{test_name} failed as command did not error when expected!\n" + "\n".join(err_cmd_lines) ) elif not should_error and command_errored: + if do_use_config_yaml: + conf_dir.cleanup() raise AssertionError( - f"{test_name} failed as command errored when not expected!" - f"\ncommand:{full_cmd}\nstdout:\n{stdout}\nstderr:\n{stderr}" + f"{test_name} failed as command errored when not expected!\n" + "\n".join(err_cmd_lines) ) + if do_use_config_yaml: + conf_dir.cleanup() return stderr, stdout From 39aebbc674c17df5428b32a4655af210eddb876d Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 6 Aug 2024 16:58:40 -0400 Subject: [PATCH 11/53] Attempted to update test workflow, add codecov badge, and add a workflow for pushing to pypi. The pypi push workflow requires local tagging via 'git tag 0.0.3' then 'git push origin 0.0.3' --- .github/workflows/python-build.yaml | 95 +++++++++++++++++++++++++++++ .github/workflows/tests.yaml | 7 ++- README.md | 2 + 3 files changed, 103 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/python-build.yaml diff --git a/.github/workflows/python-build.yaml b/.github/workflows/python-build.yaml new file mode 100644 index 0000000..19c21e5 --- /dev/null +++ b/.github/workflows/python-build.yaml @@ -0,0 +1,95 @@ +name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI + +on: push + +jobs: + build: + name: Build distribution 📦 + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.12" + - name: Install pypa/build + run: >- + python3 -m + pip install + build + --user + - name: Build a binary wheel and a source tarball + run: python3 -m build + - name: Store the distribution packages + uses: actions/upload-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + publish-to-pypi: + name: >- + Publish Python 🐍 distribution 📦 to PyPI + if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes + needs: + - build + runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/p/MEDS-transforms # Replace with your PyPI project name + permissions: + id-token: write # IMPORTANT: mandatory for trusted publishing + + steps: + - name: Download all the dists + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + - name: Publish distribution 📦 to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + + github-release: + name: >- + Sign the Python 🐍 distribution 📦 with Sigstore + and upload them to GitHub Release + needs: + - publish-to-pypi + runs-on: ubuntu-latest + + permissions: + contents: write # IMPORTANT: mandatory for making GitHub Releases + id-token: write # IMPORTANT: mandatory for sigstore + + steps: + - name: Download all the dists + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + - name: Sign the dists with Sigstore + uses: sigstore/gh-action-sigstore-python@v2.1.1 + with: + inputs: >- + ./dist/*.tar.gz + ./dist/*.whl + - name: Create GitHub Release + env: + GITHUB_TOKEN: ${{ github.token }} + run: >- + gh release create + '${{ github.ref_name }}' + --repo '${{ github.repository }}' + --notes "" + - name: Upload artifact signatures to GitHub Release + env: + GITHUB_TOKEN: ${{ github.token }} + # Upload to GitHub Release using the `gh` CLI. + # `dist/` contains the built packages, and the + # sigstore-produced signatures and certificates. + run: >- + gh release upload + '${{ github.ref_name }}' dist/** + --repo '${{ github.repository }}' diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 908adc5..34443b6 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -33,9 +33,14 @@ jobs: #---------------------------------------------- - name: Run tests run: | - pytest -v --doctest-modules --cov=src -s + pytest -v --doctest-modules --cov=src --junitxml=junit.xml -s - name: Upload coverage to Codecov uses: codecov/codecov-action@v4.0.1 with: token: ${{ secrets.CODECOV_TOKEN }} + - name: Upload test results to Codecov + if: ${{ !cancelled() }} + uses: codecov/test-results-action@v1 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/README.md b/README.md index 9838043..35cad6d 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # MEDS Transforms +[![codecov](https://codecov.io/gh/mmcdermott/MEDS_transforms/graph/badge.svg?token=5RORKQOZF9)](https://codecov.io/gh/mmcdermott/MEDS_transforms) + This repository contains a set of functions and scripts for extraction to and transformation/pre-processing of MEDS-formatted data. From c9c581f617063c585efb40335f0af75160ca7b95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Renc?= Date: Wed, 7 Aug 2024 20:23:19 -0400 Subject: [PATCH 12/53] Fix split_and_shard_patients when the full split definition is provided --- .../extract/split_and_shard_patients.py | 51 ++++++++++--------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/src/MEDS_transforms/extract/split_and_shard_patients.py b/src/MEDS_transforms/extract/split_and_shard_patients.py index 3a4e1d5..0cdb38e 100755 --- a/src/MEDS_transforms/extract/split_and_shard_patients.py +++ b/src/MEDS_transforms/extract/split_and_shard_patients.py @@ -99,31 +99,36 @@ def shard_patients[ is_in_external_split = np.isin(patients, list(all_external_splits)) patient_ids_to_split = patients[~is_in_external_split] - n_patients = len(patient_ids_to_split) - - rng = np.random.default_rng(seed) - split_names_idx = rng.permutation(len(split_fracs_dict)) - split_names = np.array(list(split_fracs_dict.keys()))[split_names_idx] - split_fracs = np.array([split_fracs_dict[k] for k in split_names]) - split_lens = np.round(split_fracs[:-1] * n_patients).astype(int) - split_lens = np.append(split_lens, n_patients - split_lens.sum()) - - if split_lens.min() == 0: - logger.warning( - "Some splits are empty. Adjusting splits to ensure all splits have at least 1 patient." - ) - max_split = split_lens.argmax() - split_lens[max_split] -= 1 - split_lens[split_lens.argmin()] += 1 + splits = external_splits + + if n_patients := len(patient_ids_to_split): + rng = np.random.default_rng(seed) + split_names_idx = rng.permutation(len(split_fracs_dict)) + split_names = np.array(list(split_fracs_dict.keys()))[split_names_idx] + split_fracs = np.array([split_fracs_dict[k] for k in split_names]) + split_lens = np.round(split_fracs[:-1] * n_patients).astype(int) + split_lens = np.append(split_lens, n_patients - split_lens.sum()) + + if split_lens.min() == 0: + logger.warning( + "Some splits are empty. Adjusting splits to ensure all splits have at least 1 patient." + ) + max_split = split_lens.argmax() + split_lens[max_split] -= 1 + split_lens[split_lens.argmin()] += 1 - if split_lens.min() == 0: - raise ValueError("Unable to adjust splits to ensure all splits have at least 1 patient.") + if split_lens.min() == 0: + raise ValueError("Unable to adjust splits to ensure all splits have at least 1 patient.") - patients = rng.permutation(patient_ids_to_split) - patients_per_split = np.split(patients, split_lens.cumsum()) + patients = rng.permutation(patient_ids_to_split) + patients_per_split = np.split(patients, split_lens.cumsum()) - splits = {k: v for k, v in zip(split_names, patients_per_split)} - splits = {**splits, **external_splits} + splits = {**splits, **{k: v for k, v in zip(split_names, patients_per_split)}} + else: + logger.info( + "The external split definition covered all patients. No need to perform an " + "additional patient split." + ) # Sharding final_shards = {} @@ -231,7 +236,7 @@ def main(cfg: DictConfig): if not external_splits_json_fp.exists(): raise FileNotFoundError(f"External splits JSON file not found at {external_splits_json_fp}") - logger.info(f"Reading external splits from {cfg.external_splits_json_fp}") + logger.info(f"Reading external splits from {cfg.stage_cfg.external_splits_json_fp}") external_splits = json.loads(external_splits_json_fp.read_text()) size_strs = ", ".join(f"{k}: {len(v)}" for k, v in external_splits.items()) From 0b924b3bbfe830359b0fa6b5c02ae0a0ccbd88f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Renc?= Date: Wed, 7 Aug 2024 20:30:31 -0400 Subject: [PATCH 13/53] External splits go after internal splits --- src/MEDS_transforms/extract/split_and_shard_patients.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MEDS_transforms/extract/split_and_shard_patients.py b/src/MEDS_transforms/extract/split_and_shard_patients.py index 0cdb38e..f9c2ee3 100755 --- a/src/MEDS_transforms/extract/split_and_shard_patients.py +++ b/src/MEDS_transforms/extract/split_and_shard_patients.py @@ -123,7 +123,7 @@ def shard_patients[ patients = rng.permutation(patient_ids_to_split) patients_per_split = np.split(patients, split_lens.cumsum()) - splits = {**splits, **{k: v for k, v in zip(split_names, patients_per_split)}} + splits = {**{k: v for k, v in zip(split_names, patients_per_split)}, **splits} else: logger.info( "The external split definition covered all patients. No need to perform an " From 3c69775a5dca43612eec639a7d1cecee05939473 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 8 Aug 2024 13:41:34 -0400 Subject: [PATCH 14/53] Added mkdocs documentation starter code. --- docs/gen_ref_pages.py | 35 +++++++++++++++++++++++++++++++++++ docs/index.md | 3 +++ docs/javascripts/mathjax.js | 19 +++++++++++++++++++ mkdocs.yml | 32 ++++++++++++++++++++++++++++++++ pyproject.toml | 4 ++++ 5 files changed, 93 insertions(+) create mode 100644 docs/gen_ref_pages.py create mode 100644 docs/index.md create mode 100644 docs/javascripts/mathjax.js create mode 100644 mkdocs.yml diff --git a/docs/gen_ref_pages.py b/docs/gen_ref_pages.py new file mode 100644 index 0000000..d38acd7 --- /dev/null +++ b/docs/gen_ref_pages.py @@ -0,0 +1,35 @@ +"""Generate the code reference pages.""" + +from pathlib import Path + +import mkdocs_gen_files + +nav = mkdocs_gen_files.Nav() + +root = Path(__file__).parent.parent +src = root / "src" + +for path in sorted(src.rglob("*.py")): + module_path = path.relative_to(src).with_suffix("") + doc_path = path.relative_to(src).with_suffix(".md") + full_doc_path = "api" / doc_path + + parts = tuple(module_path.parts) + + if parts[-1] == "__init__": + parts = parts[:-1] + doc_path = doc_path.with_name("index.md") + full_doc_path = full_doc_path.with_name("index.md") + elif parts[-1] == "__main__": + continue + + nav[parts] = doc_path.as_posix() + + with mkdocs_gen_files.open(full_doc_path, "w") as fd: + ident = ".".join(parts) + fd.write(f"::: {ident}") + + mkdocs_gen_files.set_edit_path(full_doc_path, path.relative_to(root)) + +with mkdocs_gen_files.open("api/SUMMARY.md", "w") as nav_file: + nav_file.writelines(nav.build_literate_nav()) diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..ee359f4 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,3 @@ +# Welcome! + +MEDS-Transforms is a library. More details incoming. diff --git a/docs/javascripts/mathjax.js b/docs/javascripts/mathjax.js new file mode 100644 index 0000000..7e48906 --- /dev/null +++ b/docs/javascripts/mathjax.js @@ -0,0 +1,19 @@ +window.MathJax = { + tex: { + inlineMath: [["\\(", "\\)"]], + displayMath: [["\\[", "\\]"]], + processEscapes: true, + processEnvironments: true + }, + options: { + ignoreHtmlClass: ".*|", + processHtmlClass: "arithmatex" + } +}; + +document$.subscribe(() => { + MathJax.startup.output.clearCache() + MathJax.typesetClear() + MathJax.texReset() + MathJax.typesetPromise() +}) diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000..a49156d --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,32 @@ +site_name: MEDS-Transforms +repo_url: https://github.com/mmcdermott/MEDS_transforms +site_description: Documentation for the MEDS Transforms package +site_author: Matthew McDermott + +nav: + - Home: index.md + - API: api/ + - Issues: https://github.com/mmcdermott/MEDS_transforms/issues + +theme: + name: material + locale: en + +markdown_extensions: + - smarty + # - pymdownx.arithmatex: + # generic: true + +extra_javascript: + - javascripts/mathjax.js + - https://unpkg.com/mathjax@3/es5/tex-mml-chtml.js + +plugins: + - search + - gen-files: + scripts: + - docs/gen_ref_pages.py + - literate-nav: + nav_file: SUMMARY.md + - section-index + - mkdocstrings diff --git a/pyproject.toml b/pyproject.toml index e512e7e..55c1583 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,10 @@ dev = ["pre-commit"] tests = ["pytest", "pytest-cov", "rootutils"] local_parallelism = ["hydra-joblib-launcher"] slurm_parallelism = ["hydra-submitit-launcher"] +docs = [ + "mkdocs==1.6.0", "mkdocs-material==9.5.31", "mkdocstrings[python,shell]==0.25.2", "mkdocs-gen-files==0.5.0", + "mkdocs-literate-nav==0.6.1", "mkdocs-section-index==0.3.9" +] [project.scripts] # MEDS_extract From ebabad7b6ec816a52271a7de163462ca8f3d5919 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 8 Aug 2024 13:55:08 -0400 Subject: [PATCH 15/53] Added readthedocs file. --- .readthedocs.yaml | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 .readthedocs.yaml diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..b51491f --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,21 @@ +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the version of Python and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.12" + +python: + install: + - method: pip + path: . + extra_requirements: + - docs + +mkdocs: + configuration: mkdocs.yml From 38403852a3289ec246f3d27696a4789984399b4e Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 8 Aug 2024 13:56:08 -0400 Subject: [PATCH 16/53] Added docs badge. --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 35cad6d..57c5c9f 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # MEDS Transforms [![codecov](https://codecov.io/gh/mmcdermott/MEDS_transforms/graph/badge.svg?token=5RORKQOZF9)](https://codecov.io/gh/mmcdermott/MEDS_transforms) +[![Documentation Status](https://readthedocs.org/projects/meds-transforms/badge/?version=latest)](https://meds-transforms.readthedocs.io/en/latest/?badge=latest) This repository contains a set of functions and scripts for extraction to and transformation/pre-processing of MEDS-formatted data. From 807596315e9646f613e48cf4943e84fe3df8b026 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 8 Aug 2024 14:02:05 -0400 Subject: [PATCH 17/53] Updated tests workflow to ignore docs. --- .github/workflows/tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 34443b6..90c9ed7 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -33,7 +33,7 @@ jobs: #---------------------------------------------- - name: Run tests run: | - pytest -v --doctest-modules --cov=src --junitxml=junit.xml -s + pytest -v --doctest-modules --cov=src --junitxml=junit.xml -s --ignore=docs - name: Upload coverage to Codecov uses: codecov/codecov-action@v4.0.1 From 1dedfad63c474423ab71ac452c7403555725ed33 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 8 Aug 2024 14:19:45 -0400 Subject: [PATCH 18/53] Updated docs to move basic docs into readthedocs. --- docs/index.md | 4 +--- .../pipeline_configuration.md | 0 .../preprocessing_operation_prototypes.md | 0 terminology.md => docs/terminology.md | 0 mkdocs.yml | 15 +++++++++++++-- 5 files changed, 14 insertions(+), 5 deletions(-) rename pipeline_configuration.md => docs/pipeline_configuration.md (100%) rename preprocessing_operation_prototypes.md => docs/preprocessing_operation_prototypes.md (100%) rename terminology.md => docs/terminology.md (100%) diff --git a/docs/index.md b/docs/index.md index ee359f4..ed62cb0 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,3 +1 @@ -# Welcome! - -MEDS-Transforms is a library. More details incoming. +--8\<-- "README.md" diff --git a/pipeline_configuration.md b/docs/pipeline_configuration.md similarity index 100% rename from pipeline_configuration.md rename to docs/pipeline_configuration.md diff --git a/preprocessing_operation_prototypes.md b/docs/preprocessing_operation_prototypes.md similarity index 100% rename from preprocessing_operation_prototypes.md rename to docs/preprocessing_operation_prototypes.md diff --git a/terminology.md b/docs/terminology.md similarity index 100% rename from terminology.md rename to docs/terminology.md diff --git a/mkdocs.yml b/mkdocs.yml index a49156d..02bd3d6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -5,6 +5,9 @@ site_author: Matthew McDermott nav: - Home: index.md + - "Pipeline Configuration": pipeline_configuration.md + - "Pre-processing Operations": preprocessing_operation_prototypes.md + - "Terminology": terminology.md - API: api/ - Issues: https://github.com/mmcdermott/MEDS_transforms/issues @@ -14,8 +17,16 @@ theme: markdown_extensions: - smarty - # - pymdownx.arithmatex: - # generic: true + - pymdownx.arithmatex: + generic: true + - pymdownx.highlight: + anchor_linenums: true + - pymdownx.inlinehilite + - pymdownx.smartsymbols + - pymdownx.snippets + - pymdownx.tabbed: + alternate_style: true + - pymdownx.superfences extra_javascript: - javascripts/mathjax.js From 9c782137736abf07f08e7b57d4b3f955eaa3489b Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 8 Aug 2024 14:34:45 -0400 Subject: [PATCH 19/53] Correct mdformat issue --- .pre-commit-config.yaml | 6 ++++-- docs/index.md | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2160db6..3036cc2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ default_language_version: python: python3.12 -exclude: "sample_data|docs/MIMIC_IV_tutorial/wandb_reports" +exclude: "docs/index.md" repos: - repo: https://github.com/pre-commit/pre-commit-hooks @@ -95,10 +95,12 @@ repos: - mdformat-gfm - mdformat-tables - mdformat_frontmatter - - mdformat-myst - mdformat-black - mdformat-config - mdformat-shfmt + - mdformat-mkdocs + - mdformat-toc + - mdformat-admon # word spelling linter - repo: https://github.com/codespell-project/codespell diff --git a/docs/index.md b/docs/index.md index ed62cb0..612c7a5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1 +1 @@ ---8\<-- "README.md" +--8<-- "README.md" From e0d9fea2b88d9d52d2158899b3fa45e5eac73946 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 8 Aug 2024 14:39:43 -0400 Subject: [PATCH 20/53] linted files. --- .pre-commit-config.yaml | 2 +- MIMIC-IV_Example/README.md | 32 +-- README.md | 254 ++++++++++----------- docs/pipeline_configuration.md | 100 ++++---- docs/preprocessing_operation_prototypes.md | 36 +-- eICU_Example/README.md | 46 ++-- src/MEDS_transforms/extract/README.md | 72 +++--- 7 files changed, 271 insertions(+), 271 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3036cc2..b188f48 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ default_language_version: python: python3.12 -exclude: "docs/index.md" +exclude: "docs/index.md|MIMIC-IV_Example/README.md|eICU_Example/README.md" repos: - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/MIMIC-IV_Example/README.md b/MIMIC-IV_Example/README.md index ca61588..45630b3 100644 --- a/MIMIC-IV_Example/README.md +++ b/MIMIC-IV_Example/README.md @@ -72,14 +72,14 @@ several steps: This is a step in a few parts: 1. Join a few tables by `hadm_id` to get the right times in the right rows for processing. In - particular, we need to join: - - the `hosp/diagnoses_icd` table with the `hosp/admissions` table to get the `dischtime` for each - `hadm_id`. - - the `hosp/drgcodes` table with the `hosp/admissions` table to get the `dischtime` for each `hadm_id`. + particular, we need to join: + - the `hosp/diagnoses_icd` table with the `hosp/admissions` table to get the `dischtime` for each + `hadm_id`. + - the `hosp/drgcodes` table with the `hosp/admissions` table to get the `dischtime` for each `hadm_id`. 2. Convert the patient's static data to a more parseable form. This entails: - - Get the patient's DOB in a format that is usable for MEDS, rather than the integral `anchor_year` and - `anchor_offset` fields. - - Merge the patient's `dod` with the `deathtime` from the `admissions` table. + - Get the patient's DOB in a format that is usable for MEDS, rather than the integral `anchor_year` and + `anchor_offset` fields. + - Merge the patient's `dod` with the `deathtime` from the `admissions` table. After these steps, modified files or symlinks to the original files will be written in a new directory which will be used as the input to the actual MEDS extraction ETL. We'll use `$MIMICIV_PREMEDS_DIR` to denote this @@ -104,24 +104,24 @@ subdirectories of the same root directory). This is a step in 4 parts: 1. Sub-shard the raw files. Run this command as many times simultaneously as you would like to have workers - performing this sub-sharding step. See below for how to automate this parallelism using hydra launchers. + performing this sub-sharding step. See below for how to automate this parallelism using hydra launchers. - This step uses the `./scripts/extraction/shard_events.py` script. See `joint_script*.sh` for the expected - format of the command. + This step uses the `./scripts/extraction/shard_events.py` script. See `joint_script*.sh` for the expected + format of the command. 2. Extract and form the patient splits and sub-shards. The `./scripts/extraction/split_and_shard_patients.py` - script is used for this step. See `joint_script*.sh` for the expected format of the command. + script is used for this step. See `joint_script*.sh` for the expected format of the command. 3. Extract patient sub-shards and convert to MEDS events. The - `./scripts/extraction/convert_to_sharded_events.py` script is used for this step. See `joint_script*.sh` for - the expected format of the command. + `./scripts/extraction/convert_to_sharded_events.py` script is used for this step. See `joint_script*.sh` for + the expected format of the command. 4. Merge the MEDS events into a single file per patient sub-shard. The - `./scripts/extraction/merge_to_MEDS_cohort.py` script is used for this step. See `joint_script*.sh` for the - expected format of the command. + `./scripts/extraction/merge_to_MEDS_cohort.py` script is used for this step. See `joint_script*.sh` for the + expected format of the command. 5. (Optional) Generate preliminary code statistics and merge to external metadata. This is not performed - currently in the `joint_script*.sh` scripts. + currently in the `joint_script*.sh` scripts. ## Limitations / TO-DOs: diff --git a/README.md b/README.md index 57c5c9f..2f3e8fe 100644 --- a/README.md +++ b/README.md @@ -22,60 +22,60 @@ directories. - For a pypi installation, install with `pip install MEDS-transforms`. - For a local installation, clone this repository and run `pip install .` from the repository root. - For running the MIMIC-IV example, install the optional MIMIC dependencies as well with - `pip install MEDS-transforms[examples]`. + `pip install MEDS-transforms[examples]`. - To support same-machine, process-based parallelism, install the optional joblib dependencies with - `pip install MEDS-transforms[local_parallelism]`. + `pip install MEDS-transforms[local_parallelism]`. - To support cluster-based parallelism, install the optional submitit dependencies with - `pip install MEDS-transforms[slurm_parallelism]`. + `pip install MEDS-transforms[slurm_parallelism]`. - For working on development, install the optional development dependencies with - `pip install .[dev,tests]`. + `pip install .[dev,tests]`. - Optional dependencies can be mutually installed by combining the optional dependency names with commas in - the square brackets, e.g., `pip install MEDS-transforms[examples,local_parallelism]`. + the square brackets, e.g., `pip install MEDS-transforms[examples,local_parallelism]`. ## Design Philosophy The fundamental design philosophy of this repository can be summarized as follows: 1. _(The MEDS Assumption)_: All structured electronic health record (EHR) data can be represented as a - series of events, each of which is associated with a patient, a time, and a set of codes and - numeric values. This representation is the Medical Event Data Standard (MEDS) format, and in this - repository we use it in the "flat" format, where data is organized as rows of `patient_id`, - `time`, `code`, `numeric_value` columns. + series of events, each of which is associated with a patient, a time, and a set of codes and + numeric values. This representation is the Medical Event Data Standard (MEDS) format, and in this + repository we use it in the "flat" format, where data is organized as rows of `patient_id`, + `time`, `code`, `numeric_value` columns. 2. _Easy Efficiency through Sharding_: MEDS datasets in this repository are sharded into smaller, more - manageable pieces (organized as separate files) at the patient level (and, during the raw-data extraction - process, the event level). This enables users to scale up their processing capabilities ad nauseum by - leveraging more workers to process these shards in parallel. This parallelization is seamlessly enabled - with the configuration schema used in the scripts in this repository. This style of parallelization - does not require complex packages to manage, complex systems of parallelization support, and can be - employed on single machines or across clusters. Through this style of parallelism, the MIMIC-IV ETL - included in this repository has been run end to end in under ten minutes with suitable parallelization. + manageable pieces (organized as separate files) at the patient level (and, during the raw-data extraction + process, the event level). This enables users to scale up their processing capabilities ad nauseum by + leveraging more workers to process these shards in parallel. This parallelization is seamlessly enabled + with the configuration schema used in the scripts in this repository. This style of parallelization + does not require complex packages to manage, complex systems of parallelization support, and can be + employed on single machines or across clusters. Through this style of parallelism, the MIMIC-IV ETL + included in this repository has been run end to end in under ten minutes with suitable parallelization. 3. _Simple, Modular, and Testable_: Each stage of the pipelines demonstrated in this repository is designed - to be simple, modular, and testable. Each operation is a single script that can be run independently of - the others, and each stage is designed to do a small amount of work and be easily testable in isolation. - This design philosophy ensures that the pipeline is robust to changes, easy to debug, and easy to extend. - In particular, to add new operations specific to a given model or dataset, the user need only write - simple functions that take in a flat MEDS dataframe (representing a single patient level shard) and - return a new flat MEDS dataframe, and then wrap that function in a script by following the examples - provided in this repository. These individual functions can use the same configuration schema as other - stages in the pipeline or include a separate, stage-specific configuration, and can use whatever - dataframe or data-processing tool desired (e.g., Pandas, Polars, DuckDB, FEMR, etc.), though the examples - in this repository leverage Polars. + to be simple, modular, and testable. Each operation is a single script that can be run independently of + the others, and each stage is designed to do a small amount of work and be easily testable in isolation. + This design philosophy ensures that the pipeline is robust to changes, easy to debug, and easy to extend. + In particular, to add new operations specific to a given model or dataset, the user need only write + simple functions that take in a flat MEDS dataframe (representing a single patient level shard) and + return a new flat MEDS dataframe, and then wrap that function in a script by following the examples + provided in this repository. These individual functions can use the same configuration schema as other + stages in the pipeline or include a separate, stage-specific configuration, and can use whatever + dataframe or data-processing tool desired (e.g., Pandas, Polars, DuckDB, FEMR, etc.), though the examples + in this repository leverage Polars. 4. _Configuration Extensibility through Hydra_: We rely heavily on Hydra and OmegaConf in this repository to - simplify configuration management within and across stages for a single pipeline. This design enables - easy choice of parallelization by leveraging distinct Hydra launcher plugins for local or cluster-driven - parallelism, natural capturing of logs and outputs for each stage, easy incorporation of documentation - and help-text for both overall pipelines and individual stages, and extensibility beyond default patterns - for more complex use-cases. + simplify configuration management within and across stages for a single pipeline. This design enables + easy choice of parallelization by leveraging distinct Hydra launcher plugins for local or cluster-driven + parallelism, natural capturing of logs and outputs for each stage, easy incorporation of documentation + and help-text for both overall pipelines and individual stages, and extensibility beyond default patterns + for more complex use-cases. 5. _Configuration Files over Code_: Wherever _sensible_, we prefer to rely on configuration files rather - than code to specify repeated behavior prototypes over customized, dataset-specific code to enable - maximal reproducibility. The core strength of MEDS is that it is a shared, standardized format for EHR - data, and this repository is designed to leverage that strength to the fullest extent possible by - designing pipelines that can be, wherever possible, run identically save for configuration file inputs - across disparate datasets. Configuration files also can be easier to communicate to local data experts, - who may not have Python expertise, providing another benefit. This design philosophy is not absolute, - however, and local code can and _should_ be used where appropriate -- see the `MIMIC-IV_Example` and - `eICU_Example` directories for examples of how and where per-dataset code can be leveraged in concert - with the configurable aspects of the standardized MEDS extraction pipeline. + than code to specify repeated behavior prototypes over customized, dataset-specific code to enable + maximal reproducibility. The core strength of MEDS is that it is a shared, standardized format for EHR + data, and this repository is designed to leverage that strength to the fullest extent possible by + designing pipelines that can be, wherever possible, run identically save for configuration file inputs + across disparate datasets. Configuration files also can be easier to communicate to local data experts, + who may not have Python expertise, providing another benefit. This design philosophy is not absolute, + however, and local code can and _should_ be used where appropriate -- see the `MIMIC-IV_Example` and + `eICU_Example` directories for examples of how and where per-dataset code can be leveraged in concert + with the configurable aspects of the standardized MEDS extraction pipeline. ## Intended Usage @@ -95,17 +95,17 @@ the parallelism and avoid duplicative work. This permits significant flexibility run. - The user can run the entire pipeline in serial, through a single shell script simply by calling each - stage's script in sequence. + stage's script in sequence. - The user can leverage arbitrary scheduling systems (e.g., Slurm, LSF, Kubernetes, etc.) to run each stage - in parallel on a cluster, either by manually constructing the appropriate worker scripts to run each stage's - script and simply launching as many worker jobs as is desired or by using Hydra launchers such as the - `submitit` launcher to automate the creation of appropriate scheduler worker jobs. Note this will typically - required a distributed file system to work correctly, as these scripts use manually created file locks to - avoid duplicative work. + in parallel on a cluster, either by manually constructing the appropriate worker scripts to run each stage's + script and simply launching as many worker jobs as is desired or by using Hydra launchers such as the + `submitit` launcher to automate the creation of appropriate scheduler worker jobs. Note this will typically + required a distributed file system to work correctly, as these scripts use manually created file locks to + avoid duplicative work. - The user can run each stage in parallel on a single machine by launching multiple copies of the same - script in different terminal sessions or all at once via the Hydra `joblib` launcher. This can result in a - significant speedup depending on the machine configuration as sharding ensures that parallelism can be used - with minimal file read contention. + script in different terminal sessions or all at once via the Hydra `joblib` launcher. This can result in a + significant speedup depending on the machine configuration as sharding ensures that parallelism can be used + with minimal file read contention. Two of these methods of parallelism, in particular local-machine parallelism and slurm-based cluster parallelism, are supported explicitly by this package through the use of the `joblib` and `submitit` Hydra @@ -154,15 +154,15 @@ To use this repository as an importable library, the user should follow these st 1. Install the repository as a package. 2. Design your own transform function in your own codebase and leverage `MEDS_transform` utilities such as - `MEDS_transform.mapreduce.mapper.map_over` to easily apply your transform over a sharded MEDS dataset. + `MEDS_transform.mapreduce.mapper.map_over` to easily apply your transform over a sharded MEDS dataset. 3. Leverage the `MEDS_transforms` configuration schema to enable easy configuration of your pipeline, by - importing the MEDS transforms configs via your hydra search path and using them as a base for your own - configuration files, enabling you to intermix your new stage configuration with the existing MEDS - transform stages. + importing the MEDS transforms configs via your hydra search path and using them as a base for your own + configuration files, enabling you to intermix your new stage configuration with the existing MEDS + transform stages. 4. Note that, if your transformations are sufficiently general, you can also submit a PR to add new - transformations to this repository, enabling others to leverage your work as well. - See [this example](https://github.com/mmcdermott/MEDS_transforms/pull/48) for an (in progress) example of - how to do this. + transformations to this repository, enabling others to leverage your work as well. + See [this example](https://github.com/mmcdermott/MEDS_transforms/pull/48) for an (in progress) example of + how to do this. ### As a template @@ -170,18 +170,18 @@ To use this repository as a template, the user should follow these steps: 1. Fork the repository to a new repository for their dedicated pipeline. 2. Design the set of "stages" (e.g., distinct operations that must be completed) that will be required for - their needs. As a best practice, each stage should be realized as a single or set of simple functions - that can be applied on a per-shard basis to the data. Reduction stages (where data needs to be aggregated - across the entire pipeline) should be kept as simple as possible to avoid bottlenecks, but are supported - through this pipeline design; see the (in progress) `scripts/preprocessing/collect_code_metadata.py` - script for an example. + their needs. As a best practice, each stage should be realized as a single or set of simple functions + that can be applied on a per-shard basis to the data. Reduction stages (where data needs to be aggregated + across the entire pipeline) should be kept as simple as possible to avoid bottlenecks, but are supported + through this pipeline design; see the (in progress) `scripts/preprocessing/collect_code_metadata.py` + script for an example. 3. Mimic the structure of the `configs/preprocessing.yaml` configuration file to assemble a configuration - file for the necessary stages of your pipeline. Identify in advance what dataset-specific information the - user will need to specify to run your pipeline (e.g., will they need links between dataset codes and - external ontologies? Will they need to specify select key-event concepts to identify in the data? etc.). - Proper pipeline design should enable running the pipeline across multiple datasets with minimal - dataset-specific information required, and such that that information can be specified in as easy a - manner as possible. Examples of how to do this are forthcoming. + file for the necessary stages of your pipeline. Identify in advance what dataset-specific information the + user will need to specify to run your pipeline (e.g., will they need links between dataset codes and + external ontologies? Will they need to specify select key-event concepts to identify in the data? etc.). + Proper pipeline design should enable running the pipeline across multiple datasets with minimal + dataset-specific information required, and such that that information can be specified in as easy a + manner as possible. Examples of how to do this are forthcoming. ## MEDS ETL / Extraction Pipeline Details @@ -190,42 +190,42 @@ To use this repository as a template, the user should follow these steps: Assumptions: 1. Your data is organized in a set of parquet files on disk such that each row of each file corresponds to - one or more measurements per patient and has all necessary information in that row to extract said - measurement, organized in a simple, columnar format. Each of these parquet files stores the patient's ID in - a column called `patient_id` in the same type. + one or more measurements per patient and has all necessary information in that row to extract said + measurement, organized in a simple, columnar format. Each of these parquet files stores the patient's ID in + a column called `patient_id` in the same type. 2. You have a pre-defined or can externally define the requisite MEDS base `code_metadata` file that - describes the codes in your data as necessary. This file is not used in the provided pre-processing - pipeline in this package, but is necessary for other uses of the MEDS data. + describes the codes in your data as necessary. This file is not used in the provided pre-processing + pipeline in this package, but is necessary for other uses of the MEDS data. Computational Resource Requirements: 1. This pipeline is designed for achieving throughput through parallelism in a controllable and - resource-efficient manner. By changing the input shard size and by launching more or fewer copies of the - job steps, you can control the resources used as desired. + resource-efficient manner. By changing the input shard size and by launching more or fewer copies of the + job steps, you can control the resources used as desired. 2. This pipeline preferentially uses disk over memory and compute through aggressive caching. You should - have sufficient disk space to store multiple copies of your raw dataset comfortably. + have sufficient disk space to store multiple copies of your raw dataset comfortably. 3. This pipeline can be run on a single machine or across many worker nodes on a cluster provided the worker - nodes have access to a distributed file system. The internal "locking" mechanism used to limit race - conditions among multiple workers in this pipeline is not guaranteed to be robust to all distributed - systems, though in practice this is unlikely to cause issues. + nodes have access to a distributed file system. The internal "locking" mechanism used to limit race + conditions among multiple workers in this pipeline is not guaranteed to be robust to all distributed + systems, though in practice this is unlikely to cause issues. The provided ETL consists of the following steps, which can be performed as needed by the user with whatever degree of parallelism is desired per step. 1. It re-shards the input data into a set of smaller, event-level shards to facilitate parallel processing. - This can be skipped if your input data is already suitably sharded at either a per-patient or per-event - level. + This can be skipped if your input data is already suitably sharded at either a per-patient or per-event + level. 2. It extracts the subject IDs from the sharded data and computes the set of ML splits and (per split) the - patient shards. These are stored in a JSON file in the output cohort directory. + patient shards. These are stored in a JSON file in the output cohort directory. 3. It converts the input, event level shards into the MEDS flat format and joins and shards these data into - patient-level shards for MEDS use and stores them in a nested format in the output cohort directory, - again in the flat format. This step can be broken down into two sub-steps: - - First, each input shard is converted to the MEDS flat format and split into sub patient-level shards. - - Second, the appropriate sub patient-level shards are joined and re-organized into the final - patient-level shards. This method ensures that we minimize the amount of read contention on the input - shards during the join process and can maximize parallel throughput, as (theoretically, with sufficient - workers) all input shards can be sub-sharded in parallel and then all output shards can be joined in - parallel. + patient-level shards for MEDS use and stores them in a nested format in the output cohort directory, + again in the flat format. This step can be broken down into two sub-steps: + - First, each input shard is converted to the MEDS flat format and split into sub patient-level shards. + - Second, the appropriate sub patient-level shards are joined and re-organized into the final + patient-level shards. This method ensures that we minimize the amount of read contention on the input + shards during the join process and can maximize parallel throughput, as (theoretically, with sufficient + workers) all input shards can be sub-sharded in parallel and then all output shards can be joined in + parallel. The ETL scripts all use [Hydra](https://hydra.cc/) for configuration management, leveraging the shared `configs/extraction.yaml` file for configuration. The user can override any of these settings in the normal @@ -278,20 +278,20 @@ script is a functional test that is also run with `pytest` to verify correctness #### Core Scripts: 1. `scripts/extraction/shard_events.py` shards the input data into smaller, event-level shards by splitting - raw files into chunks of a configurable number of rows. Files are split sequentially, with no regard for - data content or patient boundaries. The resulting files are stored in the `subsharded_events` - subdirectory of the output directory. + raw files into chunks of a configurable number of rows. Files are split sequentially, with no regard for + data content or patient boundaries. The resulting files are stored in the `subsharded_events` + subdirectory of the output directory. 2. `scripts/extraction/split_and_shard_patients.py` splits the patient population into ML splits and shards - these splits into patient-level shards. The result of this process is only a simple `JSON` file - containing the patient IDs belonging to individual splits and shards. This file is stored in the - `output_directory/splits.json` file. + these splits into patient-level shards. The result of this process is only a simple `JSON` file + containing the patient IDs belonging to individual splits and shards. This file is stored in the + `output_directory/splits.json` file. 3. `scripts/extraction/convert_to_sharded_events.py` converts the input, event-level shards into the MEDS - event format and splits them into patient-level sub-shards. So, the resulting files are sharded into - patient-level, then event-level groups and are not merged into full patient-level shards or appropriately - sorted for downstream use. + event format and splits them into patient-level sub-shards. So, the resulting files are sharded into + patient-level, then event-level groups and are not merged into full patient-level shards or appropriately + sorted for downstream use. 4. `scripts/extraction/merge_to_MEDS_cohort.py` merges the patient-level, event-level shards into full - patient-level shards and sorts them appropriately for downstream use. The resulting files are stored in - the `output_directory/final_cohort` directory. + patient-level shards and sorts them appropriately for downstream use. The resulting files are stored in + the `output_directory/final_cohort` directory. ## MEDS Pre-processing Transformations @@ -302,46 +302,46 @@ broken down into the following steps: 1. Filtering the dataset by criteria that do not require cross-patient analyses, e.g., - - Filtering patients by the number of events or unique times they have. - - Removing numeric values that fall outside of pre-specified, per-code ranges (e.g., for outlier - removal). + - Filtering patients by the number of events or unique times they have. + - Removing numeric values that fall outside of pre-specified, per-code ranges (e.g., for outlier + removal). 2. Adding any extra events to the records that are necessary for downstream modeling, e.g., - - Adding time-derived measurements, e.g., - - The time since the last event of a certain type. - - The patient's age as of each unique timepoint. - - The time-of-day of each event. - - Adding a "dummy" event to the dataset for each patient that occurs at the end of the observation - period. + - Adding time-derived measurements, e.g., + - The time since the last event of a certain type. + - The patient's age as of each unique timepoint. + - The time-of-day of each event. + - Adding a "dummy" event to the dataset for each patient that occurs at the end of the observation + period. 3. Iteratively (a) grouping the dataset by `code` and associated code modifier columns and collecting - statistics on the numeric and categorical values for each code then (b) filtering the dataset down to - remove outliers or other undesired codes or values, e.g., + statistics on the numeric and categorical values for each code then (b) filtering the dataset down to + remove outliers or other undesired codes or values, e.g., - - Computing the mean and standard deviation of the numeric values for each code. - - Computing the number of times each code occurs in the dataset. - - Computing appropriate numeric bins for each code for value discretization. + - Computing the mean and standard deviation of the numeric values for each code. + - Computing the number of times each code occurs in the dataset. + - Computing appropriate numeric bins for each code for value discretization. 4. Transforming the code space to appropriately include or exclude any additional measurement columns that - should be included during code grouping and modeling operations. The goal of this step is to ensure that - the only columns that need be processed going into the pre-processing, tokenization, and tensorization - stage are expressible in the `code` and `numeric_values` columns of the dataset, which helps - standardize further downstream use. + should be included during code grouping and modeling operations. The goal of this step is to ensure that + the only columns that need be processed going into the pre-processing, tokenization, and tensorization + stage are expressible in the `code` and `numeric_values` columns of the dataset, which helps + standardize further downstream use. - - Standardizing the unit of measure of observed codes or adding the unit of measure to the code such that - group-by operations over the code take the unit into account. - - Adding categorical normal/abnormal flags to laboratory test result codes. + - Standardizing the unit of measure of observed codes or adding the unit of measure to the code such that + group-by operations over the code take the unit into account. + - Adding categorical normal/abnormal flags to laboratory test result codes. 5. Normalizing the data to convert codes to indices and numeric values to the desired form (either - categorical indices or normalized numeric values). + categorical indices or normalized numeric values). 6. Tokenizing the data in time to create a pre-tensorized dataset with clear delineations between patients, - patient sequence elements, and measurements per sequence element (note that various of these delineations - may be fully flat/trivial for unnested formats). + patient sequence elements, and measurements per sequence element (note that various of these delineations + may be fully flat/trivial for unnested formats). 7. Tensorizing the data to permit efficient retrieval from disk of patient data for deep-learning modeling - via PyTorch. + via PyTorch. Much like how the entire MEDS ETL pipeline is controlled by a single configuration file, the pre-processing pipeline is also controlled by a single configuration file, stored in `configs/preprocessing.yaml`. Scripts @@ -441,9 +441,9 @@ To use either of these, you need to install additional optional dependencies: ## TODOs: 1. We need to have a vehicle to cleanly separate dataset-specific variables from the general configuration - files. Similar to task configuration files, but for models. + files. Similar to task configuration files, but for models. 2. Figure out how to ensure that each pre-processing step reads from the right prior files. Likely need some - kind of a "prior stage name" config variable. + kind of a "prior stage name" config variable. ## Notes: diff --git a/docs/pipeline_configuration.md b/docs/pipeline_configuration.md index 7c7e7dc..f41ad19 100644 --- a/docs/pipeline_configuration.md +++ b/docs/pipeline_configuration.md @@ -14,69 +14,69 @@ Suppose you have a pipeline with an input directory of `$INPUT_DIR` and a cohort `$COHORT_DIR`. Let us further suppose we impose a series of the following stages: 1. `stage_1`: A metadata map-reduce stage (e.g., counting the occurrence rates of the codes in the - data). + data). 2. `stage_2`: A metadata-only processing stage (e.g., filtering the code dataframe to only codes - that occur more than 10 times). + that occur more than 10 times). 3. `stage_3`: A data processing stage (e.g., filtering the data to only rows with a code that is in the - current running metadata file, which, due to `stage_2`, are those codes that occur more than 10 times). + current running metadata file, which, due to `stage_2`, are those codes that occur more than 10 times). 4. `stage_4`: A metadata map-reduce stage (e.g., computing the means and variances for the numerical values - in the data). + in the data). 5. `stage_5`: A data processing stage (e.g., occluding all measurement values that occur more than 3 - standard deviations from the mean). + standard deviations from the mean). 6. `stage_6`: A metadata map-reduce stage (e.g., computing the means and variances for the numerical values - in the data). + in the data). 7. `stage_7`: A data processing stage (e.g., normalizing the data to have a mean of 0 and a standard - deviation of 1). + deviation of 1). Each of these stages will read and write their output datasets in the following manner. 1. `stage_1`: - - As there is no preceding data stage, this stage will read the data in from `$INPUT_DIR/data` - (the `data` suffix is the default data directory for MEDS datasets). - - This stage will, in its mapping stage, write the partial extracted metadata files to the - `$COHORT_DIR/stage_1/$SHARD_NAME.parquet` directory. - - This stage will read in the prior joint metadata file from the `$INPUT_DIR/metadata/codes.parquet` - directory to join with the new metadata. - - This stage will join all its metadata shards, join any prior columns from the old metadata, and - write the final, joined metadata file to the `$COHORT_DIR/stage_1/codes.parquet` directory. + - As there is no preceding data stage, this stage will read the data in from `$INPUT_DIR/data` + (the `data` suffix is the default data directory for MEDS datasets). + - This stage will, in its mapping stage, write the partial extracted metadata files to the + `$COHORT_DIR/stage_1/$SHARD_NAME.parquet` directory. + - This stage will read in the prior joint metadata file from the `$INPUT_DIR/metadata/codes.parquet` + directory to join with the new metadata. + - This stage will join all its metadata shards, join any prior columns from the old metadata, and + write the final, joined metadata file to the `$COHORT_DIR/stage_1/codes.parquet` directory. 2. `stage_2`: - - This stage will read in the metadata from the `$COHORT_DIR/stage_1/codes.parquet` directory. - - This stage will write the filtered metadata to the `$COHORT_DIR/stage_2/codes.parquet` directory. + - This stage will read in the metadata from the `$COHORT_DIR/stage_1/codes.parquet` directory. + - This stage will write the filtered metadata to the `$COHORT_DIR/stage_2/codes.parquet` directory. 3. `stage_3`: - - This stage will read in the data from the `$INPUT_DIR/data` directory as there has still been no - prior data processing stage. Individual shards will be read from the - `$INPUT_DIR/data/$SHARD_NAME.parquet` files. - - This stage will read in the metadata from the `$COHORT_DIR/stage_2/codes.parquet` directory. - - This stage will write the filtered shards to the `$COHORT_DIR/stage_3/$SHARD_NAME.parquet` files. + - This stage will read in the data from the `$INPUT_DIR/data` directory as there has still been no + prior data processing stage. Individual shards will be read from the + `$INPUT_DIR/data/$SHARD_NAME.parquet` files. + - This stage will read in the metadata from the `$COHORT_DIR/stage_2/codes.parquet` directory. + - This stage will write the filtered shards to the `$COHORT_DIR/stage_3/$SHARD_NAME.parquet` files. 4. `stage_4`: - - This stage will read in the data from the `$COHORT_DIR/stage_3` directory as that is the prior data - processing stage. - - This stage will write the partial extracted metadata files to the - `$COHORT_DIR/stage_4/$SHARD_NAME.parquet` file. - - This stage will read in the prior metadata from the `$COHORT_DIR/stage_2/codes.parquet` directory and - join it with the new metadata. - - This stage will join all its metadata shards, join any prior columns from the old metadata, and - write the final, joined metadata file to the `$COHORT_DIR/stage_4/codes.parquet` file. + - This stage will read in the data from the `$COHORT_DIR/stage_3` directory as that is the prior data + processing stage. + - This stage will write the partial extracted metadata files to the + `$COHORT_DIR/stage_4/$SHARD_NAME.parquet` file. + - This stage will read in the prior metadata from the `$COHORT_DIR/stage_2/codes.parquet` directory and + join it with the new metadata. + - This stage will join all its metadata shards, join any prior columns from the old metadata, and + write the final, joined metadata file to the `$COHORT_DIR/stage_4/codes.parquet` file. 5. `stage_5`: - - This stage will read in the data from the `$COHORT_DIR/stage_3` directory. - - This stage will read in the metadata from the `$COHORT_DIR/stage_4/codes.parquet` file. - - This stage will write the filtered shards to the `$COHORT_DIR/stage_5/$SHARD_NAME.parquet` files. + - This stage will read in the data from the `$COHORT_DIR/stage_3` directory. + - This stage will read in the metadata from the `$COHORT_DIR/stage_4/codes.parquet` file. + - This stage will write the filtered shards to the `$COHORT_DIR/stage_5/$SHARD_NAME.parquet` files. 6. `stage_6`: - - This stage will read in the data from the `$COHORT_DIR/stage_5` directory. - - This stage will write the partial extracted metadata files to the - `$COHORT_DIR/stage_6/$SHARD_NAME.parquet` file. - - This stage will read in the prior metadata from the `$COHORT_DIR/stage_4/codes.parquet` file and - join it with the new metadata. - - This stage will join all its metadata shards, join any prior columns from the old metadata, and - write the final, joined metadata file to both the `$COHORT_DIR/stage_6/codes.parquet` file and to the - `$COHORT_DIR/metadata/codes.parquet` file _given that this is the last metadata stage in the pipeline._ - Note that this reduced file is the only metadata file written in the global cohort metadata directory; - the partial map files are only written to the stage directories. + - This stage will read in the data from the `$COHORT_DIR/stage_5` directory. + - This stage will write the partial extracted metadata files to the + `$COHORT_DIR/stage_6/$SHARD_NAME.parquet` file. + - This stage will read in the prior metadata from the `$COHORT_DIR/stage_4/codes.parquet` file and + join it with the new metadata. + - This stage will join all its metadata shards, join any prior columns from the old metadata, and + write the final, joined metadata file to both the `$COHORT_DIR/stage_6/codes.parquet` file and to the + `$COHORT_DIR/metadata/codes.parquet` file _given that this is the last metadata stage in the pipeline._ + Note that this reduced file is the only metadata file written in the global cohort metadata directory; + the partial map files are only written to the stage directories. 7. `stage_7`: - - This stage will read in the data from the `$COHORT_DIR/stage_5` directory. - - This stage will read in the metadata from the `$COHORT_DIR/stage_6/codes.parquet` file. - - This stage will write the normalized shards to the `$COHORT_DIR/data/$SHARD_NAME.parquet` files, using - the _global cohort data directory given that this is the last data processing stage in the pipeline._ - Note that, unlike for metadata, where only the reduce output is written to the global cohort metadata - file, all data shards are written to the global cohort data directory for final data processing stages - (which do not have reduce stages). + - This stage will read in the data from the `$COHORT_DIR/stage_5` directory. + - This stage will read in the metadata from the `$COHORT_DIR/stage_6/codes.parquet` file. + - This stage will write the normalized shards to the `$COHORT_DIR/data/$SHARD_NAME.parquet` files, using + the _global cohort data directory given that this is the last data processing stage in the pipeline._ + Note that, unlike for metadata, where only the reduce output is written to the global cohort metadata + file, all data shards are written to the global cohort data directory for final data processing stages + (which do not have reduce stages). diff --git a/docs/preprocessing_operation_prototypes.md b/docs/preprocessing_operation_prototypes.md index 289a4e4..7974e28 100644 --- a/docs/preprocessing_operation_prototypes.md +++ b/docs/preprocessing_operation_prototypes.md @@ -42,8 +42,8 @@ preparation for mapping that transformation out across the patient data by code. ##### Operation Steps 1. Add new information or transform existing columns in an existing `metadata/codes.parquet` file. Note that - `code` or `code_modifier` columns should _not_ be modified in this step as that will break the linkage - with the patient data. + `code` or `code_modifier` columns should _not_ be modified in this step as that will break the linkage + with the patient data. ##### Parameters @@ -71,7 +71,7 @@ simplicity). 1. Per-shard, filter the pateint data to satisfy desired set of patient or other data critieria. 2. Per-shard, group by code and collect some aggregate statistics. Optionally also compute statistics across - all codes. + all codes. 3. Reduce the per-shard aggregate files into a unified `metadata/codes.parquet` file. 4. Optionally merge with static per-code metadata from prior steps. @@ -79,8 +79,8 @@ simplicity). 1. What (if any) patient data filters should be applied prior to aggregation. 2. What aggregation functions should be applied to each code. Each aggregation function must specify both a - _mapper_ function that computes aggregate data on a per-shard basis and a _reducer_ function that - combines different shards together into a single, unified metadata file. + _mapper_ function that computes aggregate data on a per-shard basis and a _reducer_ function that + combines different shards together into a single, unified metadata file. 3. Whether or not aggregation functions should be computed over all raw data (the "null" code case). ##### Status @@ -95,7 +95,7 @@ Patient Filters: **None** Functions: 1. Various aggregation functions; see `src/MEDS_transforms/aggregate_code_metadata.py` for a list of supported - functions. + functions. ##### Planned Future Operations @@ -108,9 +108,9 @@ include: 1. Filtering patients wholesale based on aggregate, patient-level criteria (e.g., number of events, etc.) 2. Filtering the data to only include patient data that matches some cohort specification (meaning removing - data that is not within pre-identified ranges of time on a per-patient basis). + data that is not within pre-identified ranges of time on a per-patient basis). 3. Filtering individual measurements from the data based on some criteria (e.g., removing measurements that - have codes that are not included in the overall vocabulary, etc.). + have codes that are not included in the overall vocabulary, etc.). #### Filtering Patients @@ -152,7 +152,7 @@ via a `metadata/codes.parquet` file. 1. Per-shard, join the data, if necessary, to the provided, global `metadata/codes.parquet` file. 2. Apply row-based criteria to each measurement to determine if it should be retained or removed. 3. Return the filtered dataset, in the same format as the original, but with only the measurements to be - retained. + retained. ##### Parameters @@ -200,15 +200,15 @@ are added and this function is not reversible. 1. Per-shard, join the data, if necessary, to the provided, global `metadata/codes.parquet` file. 2. Apply row-based criteria to each measurement to determine if individual features should be occluded or - retained in full granularity. + retained in full granularity. 3. Set occluded data to the occlusion target (typically `"UNK"`, `None`, or `np.NaN`) and add an indicator - column indicating occlusion status. + column indicating occlusion status. ##### Parameters 1. What criteria should be used to occlude features. - - Relatedly, what occlusion value should be used for occluded features. - - Relatedly, what the name of the occlusion column should be (can be set by default for features). + - Relatedly, what occlusion value should be used for occluded features. + - Relatedly, what the name of the occlusion column should be (can be set by default for features). 2. What, if any, columns in the `metadata/codes.parquet` file should be joined in to the data. ##### Status @@ -219,7 +219,7 @@ This operation is only supported through the single `filter_outliers_fntr` funct ##### Currently supported operations 1. Occluding numerical values if they take a value more distant from the code's mean by a specified number - of standard deviations. + of standard deviations. ### Transforming Measurements within Events @@ -227,10 +227,10 @@ These aren't implemented yet, but are planned: 1. Re-order measurements within the event ordering. 2. Split measurements into multiple measurements in a particular order and via a particular functional form. - E.g., - - Performing ontology expansion - - Splitting a multi-faceted measurement (e.g., blood pressure recorded as `"120/80"`) into multiple - measurements (e.g., a systolic and diastolic blood pressure measurement with values `120` and `80`). + E.g., + - Performing ontology expansion + - Splitting a multi-faceted measurement (e.g., blood pressure recorded as `"120/80"`) into multiple + measurements (e.g., a systolic and diastolic blood pressure measurement with values `120` and `80`). ## Requesting New Prototypes diff --git a/eICU_Example/README.md b/eICU_Example/README.md index 148903f..13c0e6b 100644 --- a/eICU_Example/README.md +++ b/eICU_Example/README.md @@ -10,21 +10,21 @@ up from this one). **Status**: This is a work in progress. The code is not yet functional. Remaining work includes: - [ ] Implementing the pre-MEDS processing step. - - [ ] Identifying the pre-MEDS steps for eICU + - [ ] Identifying the pre-MEDS steps for eICU - [ ] Testing the pre-MEDS processing step on live eICU-CRD. - - [ ] Test that it runs at all. - - [ ] Test that the output is as expected. + - [ ] Test that it runs at all. + - [ ] Test that the output is as expected. - [ ] Check the installation instructions on a fresh client. - [ ] Testing the `configs/event_configs.yaml` configuration on eICU-CRD - [ ] Testing the MEDS extraction ETL runs on eICU-CRD (this should be expected to work, but needs - live testing). - - [ ] Sub-sharding - - [ ] Patient split gathering - - [ ] Event extraction - - [ ] Merging + live testing). + - [ ] Sub-sharding + - [ ] Patient split gathering + - [ ] Event extraction + - [ ] Merging - [ ] Validating the output MEDS cohort - - [ ] Basic validation - - [ ] Detailed validation + - [ ] Basic validation + - [ ] Detailed validation ## Step 0: Installation @@ -50,12 +50,12 @@ there should be a `hosp` and `icu` subdirectory of `$EICU_RAW_DIR`. This is a step in a few parts: 1. Join a few tables by `hadm_id` to get the right timestamps in the right rows for processing. In - particular, we need to join: - - TODO + particular, we need to join: + - TODO 2. Convert the patient's static data to a more parseable form. This entails: - - Get the patient's DOB in a format that is usable for MEDS, rather than the integral `anchor_year` and - `anchor_offset` fields. - - Merge the patient's `dod` with the `deathtime` from the `admissions` table. + - Get the patient's DOB in a format that is usable for MEDS, rather than the integral `anchor_year` and + `anchor_offset` fields. + - Merge the patient's `dod` with the `deathtime` from the `admissions` table. After these steps, modified files or symlinks to the original files will be written in a new directory which will be used as the input to the actual MEDS extraction ETL. We'll use `$EICU_PREMEDS_DIR` to denote this @@ -89,7 +89,7 @@ subdirectories of the same root directory). This is a step in 4 parts: 1. Sub-shard the raw files. Run this command as many times simultaneously as you would like to have workers - performing this sub-sharding step. See below for how to automate this parallelism using hydra launchers. + performing this sub-sharding step. See below for how to automate this parallelism using hydra launchers. ```bash ./scripts/extraction/shard_events.py \ @@ -100,7 +100,7 @@ This is a step in 4 parts: In practice, on a machine with 150 GB of RAM and 10 cores, this step takes approximately 20 minutes in total. -2. Extract and form the patient splits and sub-shards. +1. Extract and form the patient splits and sub-shards. ```bash ./scripts/extraction/split_and_shard_patients.py \ @@ -111,7 +111,7 @@ In practice, on a machine with 150 GB of RAM and 10 cores, this step takes appro In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less than 5 minutes in total. -3. Extract patient sub-shards and convert to MEDS events. +1. Extract patient sub-shards and convert to MEDS events. ```bash ./scripts/extraction/convert_to_sharded_events.py \ @@ -126,7 +126,7 @@ multiple times (though this will, of course, consume more resources). If your fi commands can also be launched as separate slurm jobs, for example. For eICU, this level of parallelization and performance is not necessary; however, for larger datasets, it can be. -4. Merge the MEDS events into a single file per patient sub-shard. +1. Merge the MEDS events into a single file per patient sub-shard. ```bash ./scripts/extraction/merge_to_MEDS_cohort.py \ @@ -166,7 +166,7 @@ to finish before moving on to the next stage. Let `$N_PARALLEL_WORKERS` be the n In practice, on a machine with 150 GB of RAM and 10 cores, this step takes approximately 20 minutes in total. -2. Extract and form the patient splits and sub-shards. +1. Extract and form the patient splits and sub-shards. ```bash ./scripts/extraction/split_and_shard_patients.py \ @@ -177,7 +177,7 @@ In practice, on a machine with 150 GB of RAM and 10 cores, this step takes appro In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less than 5 minutes in total. -3. Extract patient sub-shards and convert to MEDS events. +1. Extract patient sub-shards and convert to MEDS events. ```bash ./scripts/extraction/convert_to_sharded_events.py \ @@ -192,7 +192,7 @@ multiple times (though this will, of course, consume more resources). If your fi commands can also be launched as separate slurm jobs, for example. For eICU, this level of parallelization and performance is not necessary; however, for larger datasets, it can be. -4. Merge the MEDS events into a single file per patient sub-shard. +1. Merge the MEDS events into a single file per patient sub-shard. ```bash ./scripts/extraction/merge_to_MEDS_cohort.py \ @@ -206,7 +206,7 @@ and performance is not necessary; however, for larger datasets, it can be. Currently, some tables are ignored, including: 1. `admissiondrug`: The [documentation](https://eicu-crd.mit.edu/eicutables/admissiondrug/) notes that this is - extremely infrequently used, so we skip it. + extremely infrequently used, so we skip it. 2. Lots of questions remain about how to appropriately handle timestamps of the data -- e.g., things like HCPCS diff --git a/src/MEDS_transforms/extract/README.md b/src/MEDS_transforms/extract/README.md index 73724ec..b1e9b12 100644 --- a/src/MEDS_transforms/extract/README.md +++ b/src/MEDS_transforms/extract/README.md @@ -5,11 +5,11 @@ dataset is: 1. Arranged in a series of files on disk of an allowed format (e.g., `.csv`, `.csv.gz`, `.parquet`)... 2. Such that each file stores a dataframe containing data about patients such that each row of any given - table corresponds to zero or more observations about a patient at a given time... + table corresponds to zero or more observations about a patient at a given time... 3. And you can configure how to extract those observations in the time, code, and numeric value - format of MEDS in the event conversion `yaml` file format specified below, then... - this tool can automatically extract your raw data into a MEDS dataset for you in an efficient, reproducible, - and communicable way. + format of MEDS in the event conversion `yaml` file format specified below, then... + this tool can automatically extract your raw data into a MEDS dataset for you in an efficient, reproducible, + and communicable way. TODO: figure @@ -29,19 +29,19 @@ the three assumptions above. This step is called "Pre-MEDS" and is not standardi provide some guidance on this below. 1. **Configuration** Second, once your data obeys the assumptions above, you need to specify your extraction - configuration options. This step is broken down into two parts: - \- _Event Configuration_ First, and most importantly, you must specify the event conversion configuration - file. This file specifies how to convert your raw data into MEDS events. - \- _Pipeline Configuration_ Second, you must specify any non-event-conversion pipeline configuration - variables, either through the command line or through a configuration file. + configuration options. This step is broken down into two parts: + \- _Event Configuration_ First, and most importantly, you must specify the event conversion configuration + file. This file specifies how to convert your raw data into MEDS events. + \- _Pipeline Configuration_ Second, you must specify any non-event-conversion pipeline configuration + variables, either through the command line or through a configuration file. 2. **MEDS Extract CLI** Third, you must run the MEDS Extract tool with the specified configuration options. - This will extract your raw data into a MEDS dataset. You can run each stage of the MEDS Extract tool - individually, if you want greater control over the parallelism or management of each stage, or you can - run the full pipeline at once via TODO. + This will extract your raw data into a MEDS dataset. You can run each stage of the MEDS Extract tool + individually, if you want greater control over the parallelism or management of each stage, or you can + run the full pipeline at once via TODO. 3. **Data Cleaning** Finally, and optionally, you can also use MEDS Transform to configure additional data - cleaning steps that can be applied to the MEDS dataset after it has been extracted. This is not required - for MEDS compliance, but can be useful for downstream users of the dataset. We will provide greater - details on this below. + cleaning steps that can be applied to the MEDS dataset after it has been extracted. This is not required + for MEDS compliance, but can be useful for downstream users of the dataset. We will provide greater + details on this below. In the next few sections of the documentation, we will provide greater details on [steps 1](#step-1-configuring-meds-extract) and [2](#step-2-running-meds-extract) of this process, as these are the @@ -77,20 +77,20 @@ If no such file exists for any valid suffix, an error will be raised. Otherwise, each row of the file will be converted into a MEDS event according to the logic specified, as follows: 1. The code of the output MEDS observation will be constructed based on the - `relative_table_file_stem.event_name.code` field. This field can be a string literal, a reference to an - input column (denoted by the `col(...)` syntax), or a list of same. If it is a list, the output code will - be a `"//"` separated string of each field in the list. Each field in the list (again either a string - literal or a input column reference) is interpreted either as the specified string literal or as the - value present in the input column. If an input column is missing in the file, an error will be raised. If - a row has a null value for a specified input column, that field will be converted to the string `"UNK"` - in the output code. + `relative_table_file_stem.event_name.code` field. This field can be a string literal, a reference to an + input column (denoted by the `col(...)` syntax), or a list of same. If it is a list, the output code will + be a `"//"` separated string of each field in the list. Each field in the list (again either a string + literal or a input column reference) is interpreted either as the specified string literal or as the + value present in the input column. If an input column is missing in the file, an error will be raised. If + a row has a null value for a specified input column, that field will be converted to the string `"UNK"` + in the output code. 2. The time of the output MEDS observation will either be `null` (corresponding to static events) or - will be read from the column specified via the input. Time columns must either be in a datetime or date - format in the input data, or a string format that can be converted to a time via the optional `time_format` - key, which is either a string literal format or a list of formats to try in priority order. + will be read from the column specified via the input. Time columns must either be in a datetime or date + format in the input data, or a string format that can be converted to a time via the optional `time_format` + key, which is either a string literal format or a list of formats to try in priority order. 3. All subsequent keys and values in the event conversion block will be extracted as MEDS output column - names by directly copying from the input data columns given. There is no need to use a `col(...)` syntax - here, as string literals _cannot_ be used for these columns. + names by directly copying from the input data columns given. There is no need to use a `col(...)` syntax + here, as string literals _cannot_ be used for these columns. There are several more nuanced aspects to the configuration file that have not yet been discussed. First, the configuration file also specifies how to identify the patient ID from the raw data. This can be done either by @@ -217,8 +217,8 @@ include: 1. Extracting numeric values from free-text values in the dataset. 2. Splitting compound measurements into their constituent parts (e.g., splitting a "blood pressure" - measurement that is recorded in the raw data as "120/80" into separate "systolic" and "diastolic" blood - pressure measurements). + measurement that is recorded in the raw data as "120/80" into separate "systolic" and "diastolic" blood + pressure measurements). 3. Removing known technical errors in the raw data based on local data expertise. ## FAQ @@ -245,14 +245,14 @@ extraction pipeline. Note that this tool is _not_: 1. A specialized tool for a particular raw data source or a particular source common data model (e.g., OMOP, - i2b2, etc.). It is a general-purpose tool that can be used to extract general raw data sources into a - MEDS dataset. There may be more specialized tools available for dedicated CDMs or public data sources. - See TODO for a detailed list. + i2b2, etc.). It is a general-purpose tool that can be used to extract general raw data sources into a + MEDS dataset. There may be more specialized tools available for dedicated CDMs or public data sources. + See TODO for a detailed list. 2. A universal tool that can be used to extract _any_ raw data source into a MEDS dataset. It is a tool that - can be used to extract _many_ raw data sources into a MEDS datasets, but no tool can be universally - applicable to _all_ raw data sources. If you think your raw data source does not sufficiently conform to - the assumptions of the MEDS Extract tool (see below), you may need to write a custom extraction tool for - your raw data. Feel free to reach out if you have any questions or concerns about this. + can be used to extract _many_ raw data sources into a MEDS datasets, but no tool can be universally + applicable to _all_ raw data sources. If you think your raw data source does not sufficiently conform to + the assumptions of the MEDS Extract tool (see below), you may need to write a custom extraction tool for + your raw data. Feel free to reach out if you have any questions or concerns about this. ## Future Improvements and Roadmap From 39c626e20944a43dace645593e4678a2f4a00938 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 8 Aug 2024 15:00:36 -0400 Subject: [PATCH 21/53] Added revision dates and authors to docs --- mkdocs.yml | 2 ++ pyproject.toml | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/mkdocs.yml b/mkdocs.yml index 02bd3d6..33de7a6 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -41,3 +41,5 @@ plugins: nav_file: SUMMARY.md - section-index - mkdocstrings + - git-authors + - git-revision-date-localized diff --git a/pyproject.toml b/pyproject.toml index 55c1583..3cb078c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,8 @@ local_parallelism = ["hydra-joblib-launcher"] slurm_parallelism = ["hydra-submitit-launcher"] docs = [ "mkdocs==1.6.0", "mkdocs-material==9.5.31", "mkdocstrings[python,shell]==0.25.2", "mkdocs-gen-files==0.5.0", - "mkdocs-literate-nav==0.6.1", "mkdocs-section-index==0.3.9" + "mkdocs-literate-nav==0.6.1", "mkdocs-section-index==0.3.9", "mkdocs-git-authors-plugin==0.9.0", + "mkdocs-git-revision-date-localized-plugin==1.2.6" ] [project.scripts] From 44f18b77e0f3f4160c2c564f30b6fbf455c765e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Renc?= Date: Thu, 8 Aug 2024 18:45:30 -0400 Subject: [PATCH 22/53] Add warning when split_fracs_dict not empty but performing the split solely based on external_splits --- .../extract/split_and_shard_patients.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/MEDS_transforms/extract/split_and_shard_patients.py b/src/MEDS_transforms/extract/split_and_shard_patients.py index f9c2ee3..28c8d51 100755 --- a/src/MEDS_transforms/extract/split_and_shard_patients.py +++ b/src/MEDS_transforms/extract/split_and_shard_patients.py @@ -76,6 +76,12 @@ def shard_patients[ Traceback (most recent call last): ... ValueError: Unable to adjust splits to ensure all splits have at least 1 patient. + >>> external_splits = { + ... 'train': np.array([1, 2, 3, 4, 5], dtype=int), + ... 'test': np.array([6, 7, 8, 9, 10], dtype=int), + ... } + >>> shard_patients(patients, 5, external_splits) + {'train/0': [1, 2, 3, 4, 5], 'test/0': [6, 7, 8, 9, 10]} """ if sum(split_fracs_dict.values()) != 1: @@ -125,10 +131,13 @@ def shard_patients[ splits = {**{k: v for k, v in zip(split_names, patients_per_split)}, **splits} else: - logger.info( - "The external split definition covered all patients. No need to perform an " - "additional patient split." - ) + if split_fracs_dict: + logger.warning( + "External splits were provided covering all patients, but split_fracs_dict was not empty. " + "Ignoring the split_fracs_dict." + ) + else: + logger.info("External splits were provided covering all patients.") # Sharding final_shards = {} @@ -236,7 +245,7 @@ def main(cfg: DictConfig): if not external_splits_json_fp.exists(): raise FileNotFoundError(f"External splits JSON file not found at {external_splits_json_fp}") - logger.info(f"Reading external splits from {cfg.stage_cfg.external_splits_json_fp}") + logger.info(f"Reading external splits from {str(cfg.stage_cfg.external_splits_json_fp.resolve())}") external_splits = json.loads(external_splits_json_fp.read_text()) size_strs = ", ".join(f"{k}: {len(v)}" for k, v in external_splits.items()) From 2e1f875bef9153c7248e1b6aceddea139f93745b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Renc?= Date: Thu, 8 Aug 2024 19:11:11 -0400 Subject: [PATCH 23/53] Throw ValueError when external split lengths contradict n_patient_per_shar restriction --- .../extract/split_and_shard_patients.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/MEDS_transforms/extract/split_and_shard_patients.py b/src/MEDS_transforms/extract/split_and_shard_patients.py index 28c8d51..6f8f952 100755 --- a/src/MEDS_transforms/extract/split_and_shard_patients.py +++ b/src/MEDS_transforms/extract/split_and_shard_patients.py @@ -77,11 +77,15 @@ def shard_patients[ ... ValueError: Unable to adjust splits to ensure all splits have at least 1 patient. >>> external_splits = { - ... 'train': np.array([1, 2, 3, 4, 5], dtype=int), - ... 'test': np.array([6, 7, 8, 9, 10], dtype=int), + ... 'train': np.array([1, 2, 3, 4, 5, 6], dtype=int), + ... 'test': np.array([7, 8, 9, 10], dtype=int), ... } - >>> shard_patients(patients, 5, external_splits) - {'train/0': [1, 2, 3, 4, 5], 'test/0': [6, 7, 8, 9, 10]} + >>> shard_patients(patients, 6, external_splits) + {'train/0': [1, 2, 3, 4, 5, 6], 'test/0': [7, 8, 9, 10]} + >>> shard_patients(patients, 3, external_splits) + Traceback (most recent call last): + ... + ValueError: External splits must have fewer patients than n_patients_per_shard (3): len(train)=6, ... """ if sum(split_fracs_dict.values()) != 1: @@ -97,6 +101,14 @@ def shard_patients[ f"Attempting to convert to numpy array of dtype {patients.dtype}." ) external_splits[k] = np.array(external_splits[k], dtype=patients.dtype) + if too_lengthy_external_splits := { + k: len(v) for k, v in external_splits.items() if len(v) > n_patients_per_shard + }: + raise ValueError( + f"External splits must have fewer patients than n_patients_per_shard " + f"({n_patients_per_shard}): " + + ", ".join(f"len({k})={v}" for k, v in too_lengthy_external_splits.items()) + ) patients = np.unique(patients) From 2fbb2934c22f380849a7406ee8fc275480f5cc04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Renc?= Date: Thu, 8 Aug 2024 19:14:07 -0400 Subject: [PATCH 24/53] Add .editorconfig --- .editorconfig | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 .editorconfig diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..b3c39e5 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,13 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +indent_size = 4 +indent_style = space +insert_final_newline = true +max_line_length = 110 +tab_width = 4 + +[{*.yaml,*.yml}] +indent_size = 2 From 8839f6710eb91dfaf2b5ca9a8c8c7fd690624c35 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 8 Aug 2024 19:32:57 -0400 Subject: [PATCH 25/53] Started removing unnecessary references to shards file --- src/MEDS_transforms/extract/convert_to_sharded_events.py | 4 ++-- src/MEDS_transforms/extract/extract_code_metadata.py | 2 +- src/MEDS_transforms/extract/finalize_MEDS_metadata.py | 3 ++- src/MEDS_transforms/extract/split_and_shard_patients.py | 2 +- src/MEDS_transforms/utils.py | 9 +++------ 5 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/MEDS_transforms/extract/convert_to_sharded_events.py b/src/MEDS_transforms/extract/convert_to_sharded_events.py index e5eb514..8d6a01e 100755 --- a/src/MEDS_transforms/extract/convert_to_sharded_events.py +++ b/src/MEDS_transforms/extract/convert_to_sharded_events.py @@ -685,9 +685,9 @@ def main(cfg: DictConfig): file. """ - input_dir, patient_subsharded_dir, metadata_input_dir, shards_map_fn = stage_init(cfg) + input_dir, patient_subsharded_dir, metadata_input_dir = stage_init(cfg) - shards = json.loads(shards_map_fn.read_text()) + shards = json.loads(Path(cfg.shards_map_fp).read_text()) event_conversion_cfg_fp = Path(cfg.event_conversion_config_fp) if not event_conversion_cfg_fp.exists(): diff --git a/src/MEDS_transforms/extract/extract_code_metadata.py b/src/MEDS_transforms/extract/extract_code_metadata.py index 5e831be..db08f7b 100644 --- a/src/MEDS_transforms/extract/extract_code_metadata.py +++ b/src/MEDS_transforms/extract/extract_code_metadata.py @@ -349,7 +349,7 @@ def main(cfg: DictConfig): schema. """ - stage_input_dir, partial_metadata_dir, _, _ = stage_init(cfg) + stage_input_dir, partial_metadata_dir, _ = stage_init(cfg) raw_input_dir = Path(cfg.input_dir) event_conversion_cfg_fp = Path(cfg.event_conversion_config_fp) diff --git a/src/MEDS_transforms/extract/finalize_MEDS_metadata.py b/src/MEDS_transforms/extract/finalize_MEDS_metadata.py index f8486c2..366d89a 100755 --- a/src/MEDS_transforms/extract/finalize_MEDS_metadata.py +++ b/src/MEDS_transforms/extract/finalize_MEDS_metadata.py @@ -146,7 +146,7 @@ def main(cfg: DictConfig): logger.info("Non-zero worker found in reduce-only stage. Exiting") return - _, _, input_metadata_dir, shards_map_fp = stage_init(cfg) + _, _, input_metadata_dir = stage_init(cfg) output_metadata_dir = Path(cfg.stage_cfg.reducer_output_dir) output_code_metadata_fp = output_metadata_dir / "codes.parquet" @@ -193,6 +193,7 @@ def main(cfg: DictConfig): dataset_metadata_fp.write_text(json.dumps(dataset_metadata)) # Split creation + shards_map_fp = Path(cfg.shards_map_fp) logger.info("Creating patient splits from {str(shards_map_fp.resolve())}") shards_map = json.loads(shards_map_fp.read_text()) patient_splits = [] diff --git a/src/MEDS_transforms/extract/split_and_shard_patients.py b/src/MEDS_transforms/extract/split_and_shard_patients.py index 3a4e1d5..abe5a11 100755 --- a/src/MEDS_transforms/extract/split_and_shard_patients.py +++ b/src/MEDS_transforms/extract/split_and_shard_patients.py @@ -186,7 +186,7 @@ def main(cfg: DictConfig): ensure that split fractions sum to 1. """ - subsharded_dir, MEDS_cohort_dir, _, _ = stage_init(cfg) + subsharded_dir, MEDS_cohort_dir, _ = stage_init(cfg) event_conversion_cfg_fp = Path(cfg.event_conversion_config_fp) if not event_conversion_cfg_fp.exists(): diff --git a/src/MEDS_transforms/utils.py b/src/MEDS_transforms/utils.py index bc7f036..3e847d4 100644 --- a/src/MEDS_transforms/utils.py +++ b/src/MEDS_transforms/utils.py @@ -59,15 +59,14 @@ def write_lazyframe(df: pl.LazyFrame, out_fp: Path) -> None: df.write_parquet(out_fp, use_pyarrow=True) -def stage_init(cfg: DictConfig): +def stage_init(cfg: DictConfig) -> tuple[Path, Path, Path]: """Initializes the stage by logging the configuration and the stage-specific paths. Args: cfg: The global configuration object, which should have a ``cfg.stage_cfg`` attribute containing the stage specific configuration. - Returns: The data input directory, stage output directory, metadata input directory, and the shards file - path. + Returns: The data input directory, stage output directory, and metadata input directory. """ hydra_loguru_init() @@ -78,7 +77,6 @@ def stage_init(cfg: DictConfig): input_dir = Path(cfg.stage_cfg.data_input_dir) output_dir = Path(cfg.stage_cfg.output_dir) metadata_input_dir = Path(cfg.stage_cfg.metadata_input_dir) - shards_map_fp = Path(cfg.shards_map_fp) def chk(x: Path): return "✅" if x.exists() else "❌" @@ -89,7 +87,6 @@ def chk(x: Path): "input_dir": input_dir, "output_dir": output_dir, "metadata_input_dir": metadata_input_dir, - "shards_map_fp": shards_map_fp, }.items() ] @@ -99,7 +96,7 @@ def chk(x: Path): ] logger.debug("\n".join(logger_strs + paths_strs)) - return input_dir, output_dir, metadata_input_dir, shards_map_fp + return input_dir, output_dir, metadata_input_dir def get_package_name() -> str: From ba6b5b660fc1673a513d314bc3b16ddfb54ab860 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 8 Aug 2024 22:57:59 -0400 Subject: [PATCH 26/53] Removed usage of shards json file throughout, outside of the direct extraction pipeline where it is guaranteed to be created. --- src/MEDS_transforms/configs/extract.yaml | 2 +- src/MEDS_transforms/configs/preprocess.yaml | 3 - .../split_and_shard_patients.yaml | 2 +- .../extract/merge_to_MEDS_cohort.py | 33 ++- .../extract/split_and_shard_patients.py | 12 +- src/MEDS_transforms/mapreduce/mapper.py | 48 ++-- src/MEDS_transforms/mapreduce/utils.py | 252 ++++++++++++------ src/MEDS_transforms/utils.py | 9 +- tests/test_extract.py | 6 +- 9 files changed, 246 insertions(+), 121 deletions(-) diff --git a/src/MEDS_transforms/configs/extract.yaml b/src/MEDS_transforms/configs/extract.yaml index 377c39c..b0f2d50 100644 --- a/src/MEDS_transforms/configs/extract.yaml +++ b/src/MEDS_transforms/configs/extract.yaml @@ -30,7 +30,7 @@ event_conversion_config_fp: ??? # The code modifier columns are in this pipeline only used in the aggregate_code_metadata stage. code_modifiers: null # The shards mapping is stored in the root of the final output directory. -shards_map_fp: "${cohort_dir}/splits.json" +shards_map_fp: "${cohort_dir}/metadata/.shards.json" stages: - shard_events diff --git a/src/MEDS_transforms/configs/preprocess.yaml b/src/MEDS_transforms/configs/preprocess.yaml index 8a30b43..6b2dc8b 100644 --- a/src/MEDS_transforms/configs/preprocess.yaml +++ b/src/MEDS_transforms/configs/preprocess.yaml @@ -21,9 +21,6 @@ etl_metadata.pipeline_name: "preprocess" # tokenization. code_modifiers: ??? -# The shards map filepath is stored in the global input directory for model-specific pre-processing. -shards_map_fp: "${input_dir}/splits.json" - # Pipeline Structure stages: - filter_patients diff --git a/src/MEDS_transforms/configs/stage_configs/split_and_shard_patients.yaml b/src/MEDS_transforms/configs/stage_configs/split_and_shard_patients.yaml index b56aeb1..c4015bd 100644 --- a/src/MEDS_transforms/configs/stage_configs/split_and_shard_patients.yaml +++ b/src/MEDS_transforms/configs/stage_configs/split_and_shard_patients.yaml @@ -1,6 +1,6 @@ split_and_shard_patients: is_metadata: True - output_dir: ${cohort_dir} + output_dir: ${cohort_dir}/metadata n_patients_per_shard: 50000 external_splits_json_fp: null split_fracs: diff --git a/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py b/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py index b9e638a..df8f301 100755 --- a/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py +++ b/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +import json +import random from functools import partial from pathlib import Path @@ -8,7 +10,7 @@ from omegaconf import DictConfig, OmegaConf from MEDS_transforms.extract import CONFIG_YAML -from MEDS_transforms.mapreduce.mapper import map_over, shard_iterator +from MEDS_transforms.mapreduce.mapper import map_over def merge_subdirs_and_sort( @@ -234,10 +236,37 @@ def main(cfg: DictConfig): additional_sort_by=cfg.stage_cfg.get("additional_sort_by", None), ) + shard_map_fp = Path(cfg.shards_map_fp) + if not shard_map_fp.exists(): + raise FileNotFoundError(f"Shard map file not found at {str(shard_map_fp.resolve())}") + + shards = list(json.loads(shard_map_fp.read_text()).keys()) + + def shard_iterator(cfg: DictConfig) -> tuple[list[str], bool]: + input_dir = Path(cfg.stage_cfg.data_input_dir) + output_dir = Path(cfg.stage_cfg.output_dir) + + if cfg.stage_cfg.get("train_only", None): + raise ValueError("train_only is not supported for this stage.") + + if "worker" in cfg: + random.seed(cfg.worker) + random.shuffle(shards) + + logger.info(f"Mapping computation over a maximum of {len(shards)} shards") + + out = [] + for sh in shards: + in_fp = input_dir / sh + out_fp = output_dir / f"{sh}.parquet" + out.append((in_fp, out_fp)) + + return out, False + map_over( cfg, read_fn=read_fn, - shard_iterator_fntr=partial(shard_iterator, in_suffix=""), + shard_iterator_fntr=shard_iterator, ) diff --git a/src/MEDS_transforms/extract/split_and_shard_patients.py b/src/MEDS_transforms/extract/split_and_shard_patients.py index abe5a11..ab45a64 100755 --- a/src/MEDS_transforms/extract/split_and_shard_patients.py +++ b/src/MEDS_transforms/extract/split_and_shard_patients.py @@ -186,7 +186,7 @@ def main(cfg: DictConfig): ensure that split fractions sum to 1. """ - subsharded_dir, MEDS_cohort_dir, _ = stage_init(cfg) + subsharded_dir, _, _ = stage_init(cfg) event_conversion_cfg_fp = Path(cfg.event_conversion_config_fp) if not event_conversion_cfg_fp.exists(): @@ -249,11 +249,11 @@ def main(cfg: DictConfig): seed=cfg.seed, ) - logger.info(f"Writing sharded patients to {MEDS_cohort_dir}") - MEDS_cohort_dir.mkdir(parents=True, exist_ok=True) - out_fp = MEDS_cohort_dir / "splits.json" - out_fp.write_text(json.dumps(sharded_patients)) - logger.info(f"Done writing sharded patients to {out_fp}") + shards_map_fp = Path(cfg.shards_map_fp) + logger.info(f"Writing sharded patients to {str(shards_map_fp.resolve())}") + shards_map_fp.parent.mkdir(parents=True, exist_ok=True) + shards_map_fp.write_text(json.dumps(sharded_patients)) + logger.info("Done writing sharded patients") if __name__ == "__main__": diff --git a/src/MEDS_transforms/mapreduce/mapper.py b/src/MEDS_transforms/mapreduce/mapper.py index 3d675b0..6b637f3 100644 --- a/src/MEDS_transforms/mapreduce/mapper.py +++ b/src/MEDS_transforms/mapreduce/mapper.py @@ -157,27 +157,33 @@ def map_over( start = datetime.now() - process_split = cfg.stage_cfg.get("process_split", None) + train_only = cfg.stage_cfg.get("train_only", False) split_fp = Path(cfg.input_dir) / "metadata" / "patient_split.parquet" - shards_map_fp = Path(cfg.shards_map_fp) if "shards_map_fp" in cfg else None - if process_split and split_fp.exists(): - split_patients = ( - pl.scan_parquet(split_fp) - .filter(pl.col("split") == process_split) - .select(pl.col("patient_id")) - .collect() - .to_list() - ) - read_fn = read_and_filter_fntr(split_patients, read_fn) - elif process_split and shards_map_fp and shards_map_fp.exists(): - logger.warning( - f"Split {process_split} requested, but no patient split file found at {str(split_fp)}. " - f"Assuming this is handled through shard filtering." - ) - elif process_split: - raise ValueError( - f"Split {process_split} requested, but no patient split file found at {str(split_fp)}." - ) + + shards, includes_only_train = shard_iterator_fntr(cfg) + + if train_only: + if includes_only_train: + logger.info( + f"Processing train split only via shard prefix. Not filtering with {str(split_fp.resolve())}." + ) + elif split_fp.exists(): + logger.info(f"Processing train split only by filtering read dfs via {str(split_fp.resolve())}") + train_patients = ( + pl.scan_parquet(split_fp) + .filter(pl.col("split") == "train") + .select(pl.col("patient_id")) + .collect() + .to_list() + ) + read_fn = read_and_filter_fntr(train_patients, read_fn) + else: + raise FileNotFoundError( + f"Train split requested, but shard prefixes can't be used and " + f"patient split file not found at {str(split_fp.resolve())}." + ) + elif includes_only_train: + raise ValueError("All splits should be used, but shard iterator is returning only train splits?!?") def fntr_params(compute_fn: ANY_COMPUTE_FN_T) -> ComputeFnArgs: compute_fn_params = inspect.signature(compute_fn).parameters @@ -209,7 +215,7 @@ def fntr_params(compute_fn: ANY_COMPUTE_FN_T) -> ComputeFnArgs: raise ValueError("Invalid compute function") all_out_fps = [] - for in_fp, out_fp in shard_iterator_fntr(cfg): + for in_fp, out_fp in shards: logger.info(f"Processing {str(in_fp.resolve())} into {str(out_fp.resolve())}") rwlock_wrap( in_fp, diff --git a/src/MEDS_transforms/mapreduce/utils.py b/src/MEDS_transforms/mapreduce/utils.py index 77cbe7d..690568e 100644 --- a/src/MEDS_transforms/mapreduce/utils.py +++ b/src/MEDS_transforms/mapreduce/utils.py @@ -223,111 +223,204 @@ def rwlock_wrap[ def shard_iterator( cfg: DictConfig, - in_suffix: str = ".parquet", out_suffix: str = ".parquet", in_prefix: str = "", - out_prefix: str = "", -): - """Provides a generator that yields shard input and output files for mapreduce operations. +) -> tuple[list[tuple[Path, Path]], bool]: + """Returns a list of the shards found in the input directory and their corresponding output directories. Args: cfg: The configuration dictionary for the overall pipeline. Should (possibly) contain the following keys (some are optional, as marked below): - - ``stage_cfg.data_input_dir`` (mandatory): The directory containing the input data. - - ``stage_cfg.output_dir`` (mandatory): The directory to write the output data. - - ``shards_map_fp`` (mandatory): The file path to the shards map JSON file. - - ``stage_cfg.process_split`` (optional): The prefix of the shards to process (e.g., - ``"train/"``). If not provided, all shards will be processed. - - ``worker`` (optional): The worker ID for the MR worker; this is also used to seed the + - `stage_cfg.data_input_dir` (mandatory): The directory containing the input data. + - `stage_cfg.output_dir` (mandatory): The directory to write the output data. + - `stage_cfg.train_only` (optional): The prefix of the shards to process (e.g., + `"train/"`). If not provided, all shards will be processed. + - `worker` (optional): The worker ID for the MR worker; this is also used to seed the randomization process. If not provided, the randomization process is unseeded. - in_suffix: The suffix of the input files. Defaults to ".parquet". This can be set to "" to process - entire directories. out_suffix: The suffix of the output files. Defaults to ".parquet". - in_prefix: The prefix of the input files. Defaults to "". This can be used to load files from a - subdirectory of the input directory by including a "/" at the end of the prefix. - out_prefix: The prefix of the output files. Defaults to "". + in_prefix: The prefix of the input files. Defaults to "". This must be a full path component. It can + end with a slash but even if it doesn't it will be interpreted as a full path component. Yields: Randomly shuffled pairs of input and output file paths for each shard. The randomization process is seeded by the worker ID in ``cfg``, if provided, otherwise it is left unseeded. Examples: - >>> from tempfile import NamedTemporaryFile - >>> shards = {"train/0": [1, 2, 3], "train/1": [4, 5, 6], "held_out": [4, 5, 6], "foo": [5]} - >>> with NamedTemporaryFile() as tmp: - ... _ = Path(tmp.name).write_text(json.dumps(shards)) + >>> from tempfile import TemporaryDirectory + >>> import polars as pl + >>> df = pl.DataFrame({ + ... "patient_id": [1, 2, 3, 4, 5, 6, 7, 8, 9], + ... "code": ["A", "B", "C", "D", "E", "F", "G", "H", "I"], + ... "time": [1, 2, 3, 4, 5, 6, 1, 2, 3], + ... }) + >>> shards = {"train/0": [1, 2, 3, 4], "train/1": [5, 6, 7], "tuning": [8], "held_out": [9]} + >>> def write_dfs(input_dir: Path, df: pl.DataFrame=df, shards: dict=shards, sfx: str=".parquet"): + ... for shard_name, patient_ids in shards.items(): + ... df = df.filter(pl.col("patient_id").is_in(patient_ids)) + ... shard_fp = input_dir / f"{shard_name}{sfx}" + ... shard_fp.parent.mkdir(exist_ok=True, parents=True) + ... if sfx == ".parquet": df.write_parquet(shard_fp) + ... elif sfx == ".csv": df.write_csv(shard_fp) + ... else: raise ValueError(f"Unsupported suffix {sfx}") + ... return + + By default, this will load all shards in the input directory and write specify their appropriate output + directories: + >>> with TemporaryDirectory() as tmp: + ... root = Path(tmp) + ... input_dir = root / "data" + ... output_dir = root / "output" + ... write_dfs(input_dir) + ... cfg = DictConfig({ + ... "stage_cfg": {"data_input_dir": str(input_dir), "output_dir": str(output_dir)}, + ... "worker": 1, + ... }) + ... fps, includes_only_train = shard_iterator(cfg) + >>> [(i.relative_to(root), o.relative_to(root)) for i, o in fps] # doctest: +NORMALIZE_WHITESPACE + [(PosixPath('data/train/0.parquet'), PosixPath('output/train/0.parquet')), + (PosixPath('data/tuning.parquet'), PosixPath('output/tuning.parquet')), + (PosixPath('data/train/1.parquet'), PosixPath('output/train/1.parquet')), + (PosixPath('data/held_out.parquet'), PosixPath('output/held_out.parquet'))] + >>> includes_only_train + False + + Different workers will shuffle the shards differently: + >>> with TemporaryDirectory() as tmp: + ... root = Path(tmp) + ... input_dir = root / "data" + ... output_dir = root / "output" + ... write_dfs(input_dir) ... cfg = DictConfig({ - ... "stage_cfg": {"data_input_dir": "data/", "output_dir": "output/"}, - ... "shards_map_fp": tmp.name, "worker": 1, + ... "stage_cfg": {"data_input_dir": str(input_dir), "output_dir": str(output_dir)}, + ... "worker": 2, ... }) - ... gen = shard_iterator(cfg) - ... list(gen) # doctest: +NORMALIZE_WHITESPACE - [(PosixPath('data/foo.parquet'), PosixPath('output/foo.parquet')), + ... fps, includes_only_train = shard_iterator(cfg) + >>> [(i.relative_to(root), o.relative_to(root)) for i, o in fps] # doctest: +NORMALIZE_WHITESPACE + [(PosixPath('data/held_out.parquet'), PosixPath('output/held_out.parquet')), + (PosixPath('data/train/1.parquet'), PosixPath('output/train/1.parquet')), (PosixPath('data/train/0.parquet'), PosixPath('output/train/0.parquet')), - (PosixPath('data/held_out.parquet'), PosixPath('output/held_out.parquet')), - (PosixPath('data/train/1.parquet'), PosixPath('output/train/1.parquet'))] - >>> with NamedTemporaryFile() as tmp: - ... _ = Path(tmp.name).write_text(json.dumps(shards)) + (PosixPath('data/tuning.parquet'), PosixPath('output/tuning.parquet'))] + >>> includes_only_train + False + + We can also make it look within a specific input subdir of the data directory and change the output + suffix. Note that using a specific input subdir is _different_ than requesting it load only train. + >>> with TemporaryDirectory() as tmp: + ... root = Path(tmp) + ... input_dir = root / "data" + ... output_dir = root / "output" + ... write_dfs(input_dir) ... cfg = DictConfig({ - ... "stage_cfg": {"data_input_dir": "data/", "output_dir": "output/"}, - ... "shards_map_fp": tmp.name, "worker": 1, + ... "stage_cfg": {"data_input_dir": str(input_dir), "output_dir": str(output_dir)}, + ... "worker": 1, ... }) - ... gen = shard_iterator(cfg, in_suffix="", out_suffix=".csv", in_prefix="a/", out_prefix="b/") - ... list(gen) # doctest: +NORMALIZE_WHITESPACE - [(PosixPath('data/a/foo'), PosixPath('output/b/foo.csv')), - (PosixPath('data/a/train/0'), PosixPath('output/b/train/0.csv')), - (PosixPath('data/a/held_out'), PosixPath('output/b/held_out.csv')), - (PosixPath('data/a/train/1'), PosixPath('output/b/train/1.csv'))] - >>> with NamedTemporaryFile() as tmp: - ... _ = Path(tmp.name).write_text(json.dumps(shards)) + ... fps, includes_only_train = shard_iterator(cfg, in_prefix="train", out_suffix=".csv") + >>> [(i.relative_to(root), o.relative_to(root)) for i, o in fps] # doctest: +NORMALIZE_WHITESPACE + [(PosixPath('data/train/0.parquet'), PosixPath('output/0.csv')), + (PosixPath('data/train/1.parquet'), PosixPath('output/1.csv'))] + >>> includes_only_train + False + + We can also make it load only 'train' shards, in the case that there are shards with a valid "train/" + prefix. + >>> with TemporaryDirectory() as tmp: + ... root = Path(tmp) + ... input_dir = root / "data" + ... output_dir = root / "output" + ... write_dfs(input_dir) ... cfg = DictConfig({ ... "stage_cfg": { - ... "data_input_dir": "data/", "output_dir": "output/", "process_split": "train/" + ... "data_input_dir": str(input_dir), "output_dir": str(output_dir), + ... "train_only": True, ... }, - ... "shards_map_fp": tmp.name, ... "worker": 1, ... }) - ... gen = shard_iterator(cfg) - ... list(gen) # doctest: +NORMALIZE_WHITESPACE - [(PosixPath('data/train/1.parquet'), PosixPath('output/train/1.parquet')), - (PosixPath('data/train/0.parquet'), PosixPath('output/train/0.parquet'))] - - Note that if you ask it to process a split that isn't a valid prefix of the shards, it will return all - shards and assume that is covered via a `patient_splits.parquet` file: - >>> with NamedTemporaryFile() as tmp: - ... _ = Path(tmp.name).write_text(json.dumps(shards)) + ... fps, includes_only_train = shard_iterator(cfg) + >>> [(i.relative_to(root), o.relative_to(root)) for i, o in fps] # doctest: +NORMALIZE_WHITESPACE + [(PosixPath('data/train/0.parquet'), PosixPath('output/train/0.parquet')), + (PosixPath('data/train/1.parquet'), PosixPath('output/train/1.parquet'))] + >>> includes_only_train + True + + The train prefix used is precisely `train/` -- other uses of train will not work: + >>> wrong_pfx_shards = {"train": [1, 2, 3], "train_1": [4, 5, 6], "train-2": [7, 8, 9]} + >>> with TemporaryDirectory() as tmp: + ... root = Path(tmp) + ... input_dir = root / "data" + ... output_dir = root / "output" + ... write_dfs(input_dir, shards=wrong_pfx_shards) ... cfg = DictConfig({ - ... "stage_cfg": {"data_input_dir": "data/", "output_dir": "output/"}, - ... "shards_map_fp": tmp.name, "process_split": "nonexisting", "worker": 1, + ... "stage_cfg": { + ... "data_input_dir": str(input_dir), "output_dir": str(output_dir), + ... "train_only": True, + ... }, + ... "worker": 1, ... }) - ... gen = shard_iterator(cfg) - ... list(gen) # doctest: +NORMALIZE_WHITESPACE - [(PosixPath('data/foo.parquet'), PosixPath('output/foo.parquet')), - (PosixPath('data/train/0.parquet'), PosixPath('output/train/0.parquet')), - (PosixPath('data/held_out.parquet'), PosixPath('output/held_out.parquet')), - (PosixPath('data/train/1.parquet'), PosixPath('output/train/1.parquet'))] + ... fps, includes_only_train = shard_iterator(cfg) + >>> [(i.relative_to(root), o.relative_to(root)) for i, o in fps] # doctest: +NORMALIZE_WHITESPACE + [(PosixPath('data/train.parquet'), PosixPath('output/train.parquet')), + (PosixPath('data/train_1.parquet'), PosixPath('output/train_1.parquet')), + (PosixPath('data/train-2.parquet'), PosixPath('output/train-2.parquet'))] + >>> includes_only_train + False + + If there are no such shards, then it loads them all and assumes the filtering will be handled via the + splits parquet file. + >>> no_pfx_shards = {"0": [1, 2, 3], "1": [4, 5, 6], "2": [7, 8, 9]} + >>> with TemporaryDirectory() as tmp: + ... root = Path(tmp) + ... input_dir = root / "data" + ... output_dir = root / "output" + ... write_dfs(input_dir, shards=no_pfx_shards) + ... cfg = DictConfig({ + ... "stage_cfg": { + ... "data_input_dir": str(input_dir), "output_dir": str(output_dir), + ... "train_only": True, + ... }, + ... "worker": 1, + ... }) + ... fps, includes_only_train = shard_iterator(cfg) + >>> [(i.relative_to(root), o.relative_to(root)) for i, o in fps] # doctest: +NORMALIZE_WHITESPACE + [(PosixPath('data/1.parquet'), PosixPath('output/1.parquet')), + (PosixPath('data/0.parquet'), PosixPath('output/0.parquet')), + (PosixPath('data/2.parquet'), PosixPath('output/2.parquet'))] + >>> includes_only_train + False + + If it can't find any files, it will return an empty list: + >>> fps, includes_only_train = shard_iterator(cfg) + >>> fps + [] """ input_dir = Path(cfg.stage_cfg.data_input_dir) output_dir = Path(cfg.stage_cfg.output_dir) - shards_map_fp = Path(cfg.shards_map_fp) if "shards_map_fp" in cfg else None - - if shards_map_fp and shards_map_fp.is_file(): - shards = json.loads(shards_map_fp.read_text()) - - if "process_split" in cfg.stage_cfg: - if any(k.startswith(cfg.stage_cfg.process_split) for k in shards): - logger.info(f'Processing shards with prefix "{cfg.stage_cfg.process_split}"') - shards = {k: v for k, v in shards.items() if k.startswith(cfg.stage_cfg.process_split)} - shards = list(shards.keys()) - else: - shards = [] - for p in input_dir.glob(f"{in_prefix}*{in_suffix}"): - relative_path = p.relative_to(input_dir) - shard_name = str(relative_path) - shard_name = shard_name[len(in_prefix) :] if in_prefix else shard_name - shard_name = shard_name[: -len(in_suffix)] if in_suffix else shard_name - shards.append(shard_name) + + in_suffix = ".parquet" + + if in_prefix: + input_dir = input_dir / in_prefix + + shards = [] + for p in input_dir.glob(f"**/*{in_suffix}"): + relative_path = p.relative_to(input_dir) + shard_name = str(relative_path) + shard_name = shard_name[: -len(in_suffix)] + shards.append(shard_name) + + # We initialize this to False and overwrite it if we find dedicated train shards. + includes_only_train = False + + train_only = cfg.stage_cfg.get("train_only", None) + train_shards = [shard_name for shard_name in shards if shard_name.startswith("train/")] + if train_only and train_shards: + shards = train_shards + includes_only_train = True + elif train_only: + logger.info( + f"train_only={train_only} requested but no dedicated train shards found; processing all shards " + "and relying on `patient_splits.parquet` for filtering." + ) if "worker" in cfg: random.seed(cfg.worker) @@ -335,10 +428,11 @@ def shard_iterator( logger.info(f"Mapping computation over a maximum of {len(shards)} shards") + out = [] for sp in shards: - in_fp = input_dir / f"{in_prefix}{sp}{in_suffix}" - out_fp = output_dir / f"{out_prefix}{sp}{out_suffix}" - + in_fp = input_dir / f"{sp}{in_suffix}" + out_fp = output_dir / f"{sp}{out_suffix}" # TODO: Could add checking logic for existence of in_fp and/or out_fp here. + out.append((in_fp, out_fp)) - yield in_fp, out_fp + return out, includes_only_train diff --git a/src/MEDS_transforms/utils.py b/src/MEDS_transforms/utils.py index 3e847d4..0c4c20a 100644 --- a/src/MEDS_transforms/utils.py +++ b/src/MEDS_transforms/utils.py @@ -9,7 +9,6 @@ import hydra import polars as pl from loguru import logger -from meds import train_split from omegaconf import DictConfig, OmegaConf from MEDS_transforms import __package_name__ as package_name @@ -223,7 +222,7 @@ def populate_stage( ... "stage2": {"is_metadata": True}, ... "stage3": {"is_metadata": None, "output_dir": "/g/h"}, ... "stage4": {"data_input_dir": "/e/f"}, - ... "stage5": {"aggregations": ["foo"], "process_split": None}, + ... "stage5": {"aggregations": ["foo"], "train_only": None}, ... }, ... }) >>> args = [root_config[k] for k in ["input_dir", "cohort_dir", "stages", "stage_configs"]] @@ -232,7 +231,7 @@ def populate_stage( 'output_dir': '/c/d/stage1', 'reducer_output_dir': None} >>> populate_stage("stage2", *args) # doctest: +NORMALIZE_WHITESPACE {'is_metadata': True, 'data_input_dir': '/c/d/stage1', 'metadata_input_dir': '/a/b/metadata', - 'output_dir': '/c/d/stage2', 'reducer_output_dir': '/c/d/stage2', 'process_split': 'train'} + 'output_dir': '/c/d/stage2', 'reducer_output_dir': '/c/d/stage2', 'train_only': True} >>> populate_stage("stage3", *args) # doctest: +NORMALIZE_WHITESPACE {'is_metadata': None, 'output_dir': '/g/h', 'data_input_dir': '/c/d/stage1', 'metadata_input_dir': '/c/d/stage2', 'reducer_output_dir': None} @@ -240,7 +239,7 @@ def populate_stage( {'data_input_dir': '/e/f', 'is_metadata': False, 'metadata_input_dir': '/c/d/stage2', 'output_dir': '/c/d/stage4', 'reducer_output_dir': None} >>> populate_stage("stage5", *args) # doctest: +NORMALIZE_WHITESPACE - {'aggregations': ['foo'], 'process_split': None, 'is_metadata': True, 'data_input_dir': '/c/d/stage4', + {'aggregations': ['foo'], 'train_only': None, 'is_metadata': True, 'data_input_dir': '/c/d/stage4', 'metadata_input_dir': '/c/d/stage2', 'output_dir': '/c/d/stage5', 'reducer_output_dir': '/c/d/metadata'} >>> populate_stage("stage6", *args) # doctest: +NORMALIZE_WHITESPACE @@ -350,7 +349,7 @@ def populate_stage( } if is_metadata: - inferred_keys["process_split"] = train_split + inferred_keys["train_only"] = True out = {**stage} for key, val in inferred_keys.items(): diff --git a/tests/test_extract.py b/tests/test_extract.py index f63b9c0..b59ae02 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -441,10 +441,10 @@ def test_extraction(): all_stdouts.append(stdout) try: - splits_fp = MEDS_cohort_dir / "splits.json" - assert splits_fp.is_file(), f"Expected splits @ {str(splits_fp.resolve())} to exist." + shards_fp = MEDS_cohort_dir / "metadata" / ".shards.json" + assert shards_fp.is_file(), f"Expected splits @ {str(shards_fp.resolve())} to exist." - splits = json.loads(splits_fp.read_text()) + splits = json.loads(shards_fp.read_text()) expected_keys = ["train/0", "train/1", "tuning/0", "held_out/0"] expected_keys_str = ", ".join(f"'{k}'" for k in expected_keys) From d86fe1e9f71a03cf0d6a0284bcd9057fd793f60e Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 8 Aug 2024 23:15:40 -0400 Subject: [PATCH 27/53] Used a hash function instead of true randomization to order the shards in the iterator for test stability in github workflows --- src/MEDS_transforms/mapreduce/utils.py | 44 ++++++++++++++++---------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/src/MEDS_transforms/mapreduce/utils.py b/src/MEDS_transforms/mapreduce/utils.py index 690568e..e94df8d 100644 --- a/src/MEDS_transforms/mapreduce/utils.py +++ b/src/MEDS_transforms/mapreduce/utils.py @@ -1,7 +1,7 @@ """Basic utilities for parallelizable mapreduces on sharded MEDS datasets with caching and locking.""" +import hashlib import json -import random from collections.abc import Callable from datetime import datetime from pathlib import Path @@ -236,7 +236,7 @@ def shard_iterator( - `stage_cfg.train_only` (optional): The prefix of the shards to process (e.g., `"train/"`). If not provided, all shards will be processed. - `worker` (optional): The worker ID for the MR worker; this is also used to seed the - randomization process. If not provided, the randomization process is unseeded. + randomization process. If not provided, the randomization process will be unseeded. out_suffix: The suffix of the output files. Defaults to ".parquet". in_prefix: The prefix of the input files. Defaults to "". This must be a full path component. It can end with a slash but even if it doesn't it will be interpreted as a full path component. @@ -277,10 +277,10 @@ def shard_iterator( ... }) ... fps, includes_only_train = shard_iterator(cfg) >>> [(i.relative_to(root), o.relative_to(root)) for i, o in fps] # doctest: +NORMALIZE_WHITESPACE - [(PosixPath('data/train/0.parquet'), PosixPath('output/train/0.parquet')), + [(PosixPath('data/train/1.parquet'), PosixPath('output/train/1.parquet')), + (PosixPath('data/held_out.parquet'), PosixPath('output/held_out.parquet')), (PosixPath('data/tuning.parquet'), PosixPath('output/tuning.parquet')), - (PosixPath('data/train/1.parquet'), PosixPath('output/train/1.parquet')), - (PosixPath('data/held_out.parquet'), PosixPath('output/held_out.parquet'))] + (PosixPath('data/train/0.parquet'), PosixPath('output/train/0.parquet'))] >>> includes_only_train False @@ -296,10 +296,10 @@ def shard_iterator( ... }) ... fps, includes_only_train = shard_iterator(cfg) >>> [(i.relative_to(root), o.relative_to(root)) for i, o in fps] # doctest: +NORMALIZE_WHITESPACE - [(PosixPath('data/held_out.parquet'), PosixPath('output/held_out.parquet')), + [(PosixPath('data/tuning.parquet'), PosixPath('output/tuning.parquet')), + (PosixPath('data/held_out.parquet'), PosixPath('output/held_out.parquet')), (PosixPath('data/train/1.parquet'), PosixPath('output/train/1.parquet')), - (PosixPath('data/train/0.parquet'), PosixPath('output/train/0.parquet')), - (PosixPath('data/tuning.parquet'), PosixPath('output/tuning.parquet'))] + (PosixPath('data/train/0.parquet'), PosixPath('output/train/0.parquet'))] >>> includes_only_train False @@ -337,8 +337,8 @@ def shard_iterator( ... }) ... fps, includes_only_train = shard_iterator(cfg) >>> [(i.relative_to(root), o.relative_to(root)) for i, o in fps] # doctest: +NORMALIZE_WHITESPACE - [(PosixPath('data/train/0.parquet'), PosixPath('output/train/0.parquet')), - (PosixPath('data/train/1.parquet'), PosixPath('output/train/1.parquet'))] + [(PosixPath('data/train/1.parquet'), PosixPath('output/train/1.parquet')), + (PosixPath('data/train/0.parquet'), PosixPath('output/train/0.parquet'))] >>> includes_only_train True @@ -358,9 +358,9 @@ def shard_iterator( ... }) ... fps, includes_only_train = shard_iterator(cfg) >>> [(i.relative_to(root), o.relative_to(root)) for i, o in fps] # doctest: +NORMALIZE_WHITESPACE - [(PosixPath('data/train.parquet'), PosixPath('output/train.parquet')), - (PosixPath('data/train_1.parquet'), PosixPath('output/train_1.parquet')), - (PosixPath('data/train-2.parquet'), PosixPath('output/train-2.parquet'))] + [(PosixPath('data/train_1.parquet'), PosixPath('output/train_1.parquet')), + (PosixPath('data/train-2.parquet'), PosixPath('output/train-2.parquet')), + (PosixPath('data/train.parquet'), PosixPath('output/train.parquet'))] >>> includes_only_train False @@ -381,8 +381,8 @@ def shard_iterator( ... }) ... fps, includes_only_train = shard_iterator(cfg) >>> [(i.relative_to(root), o.relative_to(root)) for i, o in fps] # doctest: +NORMALIZE_WHITESPACE - [(PosixPath('data/1.parquet'), PosixPath('output/1.parquet')), - (PosixPath('data/0.parquet'), PosixPath('output/0.parquet')), + [(PosixPath('data/0.parquet'), PosixPath('output/0.parquet')), + (PosixPath('data/1.parquet'), PosixPath('output/1.parquet')), (PosixPath('data/2.parquet'), PosixPath('output/2.parquet'))] >>> includes_only_train False @@ -423,8 +423,18 @@ def shard_iterator( ) if "worker" in cfg: - random.seed(cfg.worker) - random.shuffle(shards) + add_str = str(cfg.worker) + else: + add_str = str(datetime.now()) + + shard_keys = [] + for shard in shards: + shard_hash = hashlib.sha256((add_str + shard).encode("utf-8")).hexdigest() + if shard_hash in shard_keys: + raise ValueError(f"Hash collision for shard {shard} with add_str {add_str}!") + shard_keys.append(int(shard_hash, 16)) + + shards = [shard for _, shard in sorted(zip(shard_keys, shards))] logger.info(f"Mapping computation over a maximum of {len(shards)} shards") From 86d1acdcd6adf57e3a885d0ca33d64a22e0a0808 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Thu, 8 Aug 2024 23:36:34 -0400 Subject: [PATCH 28/53] Made it so that docs use submodule README files as the source documentation for the submodule page. --- docs/gen_ref_pages.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/gen_ref_pages.py b/docs/gen_ref_pages.py index d38acd7..c4d35c8 100644 --- a/docs/gen_ref_pages.py +++ b/docs/gen_ref_pages.py @@ -16,18 +16,26 @@ parts = tuple(module_path.parts) + md_file_lines = [] + if parts[-1] == "__init__": parts = parts[:-1] doc_path = doc_path.with_name("index.md") full_doc_path = full_doc_path.with_name("index.md") + + readme_path = Path("/".join(parts + ("README.md",))) + if (src / readme_path).exists(): + md_file_lines.append(f'--8<-- "src/{str(readme_path)}"') elif parts[-1] == "__main__": continue nav[parts] = doc_path.as_posix() + ident = ".".join(parts) + md_file_lines.append(f"::: {ident}") + with mkdocs_gen_files.open(full_doc_path, "w") as fd: - ident = ".".join(parts) - fd.write(f"::: {ident}") + fd.write("\n".join(md_file_lines)) mkdocs_gen_files.set_edit_path(full_doc_path, path.relative_to(root)) From 546c7636fd97676357f5d31652eb4b809e310e91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Renc?= Date: Fri, 9 Aug 2024 13:23:38 -0400 Subject: [PATCH 29/53] Allow resharding of external splits, move rng declaration out of the if statement --- .../extract/split_and_shard_patients.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/src/MEDS_transforms/extract/split_and_shard_patients.py b/src/MEDS_transforms/extract/split_and_shard_patients.py index 6f8f952..7b4f40b 100755 --- a/src/MEDS_transforms/extract/split_and_shard_patients.py +++ b/src/MEDS_transforms/extract/split_and_shard_patients.py @@ -83,9 +83,7 @@ def shard_patients[ >>> shard_patients(patients, 6, external_splits) {'train/0': [1, 2, 3, 4, 5, 6], 'test/0': [7, 8, 9, 10]} >>> shard_patients(patients, 3, external_splits) - Traceback (most recent call last): - ... - ValueError: External splits must have fewer patients than n_patients_per_shard (3): len(train)=6, ... + {'train/0': [5, 1, 3], 'train/1': [2, 6, 4], 'test/0': [10, 7], 'test/1': [8, 9]} """ if sum(split_fracs_dict.values()) != 1: @@ -101,14 +99,6 @@ def shard_patients[ f"Attempting to convert to numpy array of dtype {patients.dtype}." ) external_splits[k] = np.array(external_splits[k], dtype=patients.dtype) - if too_lengthy_external_splits := { - k: len(v) for k, v in external_splits.items() if len(v) > n_patients_per_shard - }: - raise ValueError( - f"External splits must have fewer patients than n_patients_per_shard " - f"({n_patients_per_shard}): " - + ", ".join(f"len({k})={v}" for k, v in too_lengthy_external_splits.items()) - ) patients = np.unique(patients) @@ -119,8 +109,8 @@ def shard_patients[ splits = external_splits + rng = np.random.default_rng(seed) if n_patients := len(patient_ids_to_split): - rng = np.random.default_rng(seed) split_names_idx = rng.permutation(len(split_fracs_dict)) split_names = np.array(list(split_fracs_dict.keys()))[split_names_idx] split_fracs = np.array([split_fracs_dict[k] for k in split_names]) From f2d2d0e3462f980ae4079d5c03df0f81a4b24df5 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Fri, 9 Aug 2024 14:50:48 -0400 Subject: [PATCH 30/53] Fixed lint errors. --- .pre-commit-config.yaml | 6 +----- eICU_Example/configs/table_preprocessors.yaml | 3 +-- src/MEDS_transforms/extract/convert_to_sharded_events.py | 7 +------ tests/test_add_time_derived_measurements.py | 5 +---- tests/test_filter_measurements.py | 5 +---- tests/test_occlude_outliers.py | 5 +---- tests/test_reorder_measurements.py | 5 +---- 7 files changed, 7 insertions(+), 29 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b188f48..98bc99d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -126,8 +126,4 @@ repos: - id: nbqa-isort args: ["--profile=black"] - id: nbqa-flake8 - args: - [ - "--extend-ignore=E203,E402,E501,F401,F841", - "--exclude=logs/*,data/*", - ] + args: ["--extend-ignore=E203,E402,E501,F401,F841", "--exclude=logs/*,data/*"] diff --git a/eICU_Example/configs/table_preprocessors.yaml b/eICU_Example/configs/table_preprocessors.yaml index 3faf4aa..a3ad2c3 100644 --- a/eICU_Example/configs/table_preprocessors.yaml +++ b/eICU_Example/configs/table_preprocessors.yaml @@ -2,8 +2,7 @@ admissiondx: offset_col: "admitdxenteredoffset" pseudotime_col: "admitDxEnteredTimestamp" output_data_cols: ["admitdxname", "admitdxid"] - warning_items: - ["How should we use `admitdxtest`?", "How should we use `admitdxpath`?"] + warning_items: ["How should we use `admitdxtest`?", "How should we use `admitdxpath`?"] allergy: offset_col: "allergyenteredoffset" diff --git a/src/MEDS_transforms/extract/convert_to_sharded_events.py b/src/MEDS_transforms/extract/convert_to_sharded_events.py index 8d6a01e..ee4e9d7 100755 --- a/src/MEDS_transforms/extract/convert_to_sharded_events.py +++ b/src/MEDS_transforms/extract/convert_to_sharded_events.py @@ -17,12 +17,7 @@ from MEDS_transforms.extract import CONFIG_YAML from MEDS_transforms.extract.shard_events import META_KEYS from MEDS_transforms.mapreduce.mapper import rwlock_wrap -from MEDS_transforms.utils import ( - is_col_field, - parse_col_field, - stage_init, - write_lazyframe, -) +from MEDS_transforms.utils import is_col_field, parse_col_field, stage_init, write_lazyframe def in_format(fmt: str, ts_name: str) -> pl.Expr: diff --git a/tests/test_add_time_derived_measurements.py b/tests/test_add_time_derived_measurements.py index aad8428..87138c8 100644 --- a/tests/test_add_time_derived_measurements.py +++ b/tests/test_add_time_derived_measurements.py @@ -5,10 +5,7 @@ """ -from .transform_tester_base import ( - ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, - single_stage_transform_tester, -) +from .transform_tester_base import ADD_TIME_DERIVED_MEASUREMENTS_SCRIPT, single_stage_transform_tester from .utils import parse_meds_csvs AGE_CALCULATION_STR = """ diff --git a/tests/test_filter_measurements.py b/tests/test_filter_measurements.py index 8fc901e..2243320 100644 --- a/tests/test_filter_measurements.py +++ b/tests/test_filter_measurements.py @@ -5,10 +5,7 @@ """ -from .transform_tester_base import ( - FILTER_MEASUREMENTS_SCRIPT, - single_stage_transform_tester, -) +from .transform_tester_base import FILTER_MEASUREMENTS_SCRIPT, single_stage_transform_tester from .utils import parse_meds_csvs # This is the code metadata diff --git a/tests/test_occlude_outliers.py b/tests/test_occlude_outliers.py index 7eaeb26..2060bb7 100644 --- a/tests/test_occlude_outliers.py +++ b/tests/test_occlude_outliers.py @@ -7,10 +7,7 @@ import polars as pl -from .transform_tester_base import ( - OCCLUDE_OUTLIERS_SCRIPT, - single_stage_transform_tester, -) +from .transform_tester_base import OCCLUDE_OUTLIERS_SCRIPT, single_stage_transform_tester from .utils import MEDS_PL_SCHEMA, parse_meds_csvs # This is the code metadata diff --git a/tests/test_reorder_measurements.py b/tests/test_reorder_measurements.py index 2f3dd0b..305b85b 100644 --- a/tests/test_reorder_measurements.py +++ b/tests/test_reorder_measurements.py @@ -5,10 +5,7 @@ """ -from .transform_tester_base import ( - REORDER_MEASUREMENTS_SCRIPT, - single_stage_transform_tester, -) +from .transform_tester_base import REORDER_MEASUREMENTS_SCRIPT, single_stage_transform_tester from .utils import parse_meds_csvs ORDERED_CODE_PATTERNS = [ From 36add00b875be4b9c967ea70001035892620a717 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 6 Aug 2024 11:53:02 -0400 Subject: [PATCH 31/53] Moved parser around. --- src/MEDS_transforms/__init__.py | 20 +++++++++++++++++++ src/MEDS_transforms/extract/README.md | 2 +- src/MEDS_transforms/extract/__init__.py | 9 ++------- .../extract/extract_code_metadata.py | 2 +- src/MEDS_transforms/{extract => }/parser.py | 0 5 files changed, 24 insertions(+), 9 deletions(-) rename src/MEDS_transforms/{extract => }/parser.py (100%) diff --git a/src/MEDS_transforms/__init__.py b/src/MEDS_transforms/__init__.py index 89dc91c..272f852 100644 --- a/src/MEDS_transforms/__init__.py +++ b/src/MEDS_transforms/__init__.py @@ -1,6 +1,8 @@ from importlib.metadata import PackageNotFoundError, version from importlib.resources import files +import polars as pl + __package_name__ = "MEDS_transforms" try: __version__ = version(__package_name__) @@ -9,3 +11,21 @@ PREPROCESS_CONFIG_YAML = files(__package_name__).joinpath("configs/preprocess.yaml") EXTRACT_CONFIG_YAML = files(__package_name__).joinpath("configs/extract.yaml") + +MANDATORY_TYPES = { + "patient_id": pl.Int64, + "time": pl.Datetime("us"), + "code": pl.String, + "numeric_value": pl.Float32, + "categoric_value": pl.String, + "text_value": pl.String, +} + +DEPRECATED_NAMES = { + "numerical_value": "numeric_value", + "categorical_value": "categoric_value", + "category_value": "categoric_value", + "textual_value": "text_value", + "timestamp": "time", + "subject_id": "patient_id", +} diff --git a/src/MEDS_transforms/extract/README.md b/src/MEDS_transforms/extract/README.md index b1e9b12..47f7e7d 100644 --- a/src/MEDS_transforms/extract/README.md +++ b/src/MEDS_transforms/extract/README.md @@ -101,7 +101,7 @@ be stored in a `patient_id` column. If the patient ID column is not found, an er Second, you can also specify how to link the codes constructed for each event block to code-specific metadata in these blocks. This is done by specifying a `_metadata` block in the event block. The format of this block -is detailed in the `parser.py` file in this directory; see there for more details. You can also see +is detailed in the `../parser.py` file in this directory; see there for more details. You can also see configuration options for this block in the `tests/test_extract.py` file and in the `MIMIC-IV_Example/configs/event_config.yaml` file. diff --git a/src/MEDS_transforms/extract/__init__.py b/src/MEDS_transforms/extract/__init__.py index 60fddc2..29dfa77 100644 --- a/src/MEDS_transforms/extract/__init__.py +++ b/src/MEDS_transforms/extract/__init__.py @@ -1,6 +1,6 @@ import polars as pl -from MEDS_transforms import EXTRACT_CONFIG_YAML +from MEDS_transforms import EXTRACT_CONFIG_YAML, MANDATORY_TYPES # We set this equality explicitly here so linting does not remove an apparently "unused" import if we just # rename with "as" during the import. @@ -13,9 +13,4 @@ "parent_codes": pl.List(pl.String), } -MEDS_DATA_MANDATORY_TYPES = { - "patient_id": pl.Int64, - "time": pl.Datetime("us"), - "code": pl.String, - "numeric_value": pl.Float32, -} +MEDS_DATA_MANDATORY_TYPES = MANDATORY_TYPES diff --git a/src/MEDS_transforms/extract/extract_code_metadata.py b/src/MEDS_transforms/extract/extract_code_metadata.py index db08f7b..1b8b394 100644 --- a/src/MEDS_transforms/extract/extract_code_metadata.py +++ b/src/MEDS_transforms/extract/extract_code_metadata.py @@ -15,9 +15,9 @@ from MEDS_transforms.extract import CONFIG_YAML, MEDS_METADATA_MANDATORY_TYPES from MEDS_transforms.extract.convert_to_sharded_events import get_code_expr -from MEDS_transforms.extract.parser import cfg_to_expr from MEDS_transforms.extract.utils import get_supported_fp from MEDS_transforms.mapreduce.mapper import rwlock_wrap +from MEDS_transforms.parser import cfg_to_expr from MEDS_transforms.utils import stage_init, write_lazyframe diff --git a/src/MEDS_transforms/extract/parser.py b/src/MEDS_transforms/parser.py similarity index 100% rename from src/MEDS_transforms/extract/parser.py rename to src/MEDS_transforms/parser.py From 4263592aa971a50ed4b6244151e253a89490ba72 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 6 Aug 2024 12:03:04 -0400 Subject: [PATCH 32/53] Updated mandatory types to be separate from mandatory MEDS columns --- src/MEDS_transforms/__init__.py | 2 ++ src/MEDS_transforms/extract/__init__.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/MEDS_transforms/__init__.py b/src/MEDS_transforms/__init__.py index 272f852..99fb66f 100644 --- a/src/MEDS_transforms/__init__.py +++ b/src/MEDS_transforms/__init__.py @@ -12,6 +12,8 @@ PREPROCESS_CONFIG_YAML = files(__package_name__).joinpath("configs/preprocess.yaml") EXTRACT_CONFIG_YAML = files(__package_name__).joinpath("configs/extract.yaml") +MANDATORY_COLUMNS = ["patient_id", "time", "code", "numeric_value"] + MANDATORY_TYPES = { "patient_id": pl.Int64, "time": pl.Datetime("us"), diff --git a/src/MEDS_transforms/extract/__init__.py b/src/MEDS_transforms/extract/__init__.py index 29dfa77..adb15a9 100644 --- a/src/MEDS_transforms/extract/__init__.py +++ b/src/MEDS_transforms/extract/__init__.py @@ -1,6 +1,6 @@ import polars as pl -from MEDS_transforms import EXTRACT_CONFIG_YAML, MANDATORY_TYPES +from MEDS_transforms import EXTRACT_CONFIG_YAML, MANDATORY_COLUMNS, MANDATORY_TYPES # We set this equality explicitly here so linting does not remove an apparently "unused" import if we just # rename with "as" during the import. @@ -13,4 +13,4 @@ "parent_codes": pl.List(pl.String), } -MEDS_DATA_MANDATORY_TYPES = MANDATORY_TYPES +MEDS_DATA_MANDATORY_TYPES = {c: MANDATORY_TYPES[c] for c in MANDATORY_COLUMNS} From 8fc4f48b7570517e0c74946b8fc025dd3f5da513 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Fri, 9 Aug 2024 21:38:49 -0400 Subject: [PATCH 33/53] Fixed bug with matcher import. --- src/MEDS_transforms/mapreduce/mapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MEDS_transforms/mapreduce/mapper.py b/src/MEDS_transforms/mapreduce/mapper.py index 61ef75f..f6b20c3 100644 --- a/src/MEDS_transforms/mapreduce/mapper.py +++ b/src/MEDS_transforms/mapreduce/mapper.py @@ -13,7 +13,7 @@ from loguru import logger from omegaconf import DictConfig, ListConfig -from ..extract.parser import is_matcher, matcher_to_expr +from ..parser import is_matcher, matcher_to_expr from ..utils import stage_init, write_lazyframe from .utils import rwlock_wrap, shard_iterator From dd43d5a0a7109b30d03586e5e1b0e8ec46543b47 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Fri, 9 Aug 2024 21:46:03 -0400 Subject: [PATCH 34/53] Fixed typo --- src/MEDS_transforms/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/MEDS_transforms/__init__.py b/src/MEDS_transforms/__init__.py index 99fb66f..e0aaaf3 100644 --- a/src/MEDS_transforms/__init__.py +++ b/src/MEDS_transforms/__init__.py @@ -19,13 +19,13 @@ "time": pl.Datetime("us"), "code": pl.String, "numeric_value": pl.Float32, - "categoric_value": pl.String, + "categorical_value": pl.String, "text_value": pl.String, } DEPRECATED_NAMES = { "numerical_value": "numeric_value", - "categorical_value": "categoric_value", + "categoric_value": "categoric_value", "category_value": "categoric_value", "textual_value": "text_value", "timestamp": "time", From fd6e77f0629f7e55aab2301b7184e6c477f92495 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 10 Aug 2024 14:33:59 -0400 Subject: [PATCH 35/53] Separated merge shards shard iterator out for ease of import. --- .../extract/merge_to_MEDS_cohort.py | 32 +--- src/MEDS_transforms/mapreduce/utils.py | 138 ++++++++++++++++-- 2 files changed, 127 insertions(+), 43 deletions(-) diff --git a/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py b/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py index df8f301..49a45e1 100755 --- a/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py +++ b/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py @@ -1,6 +1,4 @@ #!/usr/bin/env python -import json -import random from functools import partial from pathlib import Path @@ -11,6 +9,7 @@ from MEDS_transforms.extract import CONFIG_YAML from MEDS_transforms.mapreduce.mapper import map_over +from MEDS_transforms.mapreduce.utils import shard_iterator_by_shard_map def merge_subdirs_and_sort( @@ -236,37 +235,10 @@ def main(cfg: DictConfig): additional_sort_by=cfg.stage_cfg.get("additional_sort_by", None), ) - shard_map_fp = Path(cfg.shards_map_fp) - if not shard_map_fp.exists(): - raise FileNotFoundError(f"Shard map file not found at {str(shard_map_fp.resolve())}") - - shards = list(json.loads(shard_map_fp.read_text()).keys()) - - def shard_iterator(cfg: DictConfig) -> tuple[list[str], bool]: - input_dir = Path(cfg.stage_cfg.data_input_dir) - output_dir = Path(cfg.stage_cfg.output_dir) - - if cfg.stage_cfg.get("train_only", None): - raise ValueError("train_only is not supported for this stage.") - - if "worker" in cfg: - random.seed(cfg.worker) - random.shuffle(shards) - - logger.info(f"Mapping computation over a maximum of {len(shards)} shards") - - out = [] - for sh in shards: - in_fp = input_dir / sh - out_fp = output_dir / f"{sh}.parquet" - out.append((in_fp, out_fp)) - - return out, False - map_over( cfg, read_fn=read_fn, - shard_iterator_fntr=shard_iterator, + shard_iterator_fntr=shard_iterator_by_shard_map, ) diff --git a/src/MEDS_transforms/mapreduce/utils.py b/src/MEDS_transforms/mapreduce/utils.py index e94df8d..299f856 100644 --- a/src/MEDS_transforms/mapreduce/utils.py +++ b/src/MEDS_transforms/mapreduce/utils.py @@ -221,6 +221,44 @@ def rwlock_wrap[ raise e +def shuffle_shards(shards: list[str], cfg: DictConfig) -> list[str]: + """Shuffle the shards in a deterministic, pseudo-random way based on the worker ID in the configuration. + + Args: + shards: The list of shards to shuffle. + cfg: The configuration dictionary for the overall pipeline. Should (possibly) contain the following + keys (some are optional, as marked below): + - `worker` (optional): The worker ID for the MR worker; this is also used to seed the + randomization process. If not provided, the randomization process will be unseeded. + + Returns: + The shuffled list of shards. + + Examples: + >>> cfg = DictConfig({"worker": 1}) + >>> shards = ["train/0", "train/1", "tuning", "held_out"] + >>> shuffle_shards(shards, cfg) + ['train/1', 'held_out', 'tuning', 'train/0'] + >>> cfg = DictConfig({"worker": 2}) + >>> shuffle_shards(shards, cfg) + ['tuning', 'held_out', 'train/1', 'train/0'] + """ + + if "worker" in cfg: + add_str = str(cfg["worker"]) + else: + add_str = str(datetime.now()) + + shard_keys = [] + for shard in shards: + shard_hash = hashlib.sha256((add_str + shard).encode("utf-8")).hexdigest() + if shard_hash in shard_keys: + raise ValueError(f"Hash collision for shard {shard} with add_str {add_str}!") + shard_keys.append(int(shard_hash, 16)) + + return [shard for _, shard in sorted(zip(shard_keys, shards))] + + def shard_iterator( cfg: DictConfig, out_suffix: str = ".parquet", @@ -422,19 +460,7 @@ def shard_iterator( "and relying on `patient_splits.parquet` for filtering." ) - if "worker" in cfg: - add_str = str(cfg.worker) - else: - add_str = str(datetime.now()) - - shard_keys = [] - for shard in shards: - shard_hash = hashlib.sha256((add_str + shard).encode("utf-8")).hexdigest() - if shard_hash in shard_keys: - raise ValueError(f"Hash collision for shard {shard} with add_str {add_str}!") - shard_keys.append(int(shard_hash, 16)) - - shards = [shard for _, shard in sorted(zip(shard_keys, shards))] + shards = shuffle_shards(shards, cfg) logger.info(f"Mapping computation over a maximum of {len(shards)} shards") @@ -446,3 +472,89 @@ def shard_iterator( out.append((in_fp, out_fp)) return out, includes_only_train + + +def shard_iterator_by_shard_map(cfg: DictConfig) -> tuple[list[str], bool]: + """Returns an iterator over shard paths and output paths based on a shard map file, not files on disk. + + Args: + cfg: The configuration dictionary for the overall pipeline. Should contain the following keys: + - `shards_map_fp` (mandatory): The file path to the shards map file. + - `stage_cfg.data_input_dir` (mandatory): The directory containing the input data. + - `stage_cfg.output_dir` (mandatory): The directory to write the output data. + - `worker` (optional): The worker ID for the MR worker; this is also used to seed the + + Returns: + A list of pairs of input and output file paths for each shard, as well as a boolean indicating + whether the shards are only train shards. + + Raises: + ValueError: If the `shards_map_fp` key is not present in the configuration. + FileNotFoundError: If the shard map file is not found at the path specified in the configuration. + ValueError: If the `train_only` key is present in the configuration. + + Examples: + >>> from tempfile import NamedTemporaryFile, TemporaryDirectory + >>> import json + >>> shard_iterator_by_shard_map(DictConfig({})) + Traceback (most recent call last): + ... + ValueError: shards_map_fp must be present in the configuration for a map-based shard iterator. + >>> with NamedTemporaryFile() as tmp: + ... cfg = DictConfig({"shards_map_fp": tmp.name, "stage_cfg": {"train_only": True}}) + ... shard_iterator_by_shard_map(cfg) + Traceback (most recent call last): + ... + ValueError: train_only is not supported for this stage. + >>> with TemporaryDirectory() as tmp: + ... tmp = Path(tmp) + ... shards_map_fp = tmp / "shards_map.json" + ... cfg = DictConfig({"shards_map_fp": shards_map_fp, "stage_cfg": {"train_only": False}}) + ... shard_iterator_by_shard_map(cfg) + Traceback (most recent call last): + ... + FileNotFoundError: Shard map file not found at ...shards_map.json + >>> shards = {"train/0": [1, 2, 3, 4], "train/1": [5, 6, 7], "tuning": [8], "held_out": [9]} + >>> with NamedTemporaryFile() as tmp: + ... _ = Path(tmp.name).write_text(json.dumps(shards)) + ... cfg = DictConfig({ + ... "shards_map_fp": tmp.name, + ... "worker": 1, + ... "stage_cfg": {"data_input_dir": "data", "output_dir": "output"}, + ... }) + ... fps, includes_only_train = shard_iterator_by_shard_map(cfg) + >>> fps # doctest: +NORMALIZE_WHITESPACE + [(PosixPath('data/train/1'), PosixPath('output/train/1.parquet')), + (PosixPath('data/held_out'), PosixPath('output/held_out.parquet')), + (PosixPath('data/tuning'), PosixPath('output/tuning.parquet')), + (PosixPath('data/train/0'), PosixPath('output/train/0.parquet'))] + >>> includes_only_train + False + """ + + if "shards_map_fp" not in cfg: + raise ValueError("shards_map_fp must be present in the configuration for a map-based shard iterator.") + + if cfg.stage_cfg.get("train_only", None): + raise ValueError("train_only is not supported for this stage.") + + shard_map_fp = Path(cfg.shards_map_fp) + if not shard_map_fp.exists(): + raise FileNotFoundError(f"Shard map file not found at {str(shard_map_fp.resolve())}") + + shards = list(json.loads(shard_map_fp.read_text()).keys()) + + input_dir = Path(cfg.stage_cfg.data_input_dir) + output_dir = Path(cfg.stage_cfg.output_dir) + + shards = shuffle_shards(shards, cfg) + + logger.info(f"Mapping computation over a maximum of {len(shards)} shards") + + out = [] + for sh in shards: + in_fp = input_dir / sh + out_fp = output_dir / f"{sh}.parquet" + out.append((in_fp, out_fp)) + + return out, False From dd22eeb3d8075e4a08d60fb97c96298094cbc4af Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 10 Aug 2024 14:59:05 -0400 Subject: [PATCH 36/53] Starting to add integration test for eventual functionality; not yet implemented fully. --- pyproject.toml | 1 + tests/test_reshard_to_split.py | 197 +++++++++++++++++++++++++++++++++ tests/transform_tester_base.py | 2 + 3 files changed, 200 insertions(+) create mode 100644 tests/test_reshard_to_split.py diff --git a/pyproject.toml b/pyproject.toml index 3cb078c..d4f6450 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ MEDS_extract-finalize_MEDS_data = "MEDS_transforms.extract.finalize_MEDS_data:ma ## General MEDS_transform-aggregate_code_metadata = "MEDS_transforms.aggregate_code_metadata:main" MEDS_transform-fit_vocabulary_indices = "MEDS_transforms.fit_vocabulary_indices:main" +MEDS_transform-reshard_to_split = "MEDS_transforms.reshard_to_split:main" ## Filters MEDS_transform-filter_measurements = "MEDS_transforms.filters.filter_measurements:main" MEDS_transform-filter_patients = "MEDS_transforms.filters.filter_patients:main" diff --git a/tests/test_reshard_to_split.py b/tests/test_reshard_to_split.py new file mode 100644 index 0000000..5df21ea --- /dev/null +++ b/tests/test_reshard_to_split.py @@ -0,0 +1,197 @@ +"""Tests the reshard to split script. + +Set the bash env variable `DO_USE_LOCAL_SCRIPTS=1` to use the local py files, rather than the installed +scripts. +""" + + +from .transform_tester_base import RESHARD_TO_SPLIT_SCRIPT, single_stage_transform_tester +from .utils import parse_meds_csvs + +IN_SHARD_0 = """ +patient_id,time,code,numeric_value +68729,,EYE_COLOR//HAZEL, +68729,,HEIGHT,160.3953106166676 +68729,"03/09/1978, 00:00:00",DOB, +68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, +68729,"05/26/2010, 02:30:56",HR,86.0 +68729,"05/26/2010, 02:30:56",TEMP,97.8 +68729,"05/26/2010, 04:51:52",DISCHARGE, +1195293,,EYE_COLOR//BLUE, +1195293,,HEIGHT,164.6868838269085 +1195293,"06/20/1978, 00:00:00",DOB, +1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, +1195293,"06/20/2010, 19:23:52",HR,109.0 +1195293,"06/20/2010, 19:23:52",TEMP,100.0 +1195293,"06/20/2010, 19:25:32",HR,114.1 +1195293,"06/20/2010, 19:25:32",TEMP,100.0 +1195293,"06/20/2010, 19:45:19",HR,119.8 +1195293,"06/20/2010, 19:45:19",TEMP,99.9 +1195293,"06/20/2010, 20:12:31",HR,112.5 +1195293,"06/20/2010, 20:12:31",TEMP,99.8 +1195293,"06/20/2010, 20:24:44",HR,107.7 +1195293,"06/20/2010, 20:24:44",TEMP,100.0 +1195293,"06/20/2010, 20:41:33",HR,107.5 +1195293,"06/20/2010, 20:41:33",TEMP,100.4 +1195293,"06/20/2010, 20:50:04",DISCHARGE, +""" + +IN_SHARD_1 = """ +patient_id,time,code,numeric_value +754281,,EYE_COLOR//BROWN, +754281,,HEIGHT,166.22261567137025 +754281,"12/19/1988, 00:00:00",DOB, +754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, +754281,"01/03/2010, 06:27:59",HR,142.0 +754281,"01/03/2010, 06:27:59",TEMP,99.8 +754281,"01/03/2010, 08:22:13",DISCHARGE, +814703,,EYE_COLOR//HAZEL, +814703,,HEIGHT,156.48559093209357 +814703,"03/28/1976, 00:00:00",DOB, +814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, +814703,"02/05/2010, 05:55:39",HR,170.2 +814703,"02/05/2010, 05:55:39",TEMP,100.1 +814703,"02/05/2010, 07:02:30",DISCHARGE, +""" + +IN_SHARD_2 = """ +patient_id,time,code,numeric_value +239684,,EYE_COLOR//BROWN, +239684,,HEIGHT,175.271115221764 +239684,"12/28/1980, 00:00:00",DOB, +239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, +239684,"05/11/2010, 17:41:51",HR,102.6 +239684,"05/11/2010, 17:41:51",TEMP,96.0 +239684,"05/11/2010, 17:48:48",HR,105.1 +239684,"05/11/2010, 17:48:48",TEMP,96.2 +239684,"05/11/2010, 18:25:35",HR,113.4 +239684,"05/11/2010, 18:25:35",TEMP,95.8 +239684,"05/11/2010, 18:57:18",HR,112.6 +239684,"05/11/2010, 18:57:18",TEMP,95.5 +239684,"05/11/2010, 19:27:19",DISCHARGE, +1500733,,EYE_COLOR//BROWN, +1500733,,HEIGHT,158.60131573580904 +1500733,"07/20/1986, 00:00:00",DOB, +1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, +1500733,"06/03/2010, 14:54:38",HR,91.4 +1500733,"06/03/2010, 14:54:38",TEMP,100.0 +1500733,"06/03/2010, 15:39:49",HR,84.4 +1500733,"06/03/2010, 15:39:49",TEMP,100.3 +1500733,"06/03/2010, 16:20:49",HR,90.1 +1500733,"06/03/2010, 16:20:49",TEMP,100.1 +1500733,"06/03/2010, 16:44:26",DISCHARGE, +""" + +SPLITS = { + "train/0": [239684, 1195293], + "train/1": [68729, 814703], + "tuning/0": [754281], + "held_out/0": [1500733], +} + +WANT_TRAIN_0 = """ +patient_id,time,code,numeric_value +239684,,EYE_COLOR//BROWN, +239684,,HEIGHT,175.271115221764 +239684,"12/28/1980, 00:00:00",DOB, +239684,"05/11/2010, 17:41:51",ADMISSION//CARDIAC, +239684,"05/11/2010, 17:41:51",HR,102.6 +239684,"05/11/2010, 17:41:51",TEMP,96.0 +239684,"05/11/2010, 17:48:48",HR,105.1 +239684,"05/11/2010, 17:48:48",TEMP,96.2 +239684,"05/11/2010, 18:25:35",HR,113.4 +239684,"05/11/2010, 18:25:35",TEMP,95.8 +239684,"05/11/2010, 18:57:18",HR,112.6 +239684,"05/11/2010, 18:57:18",TEMP,95.5 +239684,"05/11/2010, 19:27:19",DISCHARGE, +1195293,,EYE_COLOR//BLUE, +1195293,,HEIGHT,164.6868838269085 +1195293,"06/20/1978, 00:00:00",DOB, +1195293,"06/20/2010, 19:23:52",ADMISSION//CARDIAC, +1195293,"06/20/2010, 19:23:52",HR,109.0 +1195293,"06/20/2010, 19:23:52",TEMP,100.0 +1195293,"06/20/2010, 19:25:32",HR,114.1 +1195293,"06/20/2010, 19:25:32",TEMP,100.0 +1195293,"06/20/2010, 19:45:19",HR,119.8 +1195293,"06/20/2010, 19:45:19",TEMP,99.9 +1195293,"06/20/2010, 20:12:31",HR,112.5 +1195293,"06/20/2010, 20:12:31",TEMP,99.8 +1195293,"06/20/2010, 20:24:44",HR,107.7 +1195293,"06/20/2010, 20:24:44",TEMP,100.0 +1195293,"06/20/2010, 20:41:33",HR,107.5 +1195293,"06/20/2010, 20:41:33",TEMP,100.4 +1195293,"06/20/2010, 20:50:04",DISCHARGE, +""" + +WANT_TRAIN_1 = """ +patient_id,time,code,numeric_value +68729,,EYE_COLOR//HAZEL, +68729,,HEIGHT,160.3953106166676 +68729,"03/09/1978, 00:00:00",DOB, +68729,"05/26/2010, 02:30:56",ADMISSION//PULMONARY, +68729,"05/26/2010, 02:30:56",HR,86.0 +68729,"05/26/2010, 02:30:56",TEMP,97.8 +68729,"05/26/2010, 04:51:52",DISCHARGE, +814703,,EYE_COLOR//HAZEL, +814703,,HEIGHT,156.48559093209357 +814703,"03/28/1976, 00:00:00",DOB, +814703,"02/05/2010, 05:55:39",ADMISSION//ORTHOPEDIC, +814703,"02/05/2010, 05:55:39",HR,170.2 +814703,"02/05/2010, 05:55:39",TEMP,100.1 +814703,"02/05/2010, 07:02:30",DISCHARGE, +""" + +WANT_TUNING_0 = """ +patient_id,time,code,numeric_value +754281,,EYE_COLOR//BROWN, +754281,,HEIGHT,166.22261567137025 +754281,"12/19/1988, 00:00:00",DOB, +754281,"01/03/2010, 06:27:59",ADMISSION//PULMONARY, +754281,"01/03/2010, 06:27:59",HR,142.0 +754281,"01/03/2010, 06:27:59",TEMP,99.8 +754281,"01/03/2010, 08:22:13",DISCHARGE, +""" + +WANT_HELD_OUT_0 = """ +patient_id,time,code,numeric_value +1500733,,EYE_COLOR//BROWN, +1500733,,HEIGHT,158.60131573580904 +1500733,"07/20/1986, 00:00:00",DOB, +1500733,"06/03/2010, 14:54:38",ADMISSION//ORTHOPEDIC, +1500733,"06/03/2010, 14:54:38",HR,91.4 +1500733,"06/03/2010, 14:54:38",TEMP,100.0 +1500733,"06/03/2010, 15:39:49",HR,84.4 +1500733,"06/03/2010, 15:39:49",TEMP,100.3 +1500733,"06/03/2010, 16:20:49",HR,90.1 +1500733,"06/03/2010, 16:20:49",TEMP,100.1 +1500733,"06/03/2010, 16:44:26",DISCHARGE, +""" + +WANT_SHARDS = parse_meds_csvs( + { + "train/0": WANT_TRAIN_0, + "train/1": WANT_TRAIN_1, + "tuning/0": WANT_TUNING_0, + "held_out/0": WANT_HELD_OUT_0, + } +) + +IN_SHARDS = parse_meds_csvs( + { + "train/0": IN_SHARD_0, + "train/1": IN_SHARD_1, + "tuning/0": IN_SHARD_0, + "held_out/0": IN_SHARD_0, + } +) + + +def test_reshard_to_split(): + single_stage_transform_tester( + transform_script=RESHARD_TO_SPLIT_SCRIPT, + stage_name="reshard_to_split", + transform_stage_kwargs={}, + want_outputs=WANT_SHARDS, + input_shards=IN_SHARDS, + input_splits=SPLITS, + ) diff --git a/tests/transform_tester_base.py b/tests/transform_tester_base.py index 80e4f4b..3de8a8c 100644 --- a/tests/transform_tester_base.py +++ b/tests/transform_tester_base.py @@ -26,6 +26,7 @@ if os.environ.get("DO_USE_LOCAL_SCRIPTS", "0") == "1": # Root Source FIT_VOCABULARY_INDICES_SCRIPT = code_root / "fit_vocabulary_indices.py" + RESHARD_TO_SPLIT_SCRIPT = code_root / "reshard_to_split.py" # Filters FILTER_MEASUREMENTS_SCRIPT = filters_root / "filter_measurements.py" @@ -41,6 +42,7 @@ else: # Root Source FIT_VOCABULARY_INDICES_SCRIPT = "MEDS_transform-fit_vocabulary_indices" + RESHARD_TO_SPLIT_SCRIPT = "MEDS_transform-reshard_to_split" # Filters FILTER_MEASUREMENTS_SCRIPT = "MEDS_transform-filter_measurements" From 1b2d0230a45d6e8077988ba284a1005636482e25 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 10 Aug 2024 17:08:10 -0400 Subject: [PATCH 37/53] Improved compliance by removing creation of shards.json file and adding splits parquet file. --- .../transforms/tokenization.py | 19 +++++------- tests/transform_tester_base.py | 31 ++++++++++++++++--- 2 files changed, 35 insertions(+), 15 deletions(-) diff --git a/src/MEDS_transforms/transforms/tokenization.py b/src/MEDS_transforms/transforms/tokenization.py index 0ef5704..5e5e0e6 100644 --- a/src/MEDS_transforms/transforms/tokenization.py +++ b/src/MEDS_transforms/transforms/tokenization.py @@ -9,8 +9,6 @@ columns of concern here thus are `patient_id`, `time`, `code`, `numeric_value`. """ -import json -import random from pathlib import Path import hydra @@ -19,7 +17,7 @@ from omegaconf import DictConfig, OmegaConf from MEDS_transforms import PREPROCESS_CONFIG_YAML -from MEDS_transforms.mapreduce.mapper import rwlock_wrap +from MEDS_transforms.mapreduce.utils import rwlock_wrap, shard_iterator from MEDS_transforms.utils import hydra_loguru_init, write_lazyframe SECONDS_PER_MINUTE = 60.0 @@ -230,18 +228,17 @@ def main(cfg: DictConfig): f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" ) - input_dir = Path(cfg.stage_cfg.data_input_dir) output_dir = Path(cfg.stage_cfg.output_dir) + shards_single_output, include_only_train = shard_iterator(cfg) - shards = json.loads((Path(cfg.input_dir) / "splits.json").read_text()) + if include_only_train: + raise ValueError("Not supported for this stage.") - patient_splits = list(shards.keys()) - random.shuffle(patient_splits) + for in_fp, out_fp in shards_single_output: + sharded_path = out_fp.relative_to(output_dir) - for sp in patient_splits: - in_fp = input_dir / f"{sp}.parquet" - schema_out_fp = output_dir / "schemas" / f"{sp}.parquet" - event_seq_out_fp = output_dir / "event_seqs" / f"{sp}.parquet" + schema_out_fp = output_dir / "schemas" / sharded_path + event_seq_out_fp = output_dir / "event_seqs" / sharded_path logger.info(f"Tokenizing {str(in_fp.resolve())} into schemas at {str(schema_out_fp.resolve())}") diff --git a/tests/transform_tester_base.py b/tests/transform_tester_base.py index 80e4f4b..945a2f4 100644 --- a/tests/transform_tester_base.py +++ b/tests/transform_tester_base.py @@ -7,6 +7,7 @@ import json import os import tempfile +from collections import defaultdict from io import StringIO from pathlib import Path @@ -56,13 +57,19 @@ # Test MEDS data (inputs) -SPLITS = { +SHARDS = { "train/0": [239684, 1195293], "train/1": [68729, 814703], "tuning/0": [754281], "held_out/0": [1500733], } +SPLITS = { + "train": [239684, 1195293, 68729, 814703], + "tuning": [754281], + "held_out": [1500733], +} + MEDS_TRAIN_0 = """ patient_id,time,code,numeric_value 239684,,EYE_COLOR//BROWN, @@ -292,6 +299,8 @@ def single_stage_transform_tester( do_pass_stage_name: bool = False, file_suffix: str = ".parquet", do_use_config_yaml: bool = False, + input_shards_map: dict[str, list[int]] | None = None, + input_splits_map: dict[str, list[int]] | None = None, ): with tempfile.TemporaryDirectory() as d: MEDS_dir = Path(d) / "MEDS_cohort" @@ -306,9 +315,23 @@ def single_stage_transform_tester( MEDS_metadata_dir.mkdir(parents=True) cohort_dir.mkdir(parents=True) - # Write the splits - splits_fp = MEDS_dir / "splits.json" - splits_fp.write_text(json.dumps(SPLITS)) + # Write the shards map + if input_shards_map is None: + input_shards_map = SHARDS + + shards_fp = MEDS_metadata_dir / ".shards.json" + shards_fp.write_text(json.dumps(input_shards_map)) + + # Write the splits parquet file + if input_splits_map is None: + input_splits_map = SPLITS + input_splits_as_df = defaultdict(list) + for split_name, patient_ids in input_splits_map.items(): + input_splits_as_df["patient_id"].extend(patient_ids) + input_splits_as_df["split"].extend([split_name] * len(patient_ids)) + input_splits_df = pl.DataFrame(input_splits_as_df) + input_splits_fp = MEDS_metadata_dir / "patient_splits.parquet" + input_splits_df.write_parquet(input_splits_fp, use_pyarrow=True) if input_shards is None: input_shards = MEDS_SHARDS From 6ec7a8fe327b111b83f1d75ee48123851099151b Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 10 Aug 2024 17:14:45 -0400 Subject: [PATCH 38/53] Updated test. --- tests/test_reshard_to_split.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/test_reshard_to_split.py b/tests/test_reshard_to_split.py index 5df21ea..670d438 100644 --- a/tests/test_reshard_to_split.py +++ b/tests/test_reshard_to_split.py @@ -8,6 +8,12 @@ from .transform_tester_base import RESHARD_TO_SPLIT_SCRIPT, single_stage_transform_tester from .utils import parse_meds_csvs +IN_SHARDS_MAP = { + "0": [68729, 1195293], + "1": [754281, 814703], + "2": [239684, 1500733], +} + IN_SHARD_0 = """ patient_id,time,code,numeric_value 68729,,EYE_COLOR//HAZEL, @@ -83,10 +89,9 @@ """ SPLITS = { - "train/0": [239684, 1195293], - "train/1": [68729, 814703], - "tuning/0": [754281], - "held_out/0": [1500733], + "train": [239684, 1195293, 68729, 814703], + "tuning": [754281], + "held_out": [1500733], } WANT_TRAIN_0 = """ @@ -193,5 +198,6 @@ def test_reshard_to_split(): transform_stage_kwargs={}, want_outputs=WANT_SHARDS, input_shards=IN_SHARDS, - input_splits=SPLITS, + input_shards_map=IN_SHARDS_MAP, + input_splits_map=SPLITS, ) From dd9ba08f7fe9e633d19fbfbbded8df1fd5302f89 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 10 Aug 2024 17:24:21 -0400 Subject: [PATCH 39/53] started file; very much not ready yet. --- src/MEDS_transforms/configs/preprocess.yaml | 1 + .../stage_configs/reshard_to_split.yaml | 1 + src/MEDS_transforms/reshard_to_split.py | 97 +++++++++++++++++++ 3 files changed, 99 insertions(+) create mode 100644 src/MEDS_transforms/configs/stage_configs/reshard_to_split.yaml create mode 100644 src/MEDS_transforms/reshard_to_split.py diff --git a/src/MEDS_transforms/configs/preprocess.yaml b/src/MEDS_transforms/configs/preprocess.yaml index 6b2dc8b..ea509cd 100644 --- a/src/MEDS_transforms/configs/preprocess.yaml +++ b/src/MEDS_transforms/configs/preprocess.yaml @@ -1,6 +1,7 @@ defaults: - pipeline - stage_configs: + - reshard_to_split - filter_patients - add_time_derived_measurements - count_code_occurrences diff --git a/src/MEDS_transforms/configs/stage_configs/reshard_to_split.yaml b/src/MEDS_transforms/configs/stage_configs/reshard_to_split.yaml new file mode 100644 index 0000000..c54e254 --- /dev/null +++ b/src/MEDS_transforms/configs/stage_configs/reshard_to_split.yaml @@ -0,0 +1 @@ +n_patients_per_shard: 50000 diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py new file mode 100644 index 0000000..05fefa8 --- /dev/null +++ b/src/MEDS_transforms/reshard_to_split.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python +"""Utilities for re-sharding a MEDS cohort to subsharded splits.""" + +import time +from collections import defaultdict +from collections.abc import Callable, Sequence +from datetime import datetime +from enum import StrEnum +from pathlib import Path +from typing import NamedTuple + +import hydra +import polars as pl +import polars.selectors as cs +from loguru import logger +from omegaconf import DictConfig, ListConfig, OmegaConf + +from MEDS_transforms import PREPROCESS_CONFIG_YAML +from MEDS_transforms.extract.split_and_shard_patients import shard_patients +from MEDS_transforms.mapreduce.mapper import map_over +from MEDS_transforms.utils import write_lazyframe + + +@hydra.main( + version_base=None, config_path=str(PREPROCESS_CONFIG_YAML.parent), config_name=PREPROCESS_CONFIG_YAML.stem +) +def main(cfg: DictConfig): + """Re-shard a MEDS cohort to in a manner that subdivides patient splits.""" + + hydra_loguru_init() + + logger.info( + f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" + f"Stage: {cfg.stage}\n\n" + f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" + ) + + splits_file = Path(cfg.input_dir) / "metadata" / "patient_splits.parquet" + splits_df = pl.read_parquet(splits_file, use_pyarrow=True) + splits_map = defaultdict(list) + for (pt_id, sp) in splits_df.iterrows(): + splits_map[sp].append(pt_id) + + + new_sharded_splits = shard_patients( + patients = splits_df["patient_id"].to_numpy() + n_patients_per_shard = cfg.stage_cfg.n_patients_per_shard, + external_splits = splits_map, + split_fracs_dict = {}, + + external_splits: dict[str, Sequence[SUBJ_ID_T]] | None = None, + split_fracs_dict: dict[str, float] = {"train": 0.8, "tuning": 0.1, "held_out": 0.1}, + seed: int = 1, + + + + output_dir = Path(cfg.stage_cfg.output_dir) + shards_single_output, include_only_train = shard_iterator(cfg) + + if include_only_train: + raise ValueError("Not supported for this stage.") + + for in_fp, out_fp in shards_single_output: + sharded_path = out_fp.relative_to(output_dir) + + schema_out_fp = output_dir / "schemas" / sharded_path + event_seq_out_fp = output_dir / "event_seqs" / sharded_path + + logger.info(f"Tokenizing {str(in_fp.resolve())} into schemas at {str(schema_out_fp.resolve())}") + + rwlock_wrap( + in_fp, + schema_out_fp, + pl.scan_parquet, + write_lazyframe, + extract_statics_and_schema, + do_return=False, + do_overwrite=cfg.do_overwrite, + ) + + logger.info(f"Tokenizing {str(in_fp.resolve())} into event_seqs at {str(event_seq_out_fp.resolve())}") + + rwlock_wrap( + in_fp, + event_seq_out_fp, + pl.scan_parquet, + write_lazyframe, + extract_seq_of_patient_events, + do_return=False, + do_overwrite=cfg.do_overwrite, + ) + + logger.info(f"Done with {cfg.stage}") + + +if __name__ == "__main__": + main() From 4cc65b4bdcd0deabe708e153ba8afe2303775134 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 10 Aug 2024 17:26:08 -0400 Subject: [PATCH 40/53] Improved some documentation and error handling for splitting patients. --- .../extract/split_and_shard_patients.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/MEDS_transforms/extract/split_and_shard_patients.py b/src/MEDS_transforms/extract/split_and_shard_patients.py index 753c01f..61f1726 100755 --- a/src/MEDS_transforms/extract/split_and_shard_patients.py +++ b/src/MEDS_transforms/extract/split_and_shard_patients.py @@ -19,7 +19,7 @@ def shard_patients[ patients: np.ndarray, n_patients_per_shard: int = 50000, external_splits: dict[str, Sequence[SUBJ_ID_T]] | None = None, - split_fracs_dict: dict[str, float] = {"train": 0.8, "tuning": 0.1, "held_out": 0.1}, + split_fracs_dict: dict[str, float] | None = {"train": 0.8, "tuning": 0.1, "held_out": 0.1}, seed: int = 1, ) -> dict[str, list[SUBJ_ID_T]]: """Shard a list of patients, nested within train/tuning/held-out splits. @@ -41,7 +41,8 @@ def shard_patients[ tasks or test cases (e.g., prospective tests); training patients should often still be included in the IID splits to maximize the amount of data that can be used for training. split_fracs_dict: A dictionary mapping the split name to the fraction of patients to include in that - split. Defaults to 80% train, 10% tuning, 10% held-out. + split. Defaults to 80% train, 10% tuning, 10% held-out. This can be None or empty only when + external splits fully specify the population. seed: The random seed to use for shuffling the patients before seeding and sharding. This is useful for ensuring reproducibility. @@ -80,15 +81,12 @@ def shard_patients[ ... 'train': np.array([1, 2, 3, 4, 5, 6], dtype=int), ... 'test': np.array([7, 8, 9, 10], dtype=int), ... } - >>> shard_patients(patients, 6, external_splits) + >>> shard_patients(patients, 6, external_splits, split_fracs_dict=None) {'train/0': [1, 2, 3, 4, 5, 6], 'test/0': [7, 8, 9, 10]} >>> shard_patients(patients, 3, external_splits) {'train/0': [5, 1, 3], 'train/1': [2, 6, 4], 'test/0': [10, 7], 'test/1': [8, 9]} """ - if sum(split_fracs_dict.values()) != 1: - raise ValueError("The sum of the split fractions must be equal to 1.") - if external_splits is None: external_splits = {} else: @@ -111,6 +109,8 @@ def shard_patients[ rng = np.random.default_rng(seed) if n_patients := len(patient_ids_to_split): + if sum(split_fracs_dict.values()) != 1: + raise ValueError("The sum of the split fractions must be equal to 1.") split_names_idx = rng.permutation(len(split_fracs_dict)) split_names = np.array(list(split_fracs_dict.keys()))[split_names_idx] split_fracs = np.array([split_fracs_dict[k] for k in split_names]) From d70128f5fe0a1b5a9cf26b942f07b8a0f331dcfe Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 10 Aug 2024 18:00:23 -0400 Subject: [PATCH 41/53] Implemented preliminary version of reshard to split. --- src/MEDS_transforms/reshard_to_split.py | 155 ++++++++++++++++-------- 1 file changed, 105 insertions(+), 50 deletions(-) diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index 05fefa8..34fde10 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -1,24 +1,20 @@ #!/usr/bin/env python """Utilities for re-sharding a MEDS cohort to subsharded splits.""" -import time +import json from collections import defaultdict -from collections.abc import Callable, Sequence -from datetime import datetime -from enum import StrEnum from pathlib import Path -from typing import NamedTuple import hydra import polars as pl -import polars.selectors as cs from loguru import logger -from omegaconf import DictConfig, ListConfig, OmegaConf +from omegaconf import DictConfig from MEDS_transforms import PREPROCESS_CONFIG_YAML -from MEDS_transforms.extract.split_and_shard_patients import shard_patients -from MEDS_transforms.mapreduce.mapper import map_over -from MEDS_transforms.utils import write_lazyframe +from MEDS_transforms.extract.split_and_shard_patients import shard_iterator_by_shard_map, shard_patients +from MEDS_transforms.mapreduce.mapper import identity_fn, read_and_filter_fntr +from MEDS_transforms.mapreduce.utils import rwlock_wrap, shard_iterator +from MEDS_transforms.utils import stage_init, write_lazyframe @hydra.main( @@ -27,68 +23,127 @@ def main(cfg: DictConfig): """Re-shard a MEDS cohort to in a manner that subdivides patient splits.""" - hydra_loguru_init() - - logger.info( - f"Running with config:\n{OmegaConf.to_yaml(cfg)}\n" - f"Stage: {cfg.stage}\n\n" - f"Stage config:\n{OmegaConf.to_yaml(cfg.stage_cfg)}" - ) + stage_init(cfg) splits_file = Path(cfg.input_dir) / "metadata" / "patient_splits.parquet" splits_df = pl.read_parquet(splits_file, use_pyarrow=True) splits_map = defaultdict(list) - for (pt_id, sp) in splits_df.iterrows(): + for pt_id, sp in splits_df.iterrows(): splits_map[sp].append(pt_id) - new_sharded_splits = shard_patients( - patients = splits_df["patient_id"].to_numpy() - n_patients_per_shard = cfg.stage_cfg.n_patients_per_shard, - external_splits = splits_map, - split_fracs_dict = {}, + patients=splits_df["patient_id"].to_numpy(), + n_patients_per_shard=cfg.stage_cfg.n_patients_per_shard, + external_splits=splits_map, + split_fracs_dict=None, + seed=cfg.get("seed", 1), + ) - external_splits: dict[str, Sequence[SUBJ_ID_T]] | None = None, - split_fracs_dict: dict[str, float] = {"train": 0.8, "tuning": 0.1, "held_out": 0.1}, - seed: int = 1, + output_dir = Path(cfg.stage_cfg.output_dir) + + # Write shards to file + if "shards_map_fp" in cfg: + shards_fp = Path(cfg.shards_map_fp) + else: + shards_fp = output_dir / ".shards.json" + if shards_fp.is_file(): + if cfg.do_overwrite: + logger.warning(f"Overwriting {str(shards_fp.resolve())}") + shards_fp.unlink() + else: + raise FileExistsError(f"{str(shards_fp.resolve())} already exists.") + shards_fp.write_text(json.dumps(new_sharded_splits)) + data_input_dir = Path(cfg.stage_cfg.data_input_dir) output_dir = Path(cfg.stage_cfg.output_dir) - shards_single_output, include_only_train = shard_iterator(cfg) + orig_shards_iter, include_only_train = shard_iterator(cfg, out_suffix="") if include_only_train: - raise ValueError("Not supported for this stage.") - - for in_fp, out_fp in shards_single_output: - sharded_path = out_fp.relative_to(output_dir) + raise ValueError("This stage does not support include_only_train=True") - schema_out_fp = output_dir / "schemas" / sharded_path - event_seq_out_fp = output_dir / "event_seqs" / sharded_path + orig_shards_iter = [(in_fp, out_fp.relative_to(output_dir)) for in_fp, out_fp in orig_shards_iter] - logger.info(f"Tokenizing {str(in_fp.resolve())} into schemas at {str(schema_out_fp.resolve())}") + # Here, we modify the config in a hacky way to make the right shard mapping + cfg.shards_map_fp = str(shards_fp.resolve()) + cfg.stage_cfg.data_input_dir = str(output_dir.resolve()) - rwlock_wrap( - in_fp, - schema_out_fp, - pl.scan_parquet, + new_shards_iter, include_only_train = shard_iterator_by_shard_map(cfg) + if include_only_train: + raise ValueError("This stage does not support include_only_train=True") + + cfg.stage_cfg.data_input_dir = str(data_input_dir.resolve()) + new_shards_iter = [ + (in_fp.relative_to(data_input_dir).with_suffix(""), out_fp.with_suffix("")) + for in_fp, out_fp in new_shards_iter + ] + + # Step 1: Sub-sharding stage + logger.info("Starting sub-sharding") + subshard_fps = defaultdict(list) + + for in_fp, orig_shard_name in orig_shards_iter: + for subshard_name, out_dir in new_shards_iter: + out_fp = out_dir / f"{orig_shard_name}.parquet" + subshard_fps[subshard_name].append(out_fp) + patients = new_sharded_splits[subshard_name] + + if not patients: + raise ValueError(f"No patients found for {subshard_name}!") + + logger.info(f"Sub-sharding {str(in_fp.resolve())} into {str(out_fp.resolve())}") + + rwlock_wrap( + in_fp, + out_fp, + read_and_filter_fntr(pl.col("patient_id").is_in(patients), pl.scan_parquet), + write_lazyframe, + identity_fn, + do_return=False, + do_overwrite=cfg.do_overwrite, + ) + + logger.info("Merging sub-shards") + for subshard_name, subshard_dir in new_shards_iter: + out_fp = subshard_dir.with_suffix(".parquet") + if out_fp.is_file(): + if cfg.do_overwrite: + logger.warning(f"Overwriting {str(out_fp.resolve())}") + out_fp.unlink() + else: + raise FileExistsError(f"{str(out_fp.resolve())} already exists.") + + logger.info(f"Merging {subshard_dir}/**/*.parquet into {str(out_fp.resolve())}") + + if not subshard_fps: + raise ValueError(f"No subshards found for {subshard_name}!") + + in_dir = subshard_dir + in_fps = subshard_fps[subshard_name] + + def read_fn(fp: Path) -> pl.LazyFrame: + return pl.concat([pl.scan_parquet(fp, glob=False) for fp in in_fps], how="diagonal_relaxed").sort( + by=["patient_id", "time"], maintain_order=True, multithreaded=False + ) + + logger.info(f"Merging files to {str(out_fp.resolve())}") + result_computed, _ = rwlock_wrap( + in_dir, + out_fp, + read_fn, write_lazyframe, - extract_statics_and_schema, + identity_fn, do_return=False, do_overwrite=cfg.do_overwrite, ) - logger.info(f"Tokenizing {str(in_fp.resolve())} into event_seqs at {str(event_seq_out_fp.resolve())}") - - rwlock_wrap( - in_fp, - event_seq_out_fp, - pl.scan_parquet, - write_lazyframe, - extract_seq_of_patient_events, - do_return=False, - do_overwrite=cfg.do_overwrite, - ) + if result_computed: + logger.info(f"Cleaning up subsharded files in {str(subshard_dir.resolve())}/*.") + for fp in in_fps: + if fp.exists(): + fp.unlink() + subshard_dir.rmdir() logger.info(f"Done with {cfg.stage}") From 0db3a23f8a532dda5058ff9d3d0ac31d6c3a3862 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 10 Aug 2024 18:13:52 -0400 Subject: [PATCH 42/53] fixed some more minor errors; test fails for content reasons now. --- .../stage_configs/reshard_to_split.yaml | 3 +- src/MEDS_transforms/mapreduce/utils.py | 4 +-- src/MEDS_transforms/reshard_to_split.py | 34 +++++++++---------- tests/test_reshard_to_split.py | 2 +- 4 files changed, 21 insertions(+), 22 deletions(-) diff --git a/src/MEDS_transforms/configs/stage_configs/reshard_to_split.yaml b/src/MEDS_transforms/configs/stage_configs/reshard_to_split.yaml index c54e254..16dc505 100644 --- a/src/MEDS_transforms/configs/stage_configs/reshard_to_split.yaml +++ b/src/MEDS_transforms/configs/stage_configs/reshard_to_split.yaml @@ -1 +1,2 @@ -n_patients_per_shard: 50000 +reshard_to_split: + n_patients_per_shard: 50000 diff --git a/src/MEDS_transforms/mapreduce/utils.py b/src/MEDS_transforms/mapreduce/utils.py index 299f856..87ce835 100644 --- a/src/MEDS_transforms/mapreduce/utils.py +++ b/src/MEDS_transforms/mapreduce/utils.py @@ -179,7 +179,7 @@ def rwlock_wrap[ if do_return: return True, read_fn(out_fp) else: - return True + return True, None cache_directory = out_fp.parent / f".{out_fp.stem}_cache" cache_directory.mkdir(exist_ok=True, parents=True) @@ -213,7 +213,7 @@ def rwlock_wrap[ if do_return: return True, df else: - return True + return True, None except Exception as e: logger.warning(f"Clearing lock due to Exception {e} at {lock_fp} after {datetime.now() - st_time}") diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index 34fde10..9ddfbe7 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -11,9 +11,9 @@ from omegaconf import DictConfig from MEDS_transforms import PREPROCESS_CONFIG_YAML -from MEDS_transforms.extract.split_and_shard_patients import shard_iterator_by_shard_map, shard_patients +from MEDS_transforms.extract.split_and_shard_patients import shard_patients from MEDS_transforms.mapreduce.mapper import identity_fn, read_and_filter_fntr -from MEDS_transforms.mapreduce.utils import rwlock_wrap, shard_iterator +from MEDS_transforms.mapreduce.utils import rwlock_wrap, shard_iterator, shuffle_shards from MEDS_transforms.utils import stage_init, write_lazyframe @@ -28,7 +28,7 @@ def main(cfg: DictConfig): splits_file = Path(cfg.input_dir) / "metadata" / "patient_splits.parquet" splits_df = pl.read_parquet(splits_file, use_pyarrow=True) splits_map = defaultdict(list) - for pt_id, sp in splits_df.iterrows(): + for pt_id, sp in splits_df.iter_rows(): splits_map[sp].append(pt_id) new_sharded_splits = shard_patients( @@ -56,7 +56,6 @@ def main(cfg: DictConfig): shards_fp.write_text(json.dumps(new_sharded_splits)) - data_input_dir = Path(cfg.stage_cfg.data_input_dir) output_dir = Path(cfg.stage_cfg.output_dir) orig_shards_iter, include_only_train = shard_iterator(cfg, out_suffix="") @@ -65,19 +64,8 @@ def main(cfg: DictConfig): orig_shards_iter = [(in_fp, out_fp.relative_to(output_dir)) for in_fp, out_fp in orig_shards_iter] - # Here, we modify the config in a hacky way to make the right shard mapping - cfg.shards_map_fp = str(shards_fp.resolve()) - cfg.stage_cfg.data_input_dir = str(output_dir.resolve()) - - new_shards_iter, include_only_train = shard_iterator_by_shard_map(cfg) - if include_only_train: - raise ValueError("This stage does not support include_only_train=True") - - cfg.stage_cfg.data_input_dir = str(data_input_dir.resolve()) - new_shards_iter = [ - (in_fp.relative_to(data_input_dir).with_suffix(""), out_fp.with_suffix("")) - for in_fp, out_fp in new_shards_iter - ] + new_shards = shuffle_shards(list(new_sharded_splits.keys()), cfg) + new_shards_iter = [(shard_name, output_dir / shard_name) for shard_name in new_shards] # Step 1: Sub-sharding stage logger.info("Starting sub-sharding") @@ -143,7 +131,17 @@ def read_fn(fp: Path) -> pl.LazyFrame: for fp in in_fps: if fp.exists(): fp.unlink() - subshard_dir.rmdir() + try: + for root, dirs, files in subshard_dir.walk(top_down=False): + walked_dir = root.relative_to(subshard_dir) + if files: + raise FileExistsError(f"Files found in {walked_dir} after cleanup!: {files}") + for d in dirs: + (root / d).rmdir() + subshard_dir.rmdir() + except OSError as e: + contents_str = "\n".join([str(f) for f in subshard_dir.iterdir()]) + raise ValueError(f"Could not remove {str(subshard_dir)}. Contents:\n{contents_str}") from e logger.info(f"Done with {cfg.stage}") diff --git a/tests/test_reshard_to_split.py b/tests/test_reshard_to_split.py index 670d438..cdf399e 100644 --- a/tests/test_reshard_to_split.py +++ b/tests/test_reshard_to_split.py @@ -195,7 +195,7 @@ def test_reshard_to_split(): single_stage_transform_tester( transform_script=RESHARD_TO_SPLIT_SCRIPT, stage_name="reshard_to_split", - transform_stage_kwargs={}, + transform_stage_kwargs={"n_patients_per_shard": 2}, want_outputs=WANT_SHARDS, input_shards=IN_SHARDS, input_shards_map=IN_SHARDS_MAP, From 829cede5369723b27992d9a219a1e7c5ed206628 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 10 Aug 2024 18:16:36 -0400 Subject: [PATCH 43/53] code was actually fine; it was a test error. --- src/MEDS_transforms/reshard_to_split.py | 4 +++- tests/test_reshard_to_split.py | 7 +++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index 9ddfbe7..72e9766 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -80,7 +80,9 @@ def main(cfg: DictConfig): if not patients: raise ValueError(f"No patients found for {subshard_name}!") - logger.info(f"Sub-sharding {str(in_fp.resolve())} into {str(out_fp.resolve())}") + logger.info( + f"Sub-sharding {str(in_fp.resolve())} to {len(patients)} patients in {str(out_fp.resolve())}" + ) rwlock_wrap( in_fp, diff --git a/tests/test_reshard_to_split.py b/tests/test_reshard_to_split.py index cdf399e..3eeecfe 100644 --- a/tests/test_reshard_to_split.py +++ b/tests/test_reshard_to_split.py @@ -183,10 +183,9 @@ IN_SHARDS = parse_meds_csvs( { - "train/0": IN_SHARD_0, - "train/1": IN_SHARD_1, - "tuning/0": IN_SHARD_0, - "held_out/0": IN_SHARD_0, + "0": IN_SHARD_0, + "1": IN_SHARD_1, + "2": IN_SHARD_2, } ) From e0d1ecbc7dd6850d9a04453fa507d2521d270478 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 10 Aug 2024 19:30:56 -0400 Subject: [PATCH 44/53] Corrected a small parallelism issue in reshard. --- src/MEDS_transforms/reshard_to_split.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index 72e9766..ed3b96a 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -2,6 +2,7 @@ """Utilities for re-sharding a MEDS cohort to subsharded splits.""" import json +import time from collections import defaultdict from pathlib import Path @@ -96,21 +97,22 @@ def main(cfg: DictConfig): logger.info("Merging sub-shards") for subshard_name, subshard_dir in new_shards_iter: + in_dir = subshard_dir + in_fps = subshard_fps[subshard_name] out_fp = subshard_dir.with_suffix(".parquet") - if out_fp.is_file(): - if cfg.do_overwrite: - logger.warning(f"Overwriting {str(out_fp.resolve())}") - out_fp.unlink() - else: - raise FileExistsError(f"{str(out_fp.resolve())} already exists.") logger.info(f"Merging {subshard_dir}/**/*.parquet into {str(out_fp.resolve())}") if not subshard_fps: raise ValueError(f"No subshards found for {subshard_name}!") - in_dir = subshard_dir - in_fps = subshard_fps[subshard_name] + if out_fp.is_file(): + logger.info(f"Output file {str(out_fp.resolve())} already exists. Skipping.") + continue + + while not (all(fp.is_file() for fp in in_fps) or out_fp.is_file()): + logger.info("Waiting to begin merging for all sub-shard files to be written...") + time.sleep(cfg.polling_time) def read_fn(fp: Path) -> pl.LazyFrame: return pl.concat([pl.scan_parquet(fp, glob=False) for fp in in_fps], how="diagonal_relaxed").sort( From ac1587ee863ba1ac415982acad2e9ae59089711a Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 10 Aug 2024 22:39:14 -0400 Subject: [PATCH 45/53] Made reshard not error out if the new split is identical to the shards on disk. --- src/MEDS_transforms/reshard_to_split.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index ed3b96a..29bc176 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -53,7 +53,9 @@ def main(cfg: DictConfig): logger.warning(f"Overwriting {str(shards_fp.resolve())}") shards_fp.unlink() else: - raise FileExistsError(f"{str(shards_fp.resolve())} already exists.") + old_shards_map = json.loads(shards_fp.read_text()) + if (old_shards_map != new_sharded_splits): + raise FileExistsError(f"{str(shards_fp.resolve())} already exists and shard map differs.") shards_fp.write_text(json.dumps(new_sharded_splits)) From 2d1c4cf5db17a19064b96b6bccc1788bd871d76f Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sat, 10 Aug 2024 22:39:44 -0400 Subject: [PATCH 46/53] Fixed formatting. --- src/MEDS_transforms/reshard_to_split.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index 29bc176..a730b64 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -54,7 +54,7 @@ def main(cfg: DictConfig): shards_fp.unlink() else: old_shards_map = json.loads(shards_fp.read_text()) - if (old_shards_map != new_sharded_splits): + if old_shards_map != new_sharded_splits: raise FileExistsError(f"{str(shards_fp.resolve())} already exists and shard map differs.") shards_fp.write_text(json.dumps(new_sharded_splits)) From 7b02df8a80db2ca7a178336baffe36ba7b38fa3c Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 11 Aug 2024 07:58:20 -0400 Subject: [PATCH 47/53] Move more to locked computation. --- src/MEDS_transforms/reshard_to_split.py | 48 ++++++++++++------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index a730b64..c8ba7c3 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -101,38 +101,27 @@ def main(cfg: DictConfig): for subshard_name, subshard_dir in new_shards_iter: in_dir = subshard_dir in_fps = subshard_fps[subshard_name] - out_fp = subshard_dir.with_suffix(".parquet") - - logger.info(f"Merging {subshard_dir}/**/*.parquet into {str(out_fp.resolve())}") - - if not subshard_fps: + if not in_fps: raise ValueError(f"No subshards found for {subshard_name}!") - if out_fp.is_file(): - logger.info(f"Output file {str(out_fp.resolve())} already exists. Skipping.") - continue + out_fp = subshard_dir.with_suffix(".parquet") + + def read_fn(in_dir: Path) -> pl.LazyFrame: + while not (all(fp.is_file() for fp in in_fps) or out_fp.is_file()): + logger.info("Waiting to begin merging for all sub-shard files to be written...") + time.sleep(cfg.polling_time) - while not (all(fp.is_file() for fp in in_fps) or out_fp.is_file()): - logger.info("Waiting to begin merging for all sub-shard files to be written...") - time.sleep(cfg.polling_time) + return [pl.scan_parquet(fp, glob=False) for fp in in_fps] - def read_fn(fp: Path) -> pl.LazyFrame: - return pl.concat([pl.scan_parquet(fp, glob=False) for fp in in_fps], how="diagonal_relaxed").sort( + def compute_fn(dfs: list[pl.LazyFrame]) -> pl.LazyFrame: + logger.info(f"Merging {subshard_dir}/**/*.parquet into {str(out_fp.resolve())}") + return pl.concat(dfs, how="diagonal_relaxed").sort( by=["patient_id", "time"], maintain_order=True, multithreaded=False ) - logger.info(f"Merging files to {str(out_fp.resolve())}") - result_computed, _ = rwlock_wrap( - in_dir, - out_fp, - read_fn, - write_lazyframe, - identity_fn, - do_return=False, - do_overwrite=cfg.do_overwrite, - ) + def write_fn(df: pl.LazyFrame, out_fp: Path) -> None: + write_lazyframe(df, out_fp) - if result_computed: logger.info(f"Cleaning up subsharded files in {str(subshard_dir.resolve())}/*.") for fp in in_fps: if fp.exists(): @@ -149,6 +138,17 @@ def read_fn(fp: Path) -> pl.LazyFrame: contents_str = "\n".join([str(f) for f in subshard_dir.iterdir()]) raise ValueError(f"Could not remove {str(subshard_dir)}. Contents:\n{contents_str}") from e + logger.info(f"Merging files to {str(out_fp.resolve())}") + result_computed, _ = rwlock_wrap( + in_dir, + out_fp, + read_fn, + write_fn, + compute_fn, + do_return=False, + do_overwrite=cfg.do_overwrite, + ) + logger.info(f"Done with {cfg.stage}") From b1378bc2b0355ae7d1897933cd9c5763cd6576aa Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 11 Aug 2024 08:24:23 -0400 Subject: [PATCH 48/53] changes in progress to improve robustness --- src/MEDS_transforms/reshard_to_split.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index c8ba7c3..4e9cc40 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -111,11 +111,22 @@ def read_fn(in_dir: Path) -> pl.LazyFrame: logger.info("Waiting to begin merging for all sub-shard files to be written...") time.sleep(cfg.polling_time) - return [pl.scan_parquet(fp, glob=False) for fp in in_fps] - - def compute_fn(dfs: list[pl.LazyFrame]) -> pl.LazyFrame: + logger.info(f"Merging {str(in_dir.resolve())}/**/*.parquet:") + df = None + for fp in in_fps: + if not fp.is_file(): + raise FileNotFoundError(f"File {str(fp.resolve())} not found.") + logger.info(f" - {str(fp.resolve())}") + if df is None: + df = pl.scan_parquet(fp, glob=False) + else: + df = df.merge_sorted(pl.scan_parquet(fp, glob=False), key="patient_id") + return df + + def compute_fn(df: list[pl.DataFrame]) -> pl.LazyFrame: logger.info(f"Merging {subshard_dir}/**/*.parquet into {str(out_fp.resolve())}") - return pl.concat(dfs, how="diagonal_relaxed").sort( + #pl.concat(dfs, how="vertical").lazy() + return df.sort( by=["patient_id", "time"], maintain_order=True, multithreaded=False ) @@ -150,6 +161,7 @@ def write_fn(df: pl.LazyFrame, out_fp: Path) -> None: ) logger.info(f"Done with {cfg.stage}") + return 0 if __name__ == "__main__": From 3a957940bc49534c4b28550eaed3135b53cffde3 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 11 Aug 2024 08:24:36 -0400 Subject: [PATCH 49/53] changes in progress to improve robustness --- src/MEDS_transforms/reshard_to_split.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index 4e9cc40..a8d2dfd 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -125,10 +125,8 @@ def read_fn(in_dir: Path) -> pl.LazyFrame: def compute_fn(df: list[pl.DataFrame]) -> pl.LazyFrame: logger.info(f"Merging {subshard_dir}/**/*.parquet into {str(out_fp.resolve())}") - #pl.concat(dfs, how="vertical").lazy() - return df.sort( - by=["patient_id", "time"], maintain_order=True, multithreaded=False - ) + # pl.concat(dfs, how="vertical").lazy() + return df.sort(by=["patient_id", "time"], maintain_order=True, multithreaded=False) def write_fn(df: pl.LazyFrame, out_fp: Path) -> None: write_lazyframe(df, out_fp) From d9ed79ae1ed48e14a963407ce5020d82f497621c Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 11 Aug 2024 09:04:42 -0400 Subject: [PATCH 50/53] Made file checker wait for valid parquet files --- .../aggregate_code_metadata.py | 3 +- src/MEDS_transforms/mapreduce/utils.py | 34 +++++++++++++++++++ src/MEDS_transforms/reshard_to_split.py | 11 +++--- 3 files changed, 43 insertions(+), 5 deletions(-) diff --git a/src/MEDS_transforms/aggregate_code_metadata.py b/src/MEDS_transforms/aggregate_code_metadata.py index 91d62a1..9290926 100755 --- a/src/MEDS_transforms/aggregate_code_metadata.py +++ b/src/MEDS_transforms/aggregate_code_metadata.py @@ -16,6 +16,7 @@ from MEDS_transforms import PREPROCESS_CONFIG_YAML from MEDS_transforms.mapreduce.mapper import map_over +from MEDS_transforms.mapreduce.utils import is_complete_parquet_file from MEDS_transforms.utils import write_lazyframe @@ -667,7 +668,7 @@ def run_map_reduce(cfg: DictConfig): logger.info("Starting reduction process") - while not all(fp.is_file() for fp in all_out_fps): + while not all(is_complete_parquet_file(fp) for fp in all_out_fps): logger.info("Waiting to begin reduction for all files to be written...") time.sleep(cfg.polling_time) diff --git a/src/MEDS_transforms/mapreduce/utils.py b/src/MEDS_transforms/mapreduce/utils.py index 87ce835..0f6f1a6 100644 --- a/src/MEDS_transforms/mapreduce/utils.py +++ b/src/MEDS_transforms/mapreduce/utils.py @@ -6,12 +6,46 @@ from datetime import datetime from pathlib import Path +import pyarrow.parquet as pq from loguru import logger from omegaconf import DictConfig LOCK_TIME_FMT = "%Y-%m-%dT%H:%M:%S.%f" +def is_complete_parquet_file(fp: Path) -> bool: + """Check if a parquet file is complete. + + Args: + fp: The file path to the parquet file. + + Returns: + True if the parquet file is complete, False otherwise. + + Examples: + >>> import tempfile, polars as pl + >>> df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + >>> with tempfile.NamedTemporaryFile() as tmp: + ... df.write_parquet(tmp) + ... is_complete_parquet_file(tmp) + True + >>> with tempfile.NamedTemporaryFile() as tmp: + ... df.write_csv(tmp) + ... is_complete_parquet_file(tmp) + False + >>> with tempfile.TemporaryDirectory() as tmp: + ... tmp = Path(tmp) + ... is_complete_parquet_file(tmp / "nonexistent.parquet") + False + """ + + try: + _ = pq.ParquetFile(fp) + return True + except Exception: + return False + + def get_earliest_lock(cache_directory: Path) -> datetime | None: """Returns the earliest start time of any lock file present in a cache directory, or None if none exist. diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index a8d2dfd..81a182c 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -14,7 +14,12 @@ from MEDS_transforms import PREPROCESS_CONFIG_YAML from MEDS_transforms.extract.split_and_shard_patients import shard_patients from MEDS_transforms.mapreduce.mapper import identity_fn, read_and_filter_fntr -from MEDS_transforms.mapreduce.utils import rwlock_wrap, shard_iterator, shuffle_shards +from MEDS_transforms.mapreduce.utils import ( + is_complete_parquet_file, + rwlock_wrap, + shard_iterator, + shuffle_shards, +) from MEDS_transforms.utils import stage_init, write_lazyframe @@ -107,15 +112,13 @@ def main(cfg: DictConfig): out_fp = subshard_dir.with_suffix(".parquet") def read_fn(in_dir: Path) -> pl.LazyFrame: - while not (all(fp.is_file() for fp in in_fps) or out_fp.is_file()): + while not (all(is_complete_parquet_file(fp) for fp in in_fps) or out_fp.is_file()): logger.info("Waiting to begin merging for all sub-shard files to be written...") time.sleep(cfg.polling_time) logger.info(f"Merging {str(in_dir.resolve())}/**/*.parquet:") df = None for fp in in_fps: - if not fp.is_file(): - raise FileNotFoundError(f"File {str(fp.resolve())} not found.") logger.info(f" - {str(fp.resolve())}") if df is None: df = pl.scan_parquet(fp, glob=False) From d1f59539cd1c9189ef9b9b306a13f86bf5405deb Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 11 Aug 2024 09:07:22 -0400 Subject: [PATCH 51/53] Removed unneeded line. --- src/MEDS_transforms/reshard_to_split.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index 81a182c..296423a 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -128,7 +128,6 @@ def read_fn(in_dir: Path) -> pl.LazyFrame: def compute_fn(df: list[pl.DataFrame]) -> pl.LazyFrame: logger.info(f"Merging {subshard_dir}/**/*.parquet into {str(out_fp.resolve())}") - # pl.concat(dfs, how="vertical").lazy() return df.sort(by=["patient_id", "time"], maintain_order=True, multithreaded=False) def write_fn(df: pl.LazyFrame, out_fp: Path) -> None: From e0c0f46b5bd8379682332e70187766ead5484907 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 11 Aug 2024 09:29:16 -0400 Subject: [PATCH 52/53] Made rwlock more robust and eliminated unused return mode. --- src/MEDS_transforms/mapreduce/mapper.py | 1 - src/MEDS_transforms/mapreduce/utils.py | 62 +++++++++---------- src/MEDS_transforms/reshard_to_split.py | 10 +-- .../transforms/tokenization.py | 2 - 4 files changed, 31 insertions(+), 44 deletions(-) diff --git a/src/MEDS_transforms/mapreduce/mapper.py b/src/MEDS_transforms/mapreduce/mapper.py index f6b20c3..7649082 100644 --- a/src/MEDS_transforms/mapreduce/mapper.py +++ b/src/MEDS_transforms/mapreduce/mapper.py @@ -621,7 +621,6 @@ def map_over( read_fn, write_fn, compute_fn, - do_return=False, do_overwrite=cfg.do_overwrite, ) all_out_fps.append(out_fp) diff --git a/src/MEDS_transforms/mapreduce/utils.py b/src/MEDS_transforms/mapreduce/utils.py index 0f6f1a6..4aa4c53 100644 --- a/src/MEDS_transforms/mapreduce/utils.py +++ b/src/MEDS_transforms/mapreduce/utils.py @@ -117,6 +117,13 @@ def register_lock(cache_directory: Path) -> tuple[datetime, Path]: return lock_time, lock_fp +def default_file_checker(fp: Path) -> bool: + """Check if a file exists and is complete.""" + if fp.suffix == ".parquet": + return is_complete_parquet_file(fp) + return fp.is_file() + + def rwlock_wrap[ DF_T ]( @@ -126,8 +133,8 @@ def rwlock_wrap[ write_fn: Callable[[DF_T, Path], None], compute_fn: Callable[[DF_T], DF_T], do_overwrite: bool = False, - do_return: bool = False, -) -> tuple[bool, DF_T | None]: + out_fp_checker: Callable[[Path], bool] = default_file_checker, +) -> bool: """Wrap a series of file-in file-out map transformations on a dataframe with caching and locking. Args: @@ -146,11 +153,9 @@ def rwlock_wrap[ type DF_T and return a dataframe of (generic) type DF_T. do_overwrite: If True, the output file will be overwritten if it already exists. This is `False` by default. - do_return: If True, the final dataframe will be returned. This is `False` by default. Returns: - The dataframe resulting from the transformations applied in sequence to the dataframe stored in - `in_fp`. + True if the computation was run, False otherwise. Examples: >>> import polars as pl @@ -166,7 +171,7 @@ def rwlock_wrap[ >>> read_fn = pl.read_csv >>> write_fn = pl.DataFrame.write_csv >>> compute_fn = lambda df: df.with_columns(pl.col("c") * 2).filter(pl.col("c") > 4) - >>> result_computed = rwlock_wrap(in_fp, out_fp, read_fn, write_fn, compute_fn, do_return=False) + >>> result_computed = rwlock_wrap(in_fp, out_fp, read_fn, write_fn, compute_fn) >>> assert result_computed >>> print(out_fp.read_text()) a,b,c @@ -179,49 +184,33 @@ def rwlock_wrap[ Traceback (most recent call last): ... polars.exceptions.ColumnNotFoundError: unable to find column "d"; valid columns: ["a", "b", "c"] - >>> cache_directory = root / f".output_cache" + >>> cache_directory = root / f".output.csv_cache" >>> lock_dir = cache_directory / "locks" >>> assert not list(lock_dir.iterdir()) >>> def lock_dir_checker_fn(df: pl.DataFrame) -> pl.DataFrame: ... print(f"Lock dir empty? {not (list(lock_dir.iterdir()))}") ... return df - >>> result_computed, out_df = rwlock_wrap( - ... in_fp, out_fp, read_fn, write_fn, lock_dir_checker_fn, do_return=True - ... ) + >>> result_computed = rwlock_wrap(in_fp, out_fp, read_fn, write_fn, lock_dir_checker_fn) Lock dir empty? False >>> assert result_computed - >>> out_df - shape: (3, 3) - ┌─────┬─────┬─────┐ - │ a ┆ b ┆ c │ - │ --- ┆ --- ┆ --- │ - │ i64 ┆ i64 ┆ i64 │ - ╞═════╪═════╪═════╡ - │ 1 ┆ 2 ┆ 3 │ - │ 3 ┆ 4 ┆ -1 │ - │ 3 ┆ 5 ┆ 6 │ - └─────┴─────┴─────┘ >>> directory.cleanup() """ - if out_fp.is_file(): + if out_fp_checker(out_fp): if do_overwrite: logger.info(f"Deleting existing {out_fp} as do_overwrite={do_overwrite}.") out_fp.unlink() else: - logger.info(f"{out_fp} exists; reading directly and returning.") - if do_return: - return True, read_fn(out_fp) - else: - return True, None + logger.info(f"{out_fp} exists; returning.") + return False - cache_directory = out_fp.parent / f".{out_fp.stem}_cache" + cache_directory = out_fp.parent / f".{out_fp.parts[-1]}_cache" cache_directory.mkdir(exist_ok=True, parents=True) earliest_lock_time = get_earliest_lock(cache_directory) if earliest_lock_time is not None: logger.info(f"{out_fp} is in progress as of {earliest_lock_time}. Returning.") - return False, None if do_return else False + return False st_time, lock_fp = register_lock(cache_directory) @@ -230,12 +219,20 @@ def rwlock_wrap[ if earliest_lock_time < st_time: logger.info(f"Earlier lock found at {earliest_lock_time}. Deleting current lock and returning.") lock_fp.unlink() - return False, None if do_return else False + return False logger.info(f"Reading input dataframe from {in_fp}") df = read_fn(in_fp) logger.info("Read dataset") + earliest_lock_time = get_earliest_lock(cache_directory) + if earliest_lock_time < st_time: + logger.info( + f"Earlier lock found post read at {earliest_lock_time}. Deleting current lock and returning." + ) + lock_fp.unlink() + return False + try: df = compute_fn(df) logger.info(f"Writing final output to {out_fp}") @@ -244,10 +241,7 @@ def rwlock_wrap[ logger.info(f"Leaving cache directory {cache_directory}, but clearing lock at {lock_fp}") lock_fp.unlink() - if do_return: - return True, df - else: - return True, None + return True except Exception as e: logger.warning(f"Clearing lock due to Exception {e} at {lock_fp} after {datetime.now() - st_time}") diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index 296423a..bebb557 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -98,11 +98,9 @@ def main(cfg: DictConfig): read_and_filter_fntr(pl.col("patient_id").is_in(patients), pl.scan_parquet), write_lazyframe, identity_fn, - do_return=False, do_overwrite=cfg.do_overwrite, ) - logger.info("Merging sub-shards") for subshard_name, subshard_dir in new_shards_iter: in_dir = subshard_dir in_fps = subshard_fps[subshard_name] @@ -112,7 +110,7 @@ def main(cfg: DictConfig): out_fp = subshard_dir.with_suffix(".parquet") def read_fn(in_dir: Path) -> pl.LazyFrame: - while not (all(is_complete_parquet_file(fp) for fp in in_fps) or out_fp.is_file()): + while not all(is_complete_parquet_file(fp) for fp in in_fps): logger.info("Waiting to begin merging for all sub-shard files to be written...") time.sleep(cfg.polling_time) @@ -149,19 +147,17 @@ def write_fn(df: pl.LazyFrame, out_fp: Path) -> None: contents_str = "\n".join([str(f) for f in subshard_dir.iterdir()]) raise ValueError(f"Could not remove {str(subshard_dir)}. Contents:\n{contents_str}") from e - logger.info(f"Merging files to {str(out_fp.resolve())}") - result_computed, _ = rwlock_wrap( + logger.info(f"Merging sub-shards for {subshard_name} to {str(out_fp.resolve())}") + rwlock_wrap( in_dir, out_fp, read_fn, write_fn, compute_fn, - do_return=False, do_overwrite=cfg.do_overwrite, ) logger.info(f"Done with {cfg.stage}") - return 0 if __name__ == "__main__": diff --git a/src/MEDS_transforms/transforms/tokenization.py b/src/MEDS_transforms/transforms/tokenization.py index 5e5e0e6..c4f3d26 100644 --- a/src/MEDS_transforms/transforms/tokenization.py +++ b/src/MEDS_transforms/transforms/tokenization.py @@ -248,7 +248,6 @@ def main(cfg: DictConfig): pl.scan_parquet, write_lazyframe, extract_statics_and_schema, - do_return=False, do_overwrite=cfg.do_overwrite, ) @@ -260,7 +259,6 @@ def main(cfg: DictConfig): pl.scan_parquet, write_lazyframe, extract_seq_of_patient_events, - do_return=False, do_overwrite=cfg.do_overwrite, ) From 3fc909e7e280f9f5590e5fbcf79c5d1020e90979 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 11 Aug 2024 10:04:22 -0400 Subject: [PATCH 53/53] Made resharding use rwlock wrap --- src/MEDS_transforms/reshard_to_split.py | 73 ++++++++++++++++--------- 1 file changed, 46 insertions(+), 27 deletions(-) diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index bebb557..9a18a40 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -4,6 +4,7 @@ import json import time from collections import defaultdict +from functools import partial from pathlib import Path import hydra @@ -23,49 +24,67 @@ from MEDS_transforms.utils import stage_init, write_lazyframe -@hydra.main( - version_base=None, config_path=str(PREPROCESS_CONFIG_YAML.parent), config_name=PREPROCESS_CONFIG_YAML.stem -) -def main(cfg: DictConfig): - """Re-shard a MEDS cohort to in a manner that subdivides patient splits.""" +def valid_json_file(fp: Path) -> bool: + """Check if a file is a valid JSON file.""" + if not fp.is_file(): + return False + try: + json.loads(fp.read_text()) + return True + except json.JSONDecodeError: + return False - stage_init(cfg) - splits_file = Path(cfg.input_dir) / "metadata" / "patient_splits.parquet" - splits_df = pl.read_parquet(splits_file, use_pyarrow=True) +def make_new_shards_fn(df: pl.DataFrame, cfg: DictConfig, stage_cfg: DictConfig) -> dict[str, list[str]]: splits_map = defaultdict(list) - for pt_id, sp in splits_df.iter_rows(): + for pt_id, sp in df.iter_rows(): splits_map[sp].append(pt_id) - new_sharded_splits = shard_patients( - patients=splits_df["patient_id"].to_numpy(), - n_patients_per_shard=cfg.stage_cfg.n_patients_per_shard, + return shard_patients( + patients=df["patient_id"].to_numpy(), + n_patients_per_shard=stage_cfg.n_patients_per_shard, external_splits=splits_map, split_fracs_dict=None, seed=cfg.get("seed", 1), ) - output_dir = Path(cfg.stage_cfg.output_dir) - # Write shards to file - if "shards_map_fp" in cfg: - shards_fp = Path(cfg.shards_map_fp) - else: - shards_fp = output_dir / ".shards.json" +def write_json(d: dict, fp: Path) -> None: + fp.write_text(json.dumps(d)) - if shards_fp.is_file(): - if cfg.do_overwrite: - logger.warning(f"Overwriting {str(shards_fp.resolve())}") - shards_fp.unlink() - else: - old_shards_map = json.loads(shards_fp.read_text()) - if old_shards_map != new_sharded_splits: - raise FileExistsError(f"{str(shards_fp.resolve())} already exists and shard map differs.") - shards_fp.write_text(json.dumps(new_sharded_splits)) +@hydra.main( + version_base=None, config_path=str(PREPROCESS_CONFIG_YAML.parent), config_name=PREPROCESS_CONFIG_YAML.stem +) +def main(cfg: DictConfig): + """Re-shard a MEDS cohort to in a manner that subdivides patient splits.""" + + stage_init(cfg) output_dir = Path(cfg.stage_cfg.output_dir) + splits_file = Path(cfg.input_dir) / "metadata" / "patient_splits.parquet" + shards_fp = output_dir / ".shards.json" + + rwlock_wrap( + splits_file, + shards_fp, + partial(pl.read_parquet, use_pyarrow=True), + write_json, + partial(make_new_shards_fn, cfg=cfg, stage_cfg=cfg.stage_cfg), + do_overwrite=cfg.do_overwrite, + out_fp_checker=valid_json_file, + ) + + max_iters = cfg.get("max_iters", 10) + iters = 0 + while not valid_json_file(shards_fp) and iters < max_iters: + logger.info(f"Waiting to begin until shards map is written. Iteration {iters}/{max_iters}...") + time.sleep(cfg.polling_time) + iters += 1 + + new_sharded_splits = json.loads(shards_fp.read_text()) + orig_shards_iter, include_only_train = shard_iterator(cfg, out_suffix="") if include_only_train: raise ValueError("This stage does not support include_only_train=True")