Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mean row aggregation to HELM summarize #2997

Merged
merged 25 commits into from
Sep 25, 2024
Merged
13 changes: 12 additions & 1 deletion src/helm/benchmark/presentation/schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import ast
import dataclasses
from dataclasses import dataclass, field
from typing import List, Optional, Dict
from enum import IntEnum
from typing import List, Optional, Dict, Union
import dacite
from inspect import cleandoc
import mako.template
Expand Down Expand Up @@ -108,6 +109,14 @@ def substitute(self, environment: Dict[str, str]) -> "MetricNameMatcher":
)


@dataclass(frozen=True)
class AggregationStrategy(IntEnum):
USE_NONE = 0
USE_MWR = 1
USE_MEAN = 2
USE_BOTH = 3


@dataclass(frozen=True)
class MetricGroup(Field):
"""
Expand All @@ -119,6 +128,8 @@ class MetricGroup(Field):
hide_win_rates: Optional[bool] = None
"""If set to true, do not compute win rates."""

aggregation_strategy: Optional[Union[AggregationStrategy, int]] = 1
farzaank marked this conversation as resolved.
Show resolved Hide resolved


BY_METRIC = "by_metric"
BY_GROUP = "by_group"
Expand Down
54 changes: 53 additions & 1 deletion src/helm/benchmark/presentation/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,32 @@ def compute_aggregate_row_win_rates(table: Table, aggregation: str = "mean") ->
return aggregate_win_rates


def compute_aggregate_row_means(table: Table) -> List[Optional[float]]:
"""
Computes the aggregate mean of each row across columns.
Returns a list of means, one per row, with None if a row was never meaningfully comparable (i.e., all
non-null values of the row are in columns we skip).
"""

means_per_row: List[Optional[float]] = []
farzaank marked this conversation as resolved.
Show resolved Hide resolved
for row in table.rows:
total: float = 0.0
farzaank marked this conversation as resolved.
Show resolved Hide resolved
count = 0
for cell in row:
try:
if cell.value:
farzaank marked this conversation as resolved.
Show resolved Hide resolved
total += float(cell.value)
count += 1
except Exception:
print("failed")
farzaank marked this conversation as resolved.
Show resolved Hide resolved
if count == 0:
means_per_row.append(None)
else:
means_per_row.append(total / count)

return means_per_row


AGGREGATE_WIN_RATE_COLUMN = 1


Expand Down Expand Up @@ -881,6 +907,7 @@ def create_group_table(
sub_split: Optional[str] = None,
bold_columns: bool = True,
add_win_rate: bool = False,
selected_agg_strat: int = 0,
) -> Table:
"""
Create a table for where each row is an adapter (for which we have a set of runs) and columns are pairs of
Expand Down Expand Up @@ -1063,7 +1090,14 @@ def _adapter_spec_sort_key(spec):

table = Table(title=title, header=header, rows=rows, links=links, name=name)

if add_win_rate:
add_mean_col = (
selected_agg_strat >= 2
) # values 2 or 3 indicate we should include mean (see AggregationStrategy enum)
add_mwr = (
selected_agg_strat % 2 != 0 or add_win_rate
) # values 1 or 3 say to include mwr (see AggregationStrategy enum)

if add_mwr:
# add overall win rate as the second column
WIN_RATE_AGGREGATION = "mean"
win_rates = compute_aggregate_row_win_rates(table, aggregation=WIN_RATE_AGGREGATION)
Expand All @@ -1078,6 +1112,22 @@ def _adapter_spec_sort_key(spec):
)
for row, win_rate in zip(table.rows, win_rates):
row.insert(AGGREGATE_WIN_RATE_COLUMN, Cell(win_rate))
if add_mean_col:
means = compute_aggregate_row_means(table)
description = "An average over columns representing the mean performance"
insertion_column = AGGREGATE_WIN_RATE_COLUMN
if add_mwr:
insertion_column += 1
table.header.insert(
insertion_column,
HeaderCell(
"Mean Performance",
description=description,
lower_is_better=False,
),
)
for row, row_mean in zip(table.rows, means):
row.insert(insertion_column, Cell(row_mean))

if bold_columns:
for i, header_cell in enumerate(table.header):
Expand Down Expand Up @@ -1126,13 +1176,15 @@ def create_group_tables_by_metric_group(self, group: RunGroup) -> List[Table]:
if len(adapter_to_runs) > 0:
for metric_group in all_metric_groups:
display_name = self.schema.name_to_metric_group[metric_group].get_short_display_name()
agg_strat = self.schema.name_to_metric_group[metric_group].aggregation_strategy or 1
table = self.create_group_table(
name=metric_group,
title=display_name,
adapter_to_runs=adapter_to_runs,
columns=[(subgroup, metric_group) for subgroup in subgroups],
is_scenario_table=False,
add_win_rate=not self.schema.name_to_metric_group[metric_group].hide_win_rates,
selected_agg_strat=int(agg_strat),
)
tables.append(table)
return tables
Expand Down
1 change: 1 addition & 0 deletions src/helm/benchmark/static/schema_safety.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ perturbations: []
metric_groups:
- name: accuracy
display_name: Accuracy
aggregation_strategy: 3
farzaank marked this conversation as resolved.
Show resolved Hide resolved
metrics:
- name: ${main_name}
split: ${main_split}
Expand Down