Skip to content

Commit

Permalink
Merge branch 'dev' into 69_add_reducer_logic
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Aug 8, 2024
2 parents 83267dd + 36a172d commit fb0c42d
Showing 1 changed file with 123 additions and 14 deletions.
137 changes: 123 additions & 14 deletions src/MEDS_transforms/aggregate_code_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ class METADATA_FN(StrEnum):
the code
"values/min": Collects the minimum non-null, non-nan numeric_value value for the code & modifiers
"values/max": Collects the maximum non-null, non-nan numeric_value value for the code & modifiers
"values/quantiles": Collects the specified quantiles over all observed numeric values for the code &
modifiers group. The quantiles are specified in the output as a polars struct field, with the
quantile as the key and the value as the quantile value. The desired quantiles are specified in
the configuration file using the dictionary syntax for the aggregation.
"""

CODE_N_PATIENTS = "code/n_patients"
Expand All @@ -71,6 +75,7 @@ class METADATA_FN(StrEnum):
VALUES_SUM_SQD = "values/sum_sqd"
VALUES_MIN = "values/min"
VALUES_MAX = "values/max"
VALUES_QUANTILES = "values/quantiles"


class MapReducePair(NamedTuple):
Expand All @@ -97,22 +102,72 @@ class MapReducePair(NamedTuple):
reducer: Callable[[pl.Expr | Sequence[pl.Expr] | cs._selector_proxy_], pl.Expr]


def quantile_reducer(cols: cs._selector_proxy_, quantiles: list[float]) -> pl.Expr:
"""Calculates the specified quantiles for the combined set of all numerical values in `cols`.
Args:
cols: A polars selector that selects the column(s) containing the numerical values for which the
quantiles should be calculated.
quantiles: A list of floats specifying the quantiles that should be calculated.
Returns:
A polars expression that calculates the specified quantiles for the combined set of all numerical
values in `cols`.
Examples:
>>> df = pl.DataFrame({
... "key": [1, 2],
... "vals/shard1": [[1, 2, float('nan')], [None, 3]],
... "vals/shard2": [[3.0, 4], [30]],
... }, strict=False)
>>> expr = quantile_reducer(cs.starts_with("vals/"), [0.01, 0.5, 0.75])
>>> df.select(expr)
shape: (1, 1)
┌──────────────────┐
│ values/quantiles │
│ --- │
│ struct[3] │
╞══════════════════╡
│ {1.0,3.0,30.0} │
└──────────────────┘
>>> df.select("key", expr.over("key"))
shape: (2, 2)
┌─────┬──────────────────┐
│ key ┆ values/quantiles │
│ --- ┆ --- │
│ i64 ┆ struct[3] │
╞═════╪══════════════════╡
│ 1 ┆ {1.0,3.0,4.0} │
│ 2 ┆ {3.0,30.0,30.0} │
└─────┴──────────────────┘
"""

vals = pl.concat_list(cols.fill_null([])).explode()

quantile_cols = [f"values/quantile/{q}" for q in quantiles]
quantiles_struct = {col: vals.quantile(q).alias(col) for col, q in zip(quantile_cols, quantiles)}

return pl.struct(**quantiles_struct).alias(METADATA_FN.VALUES_QUANTILES)


VAL = pl.col("numeric_value")
VAL_PRESENT: pl.Expr = VAL.is_not_null() & VAL.is_not_nan()
IS_INT: pl.Expr = VAL.round() == VAL
PRESENT_VALS = VAL.filter(VAL_PRESENT)

CODE_METADATA_AGGREGATIONS: dict[METADATA_FN, MapReducePair] = {
METADATA_FN.CODE_N_PATIENTS: MapReducePair(pl.col("patient_id").n_unique(), pl.sum_horizontal),
METADATA_FN.CODE_N_OCCURRENCES: MapReducePair(pl.len(), pl.sum_horizontal),
METADATA_FN.VALUES_N_PATIENTS: MapReducePair(
pl.col("patient_id").filter(VAL_PRESENT).n_unique(), pl.sum_horizontal
),
METADATA_FN.VALUES_N_OCCURRENCES: MapReducePair(VAL.filter(VAL_PRESENT).len(), pl.sum_horizontal),
METADATA_FN.VALUES_N_OCCURRENCES: MapReducePair(PRESENT_VALS.len(), pl.sum_horizontal),
METADATA_FN.VALUES_N_INTS: MapReducePair(VAL.filter(VAL_PRESENT & IS_INT).len(), pl.sum_horizontal),
METADATA_FN.VALUES_SUM: MapReducePair(VAL.filter(VAL_PRESENT).sum(), pl.sum_horizontal),
METADATA_FN.VALUES_SUM_SQD: MapReducePair((VAL.filter(VAL_PRESENT) ** 2).sum(), pl.sum_horizontal),
METADATA_FN.VALUES_MIN: MapReducePair(VAL.filter(VAL_PRESENT).min(), pl.min_horizontal),
METADATA_FN.VALUES_MAX: MapReducePair(VAL.filter(VAL_PRESENT).max(), pl.max_horizontal),
METADATA_FN.VALUES_SUM: MapReducePair(PRESENT_VALS.sum(), pl.sum_horizontal),
METADATA_FN.VALUES_SUM_SQD: MapReducePair((PRESENT_VALS**2).sum(), pl.sum_horizontal),
METADATA_FN.VALUES_MIN: MapReducePair(PRESENT_VALS.min(), pl.min_horizontal),
METADATA_FN.VALUES_MAX: MapReducePair(PRESENT_VALS.max(), pl.max_horizontal),
METADATA_FN.VALUES_QUANTILES: MapReducePair(PRESENT_VALS, quantile_reducer),
}


Expand Down Expand Up @@ -148,8 +203,8 @@ def validate_args_and_get_code_cols(stage_cfg: DictConfig, code_modifiers: list[
...
ValueError: Metadata aggregation function INVALID not found in METADATA_FN enumeration. Values are:
code/n_patients, code/n_occurrences, values/n_patients, values/n_occurrences, values/n_ints,
values/sum, values/sum_sqd, values/min, values/max
>>> valid_cfg = DictConfig({"aggregations": ["code/n_patients", "values/n_ints"]})
values/sum, values/sum_sqd, values/min, values/max, values/quantiles
>>> valid_cfg = DictConfig({"aggregations": ["code/n_patients", {"name": "values/n_ints"}]})
>>> validate_args_and_get_code_cols(valid_cfg, 33)
Traceback (most recent call last):
...
Expand All @@ -171,6 +226,8 @@ def validate_args_and_get_code_cols(stage_cfg: DictConfig, code_modifiers: list[

aggregations = stage_cfg.aggregations
for agg in aggregations:
if isinstance(agg, (dict, DictConfig)):
agg = agg.get("name", None)
if agg not in METADATA_FN:
raise ValueError(
f"Metadata aggregation function {agg} not found in METADATA_FN enumeration. Values are: "
Expand Down Expand Up @@ -348,12 +405,30 @@ def mapper_fntr(
│ C ┆ 1 ┆ 81.25 ┆ 5.0 ┆ 7.5 │
│ D ┆ null ┆ 0.0 ┆ null ┆ null │
└──────┴───────────┴────────────────┴────────────┴────────────┘
>>> stage_cfg = DictConfig({"aggregations": ["values/quantiles"]})
>>> mapper = mapper_fntr(stage_cfg, code_modifiers)
>>> mapper(df.lazy()).collect().select("code", "modifier1", pl.col("values/quantiles"))
shape: (5, 3)
┌──────┬───────────┬──────────────────┐
│ code ┆ modifier1 ┆ values/quantiles │
│ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ list[f64] │
╞══════╪═══════════╪══════════════════╡
│ A ┆ 1 ┆ [1.1, 1.1] │
│ A ┆ 2 ┆ [6.0] │
│ B ┆ 2 ┆ [2.0, 4.0] │
│ C ┆ 1 ┆ [5.0, 7.5] │
│ D ┆ null ┆ [] │
└──────┴───────────┴──────────────────┘
"""

