Skip to content

Commit

Permalink
align summary_method -> aggregation_method
Browse files Browse the repository at this point in the history
  • Loading branch information
diehlbw committed Jul 12, 2024
1 parent 263d5de commit 8214a56
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 25 deletions.
32 changes: 18 additions & 14 deletions src/seismometer/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def fairness_audit(metric_list: Optional[list[str]] = None, fairness_threshold=1
sg.entity_keys,
score=sg.output,
ref_event=sg.predict_time,
summary_method=sg.event_aggregation_method(sg.target),
aggregation_method=sg.event_aggregation_method(sg.target),
)[[sg.target, sg.output] + sensitive_groups]

display_fairness_audit(
Expand Down Expand Up @@ -442,7 +442,7 @@ def _plot_leadtime_enc(
entity_keys,
score=score,
ref_event=target_zero,
summary_method="first",
aggregation_method="first",
)[[target_zero, ref_time, cohort_col]]

# filter by group size
Expand Down Expand Up @@ -508,7 +508,7 @@ def _plot_cohort_evaluation(
subgroups: list[str],
censor_threshold: int = 10,
per_context_id: bool = False,
summary_method: str = "max",
aggregation_method: str = "max",
ref_time: str = None,
) -> HTML:
"""
Expand All @@ -533,20 +533,22 @@ def _plot_cohort_evaluation(
censor_threshold : int
minimum rows to allow in a plot, by default 10
per_context_id : bool, optional
if true, summarize scores for each context, by default False
summary_method : str, optional
method to summarize multiple scores into a single value before calculation of performance, by default "max"
if true, aggregate scores for each context, by default False
aggregation_method : str, optional
method to reduce multiple scores into a single value before calculation of performance, by default "max"
ignored if per_context_id is False
ref_time : str, optional
reference time column used for summarization when per_context_id is True and summary_method is time-based
reference time column used for aggregation when per_context_id is True and aggregation_method is time-based
Returns
-------
HTML
_description_
"""
data = (
pdh.event_score(dataframe, entity_keys, score=output, ref_event=ref_time, summary_method=summary_method)
pdh.event_score(
dataframe, entity_keys, score=output, ref_event=ref_time, aggregation_method=aggregation_method
)
if per_context_id
else dataframe
)
Expand Down Expand Up @@ -598,7 +600,7 @@ def _model_evaluation(
output: str,
thresholds: Optional[list[float]],
per_context_id: bool = False,
summary_method: str = "max",
aggregation_method: str = "max",
ref_time: str = None,
) -> HTML:
"""
Expand All @@ -620,19 +622,21 @@ def _model_evaluation(
model thresholds
per_context_id : bool, optional
report only the max score for a given entity context, by default False
summary_method : str, optional
method to summarize multiple scores into a single value before calculation of performance, by default "max"
aggregation_method : str, optional
method to reduce multiple scores into a single value before calculation of performance, by default "max"
ignored if per_context_id is False
ref_time : str, optional
reference time column used for summarization when per_context_id is True and summary_method is time-based
reference time column used for aggregation when per_context_id is True and aggregation_method is time-based
Returns
-------
HTML
Plot of model evaluation metrics
"""
data = (
pdh.event_score(dataframe, entity_keys, score=output, ref_event=ref_time, summary_method=summary_method)
pdh.event_score(
dataframe, entity_keys, score=output, ref_event=ref_time, aggregation_method=aggregation_method
)
if per_context_id
else dataframe
)
Expand Down Expand Up @@ -912,7 +916,7 @@ def show_info(plot_help: bool = False):

def _style_cohort_summaries(df: pd.DataFrame, attribute: str) -> Styler:
"""
Adds required styling to a cohort summary dataframe.
Adds required styling to a cohort dataframe.
Parameters
----------
Expand Down
14 changes: 7 additions & 7 deletions src/seismometer/data/pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def event_score(
pks: list[str],
score: str,
ref_event: Optional[str] = None,
summary_method: str = "max",
aggregation_method: str = "max",
) -> pd.DataFrame:
"""
Reduces a dataframe of all predictions to a single row of significance; such as the max or most recent value for
Expand All @@ -180,40 +180,40 @@ def event_score(
The column name containing the score value.
ref_event : Optional[str], optional
The column name containing the time to consider, by default None.
summary_method : str, optional
aggregation_method : str, optional
A string describing the method to select a value, by default 'max'.
Returns
-------
pd.DataFrame
The reduced dataframe with one row per combination of pks.
"""
logger.debug(f"Combining scores using {summary_method} for {score} on {ref_event}")
logger.debug(f"Combining scores using {aggregation_method} for {score} on {ref_event}")
# groupby.agg works on columns indivdually - this wants entire row where a condition is met
# start with first/last/max/min

ref_score = _resolve_score_col(merged_frame, score)
if summary_method == "max":
if aggregation_method == "max":
ref_col = ref_score

def apply_fn(gf):
return gf.idxmax()

elif summary_method == "min":
elif aggregation_method == "min":
ref_col = ref_score

def apply_fn(gf):
return gf.idxmin()

# merged frame has time columns only for events in appropriate time window,
# implicitly reduces to positive label (need method to re-add negative samples)
elif summary_method == "last":
elif aggregation_method == "last":

def apply_fn(gf):
return gf.idxmax()

ref_col = _resolve_time_col(merged_frame, ref_event)
elif summary_method == "first":
elif aggregation_method == "first":

def apply_fn(gf):
return gf.idxmin()
Expand Down
8 changes: 4 additions & 4 deletions tests/data/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ def event_data(res):

class Test_Event_Score:
@pytest.mark.parametrize("ref_event", ["Target", "PredictTime", "Reference_5_15_Time"])
@pytest.mark.parametrize("summary_method", ["min", "max", "first", "last"])
@pytest.mark.parametrize("aggregation_method", ["min", "max", "first", "last"])
@pytest.mark.parametrize("id_, csn", [pytest.param(1, 0, id="monotonic-increasing")])
def test_bad_event(self, id_, csn, summary_method, ref_event, event_data):
def test_bad_event(self, id_, csn, aggregation_method, ref_event, event_data):
input_frame, expected_frame = event_data
expected_score = expected_frame.loc[
(expected_frame["Id"] == id_)
& (expected_frame["CSN"] == csn)
& (expected_frame["ref_event"] == ref_event),
summary_method,
aggregation_method,
]

actual = undertest.event_score(input_frame, ["Id", "CSN"], "ModelScore", ref_event, summary_method)
actual = undertest.event_score(input_frame, ["Id", "CSN"], "ModelScore", ref_event, aggregation_method)
assert actual["ModelScore"].tolist() == expected_score.tolist()

0 comments on commit 8214a56

Please sign in to comment.