code_key_columns = validate_args_and_get_code_cols(stage_cfg, code_modifiers)
aggregations = stage_cfg.aggregations

agg_operations = {agg: CODE_METADATA_AGGREGATIONS[agg].mapper for agg in aggregations}
agg_operations = {}
for agg in aggregations:
agg_name = agg if isinstance(agg, str) else agg["name"]
agg_operations[agg_name] = CODE_METADATA_AGGREGATIONS[agg_name].mapper

def by_code_mapper(df: pl.LazyFrame) -> pl.LazyFrame:
return df.group_by(code_key_columns).agg(**agg_operations).sort(code_key_columns)
Expand Down Expand Up @@ -409,6 +484,7 @@ def reducer_fntr(
... "values/sum_sqd": [21.3, 2.42, 36.0, 84.0, 81.25],
... "values/min": [-1, 0, -1, 2, 2],
... "values/max": [8.0, 1.1, 6.0, 8.0, 7.5],
... "values/quantiles": [[1.1, 1.1], [6.0], [6.0], [5.0, 7.5], []],
... })
>>> df_2 = pl.DataFrame({
... "code": ["A", "A", "B", "C"],
Expand All @@ -422,6 +498,7 @@ def reducer_fntr(
... "values/sum_sqd": [0., 103.2, 84.0, 81.25],
... "values/min": [None, -1., 0.2, -2.],
... "values/max": [None, 6.2, 1.0, 1.5],
... "values/quantiles": [[1.3, -1.1, 2.0], [6.0, 1.2], [3.0, 2.5], [11.1, 12.]],
... })
>>> df_3 = pl.DataFrame({
... "code": ["D"],
Expand All @@ -435,6 +512,7 @@ def reducer_fntr(
... "values/sum_sqd": [4],
... "values/min": [0],
... "values/max": [2],
... "values/quantiles": [[]],
... })
>>> code_modifiers = ["modifier1"]
>>> stage_cfg = DictConfig({"aggregations": ["code/n_patients", "values/n_ints"]})
Expand Down Expand Up @@ -519,25 +597,56 @@ def reducer_fntr(
Traceback (most recent call last):
...
KeyError: 'Column values/min not found in DataFrame 0 for reduction.'
>>> stage_cfg = DictConfig({
... "aggregations": [{"name": "values/quantiles", "quantiles": [0.25, 0.5, 0.75]}],
... })
>>> reducer = reducer_fntr(stage_cfg, code_modifiers)
>>> reducer(df_1, df_2, df_3).unnest("values/quantiles")
shape: (7, 5)
┌──────┬───────────┬──────────────────────┬─────────────────────┬──────────────────────┐
│ code ┆ modifier1 ┆ values/quantile/0.25 ┆ values/quantile/0.5 ┆ values/quantile/0.75 │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ f64 ┆ f64 ┆ f64 │
╞══════╪═══════════╪══════════════════════╪═════════════════════╪══════════════════════╡
│ null ┆ null ┆ 1.1 ┆ 1.1 ┆ 1.1 │
│ A ┆ 1 ┆ 1.3 ┆ 2.0 ┆ 2.0 │
│ A ┆ 2 ┆ 6.0 ┆ 6.0 ┆ 6.0 │
│ B ┆ 1 ┆ 3.0 ┆ 5.0 ┆ 5.0 │
│ C ┆ null ┆ 11.1 ┆ 12.0 ┆ 12.0 │
│ C ┆ 2 ┆ null ┆ null ┆ null │
│ D ┆ 1 ┆ null ┆ null ┆ null │
└──────┴───────────┴──────────────────────┴─────────────────────┴──────────────────────┘
"""

code_key_columns = validate_args_and_get_code_cols(stage_cfg, code_modifiers)
aggregations = stage_cfg.aggregations

agg_operations = {
agg: CODE_METADATA_AGGREGATIONS[agg].reducer(cs.matches(f"{agg}/shard_\\d+")) for agg in aggregations
}
agg_operations = {}
for agg in aggregations:
if isinstance(agg, (dict, DictConfig)):
agg_name = agg["name"]
agg_kwargs = {k: v for k, v in agg.items() if k != "name"}
else:
agg_name = agg
agg_kwargs = {}
agg_operations[agg_name] = (
CODE_METADATA_AGGREGATIONS[agg_name]
.reducer(cs.matches(f"{agg_name}/shard_\\d+"), **agg_kwargs)
.over(*code_key_columns)
)

def reducer(*dfs: Sequence[pl.LazyFrame]) -> pl.LazyFrame:
renamed_dfs = []
for i, df in enumerate(dfs):
agg_selectors = []
for agg in aggregations:
if isinstance(agg, (dict, DictConfig)):
agg = agg["name"]
if agg not in df.columns:
raise KeyError(f"Column {agg} not found in DataFrame {i} for reduction.")
agg_selectors.append(pl.col(agg).alias(f"{agg}/shard_{i}"))

renamed_dfs.append(
df.select(*code_key_columns, *[pl.col(agg).alias(f"{agg}/shard_{i}") for agg in aggregations])
)
renamed_dfs.append(df.select(*code_key_columns, *agg_selectors))

df = renamed_dfs[0]
for rdf in renamed_dfs[1:]:
Expand Down

0 comments on commit fb0c42d

Please sign in to comment.